/*
 * SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
 * Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: GPL-2.0-only or BSD-2-Clause
 */

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <fcntl.h>
#include <sys/time.h>
#include <sys/socket.h>
#include <sys/un.h>

#include "vlogger/vlogger.h"
#include "utils/lock_wrapper.h"
#include "core/sock/sock-redirect.h"
#include "core/util/list.h"
#include "core/util/agent.h"

#undef MODULE_NAME
#define MODULE_NAME "agent:"
#undef MODULE_HDR
#define MODULE_HDR MODULE_NAME "%d:%s() "

#define AGENT_DEFAULT_MSG_NUM  (512)
#define AGENT_DEFAULT_MSG_GROW (16) /* number of messages to grow */
#define AGENT_DEFAULT_INACTIVE                                                                     \
    (10) /* periodic time for establishment connection attempts (in sec) */
#define AGENT_DEFAULT_ALIVE (1) /* periodic time for alive check (in sec) */

/* Force system call */
#ifdef XLIO_STATIC_BUILD
#define sys_call(_result, _func, ...)                                                              \
    do {                                                                                           \
        _result = ::_func(__VA_ARGS__);                                                            \
    } while (0)
#else /* XLIO_STATIC_BUILD */
#define sys_call(_result, _func, ...)                                                              \
    do {                                                                                           \
        if (orig_os_api._func)                                                                     \
            _result = orig_os_api._func(__VA_ARGS__);                                              \
        else                                                                                       \
            _result = ::_func(__VA_ARGS__);                                                        \
    } while (0)
#endif /* XLIO_STATIC_BUILD */

/* Print user notification */
#define output_fatal()                                                                             \
    do {                                                                                           \
        vlog_levels_t _level =                                                                     \
            (mce_sys_var::HYPER_MSHV == safe_mce_sys().hypervisor ? VLOG_WARNING : VLOG_DEBUG);    \
        vlog_printf(_level, "*************************************************************\n");    \
        if (rc == -EPROTONOSUPPORT)                                                                \
            vlog_printf(                                                                           \
                _level,                                                                            \
                "* Protocol version mismatch was found between the library and the service. *\n"); \
        else                                                                                       \
            vlog_printf(_level, "* Can not establish connection with the service.      *\n");      \
        vlog_printf(_level, "* UDP/TCP connections are likely to be limited.             *\n");    \
        vlog_printf(_level, "*************************************************************\n");    \
    } while (0)

agent *g_p_agent = nullptr;

agent::agent()
    : m_state(AGENT_CLOSED)
    , m_sock_fd(-1)
    , m_pid_fd(-1)
    , m_sock_file {}
    , m_pid_file {}
    , m_msg_num(AGENT_DEFAULT_MSG_NUM)
{
    int rc = 0;
    agent_msg_t *msg = nullptr;
    int i = 0;

    INIT_LIST_HEAD(&m_cb_queue);
    INIT_LIST_HEAD(&m_free_queue);
    INIT_LIST_HEAD(&m_wait_queue);

    /* Fill free queue with empty messages */
    i = m_msg_num;
    m_msg_num = 0;
    const char *path = safe_mce_sys().service_notify_dir;
    while (i--) {
        /* coverity[overwrite_var] */
        msg = (agent_msg_t *)calloc(1, sizeof(*msg));
        if (!msg) {
            rc = -ENOMEM;
            __log_dbg("failed queue creation (rc = %d)", rc);
            goto err;
        }
        msg->length = 0;
        msg->tag = AGENT_MSG_TAG_INVALID;
        list_add_tail(&msg->item, &m_free_queue);
        m_msg_num++;
    }

    if ((mkdir(path, 0777) != 0) && (errno != EEXIST)) {
        rc = -errno;
        __log_dbg("failed create folder %s (rc = %d)", path, rc);
        goto err;
    }

    rc = snprintf(m_sock_file, sizeof(m_sock_file) - 1, "%s/%s.%d.sock", path, XLIO_AGENT_BASE_NAME,
                  getpid());
    if ((rc < 0) || (rc == (sizeof(m_sock_file) - 1))) {
        rc = -ENOMEM;
        __log_dbg("failed allocate sock file (rc = %d)", rc);
        goto err;
    }

    rc = snprintf(m_pid_file, sizeof(m_pid_file) - 1, "%s/%s.%d.pid", path, XLIO_AGENT_BASE_NAME,
                  getpid());
    if ((rc < 0) || (rc == (sizeof(m_pid_file) - 1))) {
        rc = -ENOMEM;
        __log_dbg("failed allocate pid file (rc = %d)", rc);
        goto err;
    }

    sys_call(m_pid_fd, open, m_pid_file, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR | S_IRGRP);
    if (m_pid_fd < 0) {
        rc = -errno;
        __log_dbg("failed open pid file (rc = %d)", rc);
        goto err;
    }

    rc = create_agent_socket();
    if (rc < 0) {
        __log_dbg("failed open sock file (rc = %d)", rc);
        goto err;
    }

    /* Initialization is mostly completed
     * At the moment it does not matter if connection with
     * daemon can be done here or later
     */
    m_state = AGENT_INACTIVE;

    rc = send_msg_init();
    if (rc < 0) {
        __log_dbg("failed establish connection with daemon (rc = %d)", rc);
        goto err;
    }

    /* coverity[leaked_storage] */
    return;

err:
    /* There is no chance to establish connection with daemon
     * because of internal problems or communication protocol
     * variance
     * So do not try anymore
     */
    m_state = AGENT_CLOSED;

    output_fatal();

    while (!list_empty(&m_free_queue)) {
        /* coverity[overwrite_var] */
        msg = list_first_entry(&m_free_queue, agent_msg_t, item);
        list_del_init(&msg->item);
        free(msg);
    }

    if (m_pid_fd > 0) {
        int ret = 0;
        NOT_IN_USE(ret);
        /* coverity[RESOURCE_LEAK] */
        sys_call(ret, close, m_pid_fd);
        /* coverity[leaked_handle] */
        m_pid_fd = -1;
        unlink(m_pid_file);
    }

    if (m_sock_fd > 0) {
        int ret = 0;
        NOT_IN_USE(ret);
        sys_call(ret, close, m_sock_fd);
        m_sock_fd = -1;
        unlink(m_sock_file);
    }

    /* coverity[leaked_storage] */
    return;
}

agent::~agent()
{
    agent_msg_t *msg = nullptr;
    agent_callback_t *cb = nullptr;

    if (AGENT_CLOSED == m_state) {
        return;
    }

    progress();
    send_msg_exit();

    m_state = AGENT_CLOSED;

    /* This delay is needed to allow process EXIT message
     * before event from system file monitor is raised
     */
    usleep(1000);

    while (!list_empty(&m_cb_queue)) {
        cb = list_first_entry(&m_cb_queue, agent_callback_t, item);
        list_del_init(&cb->item);
        free(cb);
    }

    while (!list_empty(&m_free_queue)) {
        msg = list_first_entry(&m_free_queue, agent_msg_t, item);
        list_del_init(&msg->item);
        free(msg);
    }

    if (m_sock_fd > 0) {
        int ret = 0;
        NOT_IN_USE(ret);
        sys_call(ret, close, m_sock_fd);
        unlink(m_sock_file);
    }

    if (m_pid_fd > 0) {
        int ret = 0;
        NOT_IN_USE(ret);
        sys_call(ret, close, m_pid_fd);
        unlink(m_pid_file);
    }
}

void agent::register_cb(agent_cb_t fn, void *arg)
{
    agent_callback_t *cb = nullptr;
    struct list_head *entry = nullptr;

    if (AGENT_CLOSED == m_state) {
        return;
    }

    if (!fn) {
        return;
    }

    m_cb_lock.lock();
    /* check if it exists in the queue */
    list_for_each(entry, &m_cb_queue)
    {
        cb = list_entry(entry, agent_callback_t, item);
        if ((cb->cb == fn) && (cb->arg == arg)) {
            m_cb_lock.unlock();
            return;
        }
    }
    /* allocate new callback element and add to the queue */
    cb = (agent_callback_t *)calloc(1, sizeof(*cb));
    if (cb) {
        cb->cb = fn;
        cb->arg = arg;
        list_add_tail(&cb->item, &m_cb_queue);
    }
    m_cb_lock.unlock();
    /* coverity[leaked_storage] */
}

void agent::unregister_cb(agent_cb_t fn, void *arg)
{
    agent_callback_t *cb = nullptr;
    struct list_head *entry = nullptr;

    if (AGENT_CLOSED == m_state) {
        return;
    }

    m_cb_lock.lock();
    /* find element in the queue and remove one */
    list_for_each(entry, &m_cb_queue)
    {
        cb = list_entry(entry, agent_callback_t, item);
        if ((cb->cb == fn) && (cb->arg == arg)) {
            list_del_init(&cb->item);
            free(cb);
            m_cb_lock.unlock();
            return;
        }
    }
    m_cb_lock.unlock();
}

int agent::put(const void *data, size_t length, intptr_t tag)
{
    agent_msg_t *msg = nullptr;
    int i = 0;

    if (AGENT_CLOSED == m_state) {
        return 0;
    }

    if (m_sock_fd < 0) {
        return -EBADF;
    }

    if (length > sizeof(msg->data)) {
        return -EINVAL;
    }

    m_msg_lock.lock();

    /* put any message in case agent is active to avoid queue uncontrolled grow
     * progress() function is able to call registered callbacks in case
     * it detects that link with daemon is up
     */
    if (AGENT_ACTIVE == m_state) {
        /* allocate new message in case free queue is empty */
        if (list_empty(&m_free_queue)) {
            for (i = 0; i < AGENT_DEFAULT_MSG_GROW; i++) {
                /* coverity[overwrite_var] */
                msg = (agent_msg_t *)malloc(sizeof(*msg));
                if (!msg) {
                    break;
                }
                msg->length = 0;
                msg->tag = AGENT_MSG_TAG_INVALID;
                list_add_tail(&msg->item, &m_free_queue);
                m_msg_num++;
            }
        }
        /* get message from free queue */
        /* coverity[overwrite_var] */
        msg = list_first_entry(&m_free_queue, agent_msg_t, item);
        list_del_init(&msg->item);

        /* put message into wait queue */
        list_add_tail(&msg->item, &m_wait_queue);
    }

    /* update message */
    if (msg) {
        memcpy(&msg->data, data, length);
        msg->length = length;
        msg->tag = tag;
    }

    m_msg_lock.unlock();

    return 0;
}

void agent::progress(void)
{
    agent_msg_t *msg = nullptr;
    struct timeval tv_now = TIMEVAL_INITIALIZER;
    static struct timeval tv_inactive_elapsed = TIMEVAL_INITIALIZER;
    static struct timeval tv_alive_elapsed = TIMEVAL_INITIALIZER;

    if (AGENT_CLOSED == m_state) {
        return;
    }

    gettime(&tv_now);

    /* Attempt to establish connection with daemon */
    if (AGENT_INACTIVE == m_state) {
        /* Attempt can be done less often than progress in active state */
        /* cppcheck-suppress syntaxError */
        if (tv_cmp(&tv_inactive_elapsed, &tv_now, <)) {
            tv_inactive_elapsed = tv_now;
            tv_inactive_elapsed.tv_sec += AGENT_DEFAULT_INACTIVE;
            if (0 <= send_msg_init()) {
                progress_cb();
                goto go;
            }
        }
        return;
    }

go:
    /* Check connection with daemon during active state */
    if (list_empty(&m_wait_queue)) {
        if (tv_cmp(&tv_alive_elapsed, &tv_now, <)) {
            check_link();
        }
    } else {
        tv_alive_elapsed = tv_now;
        tv_alive_elapsed.tv_sec += AGENT_DEFAULT_ALIVE;

        /* Process all messages that are in wait queue */
        m_msg_lock.lock();
        while (!list_empty(&m_wait_queue)) {
            msg = list_first_entry(&m_wait_queue, agent_msg_t, item);
            if (0 > send(msg)) {
                break;
            }
            list_del_init(&msg->item);
            msg->length = 0;
            msg->tag = AGENT_MSG_TAG_INVALID;
            list_add_tail(&msg->item, &m_free_queue);
        }
        m_msg_lock.unlock();
    }
}

void agent::progress_cb(void)
{
    agent_callback_t *cb = nullptr;
    struct list_head *entry = nullptr;

    m_cb_lock.lock();
    list_for_each(entry, &m_cb_queue)
    {
        cb = list_entry(entry, agent_callback_t, item);
        cb->cb(cb->arg);
    }
    m_cb_lock.unlock();
}

int agent::send(agent_msg_t *msg)
{
    int rc = 0;

    if (AGENT_ACTIVE != m_state) {
        return -ENODEV;
    }

    if (m_sock_fd < 0) {
        return -EBADF;
    }

    if (!msg) {
        return -EINVAL;
    }

    /* send() in blocking manner */
    sys_call(rc, send, m_sock_fd, (void *)&msg->data, msg->length, 0);
    if (rc < 0) {
        __log_dbg("Failed to send() errno %d (%s)", errno, strerror(errno));
        rc = -errno;
        m_state = AGENT_INACTIVE;
        __log_dbg("Agent is inactivated. state = %d", m_state);
        goto err;
    }

err:
    return rc;
}

int agent::send_msg_init(void)
{
    int rc = 0;
    struct sockaddr_un server_addr;
    struct xlio_msg_init data;

    if (AGENT_ACTIVE == m_state) {
        return 0;
    }

    if (m_sock_fd < 0) {
        return -EBADF;
    }

    /* Set server address */
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sun_family = AF_UNIX;
    strncpy(server_addr.sun_path, XLIO_AGENT_ADDR, sizeof(server_addr.sun_path) - 1);

    sys_call(rc, connect, m_sock_fd, (struct sockaddr *)&server_addr, sizeof(struct sockaddr_un));
    if (rc < 0) {
        __log_dbg("Failed to connect() errno %d (%s)", errno, strerror(errno));
        rc = -ECONNREFUSED;
        goto err;
    }

    memset(&data, 0, sizeof(data));
    data.hdr.code = XLIO_MSG_INIT;
    data.hdr.ver = XLIO_AGENT_VER;
    data.hdr.pid = getpid();
    data.ver = (PRJ_LIBRARY_MAJOR << 12) | (PRJ_LIBRARY_MINOR << 8) | (PRJ_LIBRARY_RELEASE << 4) |
        PRJ_LIBRARY_REVISION;

    /* send(XLIO_MSG_INIT) in blocking manner */
    sys_call(rc, send, m_sock_fd, &data, sizeof(data), 0);
    if (rc < 0) {
        __log_dbg("Failed to send(XLIO_MSG_INIT) errno %d (%s)", errno, strerror(errno));
        rc = -ECONNREFUSED;
        goto err;
    }

    /* recv(XLIO_MSG_INIT|ACK) in blocking manner */
    memset(&data, 0, sizeof(data));
    sys_call(rc, recv, m_sock_fd, &data, sizeof(data), 0);
    if (rc < (int)sizeof(data)) {
        __log_dbg("Failed to recv(XLIO_MSG_INIT) errno %d (%s)", errno, strerror(errno));
        rc = -ECONNREFUSED;
        goto err;
    }

    if (data.hdr.code != (XLIO_MSG_INIT | XLIO_MSG_ACK) || data.hdr.pid != getpid()) {
        __log_dbg("Protocol is not supported: code = 0x%X pid = %d", data.hdr.code, data.hdr.pid);
        rc = -EPROTO;
        goto err;
    }

    if (data.hdr.ver < XLIO_AGENT_VER) {
        __log_dbg("Protocol version mismatch: agent ver = 0x%X service ver = 0x%X", XLIO_AGENT_VER,
                  data.hdr.ver);
        rc = -EPROTONOSUPPORT;
        goto err;
    }

    m_state = AGENT_ACTIVE;
    __log_dbg("Agent is activated. state = %d", m_state);

err:
    return rc;
}

int agent::send_msg_exit(void)
{
    int rc = 0;
    struct xlio_msg_exit data;

    if (AGENT_ACTIVE != m_state) {
        return -ENODEV;
    }

    if (m_sock_fd < 0) {
        return -EBADF;
    }

    m_state = AGENT_INACTIVE;
    __log_dbg("Agent is inactivated. state = %d", m_state);

    memset(&data, 0, sizeof(data));
    data.hdr.code = XLIO_MSG_EXIT;
    data.hdr.ver = XLIO_AGENT_VER;
    data.hdr.pid = getpid();

    /* send(XLIO_MSG_EXIT) in blocking manner */
    sys_call(rc, send, m_sock_fd, &data, sizeof(data), 0);
    if (rc < 0) {
        __log_dbg("Failed to send(XLIO_MSG_EXIT) errno %d (%s)", errno, strerror(errno));
        rc = -errno;
        goto err;
    }

    return 0;
err:
    return rc;
}

int agent::create_agent_socket(void)
{
    int rc = 0;
    int optval = 1;
    struct timeval opttv;
    struct sockaddr_un sock_addr;

    /* Create UNIX UDP socket to receive data from XLIO processes */
    memset(&sock_addr, 0, sizeof(sock_addr));
    sock_addr.sun_family = AF_UNIX;
    strncpy(sock_addr.sun_path, m_sock_file, sizeof(sock_addr.sun_path) - 1);
    /* remove possible old socket */
    unlink(m_sock_file);

    sys_call(m_sock_fd, socket, AF_UNIX, SOCK_DGRAM, 0);
    if (m_sock_fd < 0) {
        __log_dbg("Failed to call socket() errno %d (%s)", errno, strerror(errno));
        rc = -errno;
        goto err;
    }

    optval = 1;
    sys_call(rc, setsockopt, m_sock_fd, SOL_SOCKET, SO_REUSEADDR, (const void *)&optval,
             sizeof(optval));
    if (rc < 0) {
        __log_dbg("Failed to call setsockopt(SO_REUSEADDR) errno %d (%s)", errno, strerror(errno));
        rc = -errno;
        goto err;
    }

    /* Sets the timeout value as 3 sec that specifies the maximum amount of time
     * an input function waits until it completes.
     */
    opttv.tv_sec = 3;
    opttv.tv_usec = 0;
    sys_call(rc, setsockopt, m_sock_fd, SOL_SOCKET, SO_RCVTIMEO, (const void *)&opttv,
             sizeof(opttv));
    if (rc < 0) {
        __log_dbg("Failed to call setsockopt(SO_RCVTIMEO) errno %d (%s)", errno, strerror(errno));
        rc = -errno;
        goto err;
    }

    /* bind created socket */
    sys_call(rc, bind, m_sock_fd, (struct sockaddr *)&sock_addr, sizeof(sock_addr));
    if (rc < 0) {
        __log_dbg("Failed to call bind() errno %d (%s)", errno, strerror(errno));
        rc = -errno;
        goto err;
    }

err:
    return rc;
}

void agent::check_link(void)
{
    int rc = 0;
    static struct sockaddr_un server_addr;
    static int flag = 0;

    /* Set server address */
    if (!flag) {
        flag = 1;
        memset(&server_addr, 0, sizeof(server_addr));
        server_addr.sun_family = AF_UNIX;
        strncpy(server_addr.sun_path, XLIO_AGENT_ADDR, sizeof(server_addr.sun_path) - 1);
    }

    sys_call(rc, connect, m_sock_fd, (struct sockaddr *)&server_addr, sizeof(struct sockaddr_un));
    if (rc < 0) {
        __log_dbg("Failed to connect() errno %d (%s)", errno, strerror(errno));
        m_state = AGENT_INACTIVE;
        __log_dbg("Agent is inactivated. state = %d", m_state);
    }
}
