/*
 * SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
 * Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: GPL-2.0-only or BSD-2-Clause
 */

#include "common/def.h"
#include "common/log.h"
#include "common/sys.h"
#include "common/base.h"
#include <infiniband/verbs.h>
#include <pthread.h>
#include <unistd.h>
#include "core/xlio_base.h"

#if defined(EXTRA_API_ENABLED) && (EXTRA_API_ENABLED == 1)

static int connected_counter = 0;
static int terminated_counter = 0;
static int rx_cb_counter = 0;
static int comp_cb_counter = 0;
static const char *data_to_send = "I Love XLIO!";
static struct ibv_pd *pd = NULL;
static struct ibv_mr *mr_buf;
static char sndbuf[256];
static std::vector<xlio_socket_t> accepted_sockets;

class ultra_api_socket_send_receive_2 : public ultra_api_base {
public:
    virtual void SetUp() { errno = EOK; };
    virtual void TearDown() {};
    void destroy_poll_group(xlio_poll_group_t group) { base_destroy_poll_group(group); }
    static void socket_event_cb(xlio_socket_t sock, uintptr_t userdata_sq, int event, int value)
    {
        UNREFERENCED_PARAMETER(sock);
        UNREFERENCED_PARAMETER(userdata_sq);
        UNREFERENCED_PARAMETER(value);
        if (event == XLIO_SOCKET_EVENT_ESTABLISHED) {
            connected_counter++;
        } else if (event == XLIO_SOCKET_EVENT_CLOSED) {
            terminated_counter++;
        } else if (event == XLIO_SOCKET_EVENT_TERMINATED) {
            terminated_counter++;
        }
    }
    static void socket_comp_cb(xlio_socket_t sock, uintptr_t userdata_sq, uintptr_t userdata_op)
    {
        UNREFERENCED_PARAMETER(sock);
        UNREFERENCED_PARAMETER(userdata_sq);
        UNREFERENCED_PARAMETER(userdata_op);
        comp_cb_counter++;
    }

    static void socket_rx_cb(xlio_socket_t sock, uintptr_t userdata_sq, void *data, size_t len,
                             struct xlio_buf *buf)
    {
        UNREFERENCED_PARAMETER(sock);
        UNREFERENCED_PARAMETER(userdata_sq);
        rx_cb_counter++;

        // Assume that the data_to_send is received in one packet
        if (memcmp(data, data_to_send, len) != 0) {
            GTEST_FAIL();
        }
        xlio_api->xlio_socket_buf_free(sock, buf);
    }

    static void socket_accept_cb(xlio_socket_t sock, xlio_socket_t parent_sock,
                                 uintptr_t parent_userdata)
    {
        UNREFERENCED_PARAMETER(sock);
        UNREFERENCED_PARAMETER(parent_sock);
        UNREFERENCED_PARAMETER(parent_userdata);
        int rc = xlio_api->xlio_socket_update(sock, 0, 0x1);
        ASSERT_EQ(rc, 0);
        accepted_sockets.push_back(sock);
        connected_counter++;
        pd = xlio_api->xlio_socket_get_pd(sock);
        ASSERT_TRUE(pd != NULL);
        mr_buf = ibv_reg_mr(pd, sndbuf, sizeof(sndbuf), IBV_ACCESS_LOCAL_WRITE);
        ASSERT_TRUE(mr_buf != NULL);
        base_send_single_msg(sock, data_to_send, strlen(data_to_send), 0x1, 0, mr_buf, sndbuf);
    }
};

/**
 * @test ultra_api_socket_send_receive_2.ti_1
 * @brief
 *    Create TCP socket/connect/send(target)/receive(initiator)
 * @details
 */
TEST_F(ultra_api_socket_send_receive_2, ti_1)
{
    int rc;
    int pid = fork();
    ultra_api_base::SetUp();
    xlio_poll_group_t group;
    xlio_socket_t sock;

    base_create_poll_group(&group, &socket_event_cb, &socket_comp_cb, &socket_rx_cb,
                           &socket_accept_cb);
    xlio_socket_attr sattr = {
        .flags = 0,
        .domain = server_addr.addr.sa_family,
        .group = group,
        .userdata_sq = 0,
    };
    base_create_socket(&sattr, &sock);
    if (pid == 0) {
        // Child process - server side
        rc = xlio_api->xlio_socket_bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr));
        ASSERT_EQ(0, rc);

        rc = xlio_api->xlio_socket_listen(sock);
        ASSERT_EQ(0, rc);

        barrier_fork(pid, true);

        while (connected_counter < 1 || comp_cb_counter < 1) {
            xlio_api->xlio_poll_group_poll(group);
        }

        base_wait_for_delayed_acks(group);

        barrier_fork(pid, true);

        base_destroy_socket(sock);
        base_cleanup_accepted_sockets(accepted_sockets);
        while (terminated_counter < 1) {
            xlio_api->xlio_poll_group_poll(group);
        }

        destroy_poll_group(group);
        exit(testing::Test::HasFailure());
    } else {
        // Parent process - client side
        rc = xlio_api->xlio_socket_bind(sock, (struct sockaddr *)&client_addr, sizeof(client_addr));
        ASSERT_EQ(0, rc);

        barrier_fork(pid, true); // Wait for child to bind and listen

        rc = xlio_api->xlio_socket_connect(sock, (struct sockaddr *)&server_addr,
                                           sizeof(server_addr));
        ASSERT_EQ(0, rc);

        while (connected_counter < 1 || rx_cb_counter < 1) {
            xlio_api->xlio_poll_group_poll(group);
        }

        base_wait_for_delayed_acks(group);

        barrier_fork(pid, true); // Wait for child to accept + receive last ack

        base_destroy_socket(sock);
        while (terminated_counter < 1) {
            xlio_api->xlio_poll_group_poll(group);
        }

        destroy_poll_group(group);

        wait_fork(pid);
    }
}

#endif /* EXTRA_API_ENABLED */
