/*
 * SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
 * Copyright (c) 2019-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

#include <config.h>
#include "netdev-doca-vdpa.h"

#include <rte_bus.h>
#include <rte_dev.h>
#include <rte_malloc.h>
#include <rte_pci.h>
#include <rte_vdpa.h>
#include <rte_vhost.h>

#include "netdev-provider.h"
#include "openvswitch/shash.h"
#include "openvswitch/vlog.h"
#include "util.h"

VLOG_DEFINE_THIS_MODULE(netdev_doca_vdpa);

#define NETDEV_DPDK_VDPA_STATS_MAX_STR_SIZE 128

enum netdev_doca_vdpa_mode {
    NETDEV_DPDK_VDPA_MODE_INIT,
    NETDEV_DPDK_VDPA_MODE_HW,
};

struct netdev_doca_vdpa_relay {
    PADDED_MEMBERS(CACHE_LINE_SIZE,
        char *vf_devargs;
        char *vm_socket;
        char *vhost_name;
        bool started;
        enum netdev_doca_vdpa_mode hw_mode;
        struct rte_vdpa_device *vdpa_dev;
        struct rte_device *rte_dev;
        int vid;
        struct shash_node *map_item;
        struct ovs_mutex lock;
        );
};

static struct shash relays_map = SHASH_INITIALIZER(&relays_map);
static struct ovs_mutex relays_map_lock = OVS_MUTEX_INITIALIZER;

static void
relays_map_add(struct netdev_doca_vdpa_relay *relay)
{
    if (relay->map_item) {
        return;
    }

    ovs_mutex_lock(&relays_map_lock);
    relay->map_item = shash_add(&relays_map, relay->vhost_name, relay);
    ovs_mutex_unlock(&relays_map_lock);
}

static void
relays_map_remove(struct netdev_doca_vdpa_relay *relay)
{
    if (!relay->map_item) {
        return;
    }

    ovs_mutex_lock(&relays_map_lock);
    shash_delete(&relays_map, relay->map_item);
    relay->map_item = NULL;
    ovs_mutex_unlock(&relays_map_lock);
}

static void
netdev_doca_vdpa_free(void *ptr)
{
    if (ptr == NULL) {
        return;
    }
    free(ptr);
    ptr = NULL;
}

static void
netdev_doca_vdpa_free_relay_strings(struct netdev_doca_vdpa_relay *relay)
{
    netdev_doca_vdpa_free(relay->vm_socket);
    netdev_doca_vdpa_free(relay->vf_devargs);
    netdev_doca_vdpa_free(relay->vhost_name);
}

void *
netdev_doca_vdpa_alloc_relay(void)
{
    struct netdev_doca_vdpa_relay *relay;

    relay = rte_zmalloc("ovs_doca_vdpa_relay", sizeof(struct netdev_doca_vdpa_relay),
                        CACHE_LINE_SIZE);
    relay->vid = -1;
    ovs_mutex_init(&relay->lock);
    return relay;
}

static int
netdev_doca_vdpa_new_device(int vid)
{
    char ifname[NETDEV_DPDK_VDPA_STATS_MAX_STR_SIZE];
    struct netdev_doca_vdpa_relay *relay;

    rte_vhost_get_ifname(vid, ifname, sizeof(ifname));
    ovs_mutex_lock(&relays_map_lock);
    relay = shash_find_data(&relays_map, ifname);
    ovs_mutex_unlock(&relays_map_lock);
    if (!relay) {
        VLOG_ERR("cannot find relay for ifname=%s", ifname);
        return -1;
    }
    relay->vid = vid;
    relay->started = true;
    VLOG_INFO("create device callback, vid=%d, ifname=%s", vid, ifname);

    return 0;
}

static void
netdev_doca_vdpa_destroy_device(int vid)
{
    char ifname[NETDEV_DPDK_VDPA_STATS_MAX_STR_SIZE];
    struct netdev_doca_vdpa_relay *relay;

    rte_vhost_get_ifname(vid, ifname, sizeof(ifname));
    ovs_mutex_lock(&relays_map_lock);
    relay = shash_find_data(&relays_map, ifname);
    ovs_mutex_unlock(&relays_map_lock);
    if (!relay) {
        VLOG_ERR("cannot find relay for vid=%d", vid);
        return;
    }
    ovs_mutex_lock(&relay->lock);
    relay->started = false;
    relay->vid = -1;
    ovs_mutex_unlock(&relay->lock);
    VLOG_INFO("destroy device callback, vid=%d, ifname=%s", vid,
              relay->vhost_name);
}

static const struct rte_vhost_device_ops netdev_doca_vdpa_sample_devops = {
        .new_device = netdev_doca_vdpa_new_device,
        .destroy_device = netdev_doca_vdpa_destroy_device,
};

static int
bus_name_cmp(const struct rte_bus *bus, const void *name)
{
    const char *bus_name = rte_bus_name(bus);

    return strncmp(bus_name, name, strlen(bus_name));
}

static bool
netdev_doca_vdpa_is_same_dev(const char *vf_devargs, const char *dev_name)
{
    size_t len;
    char *p;

    p = strstr(vf_devargs, dev_name);
    if (!p) {
        return false;
    }
    len = strlen(dev_name);
    /* The name ends with at least one devarg (class=vdpa). Check for end of
     * string or the delimiter is ','.
     */
    if (p[len] == '\0' || p[len] == ',') {
        return true;
    }
    return false;
}

static int
netdev_doca_vdpa_config_hw_impl(struct netdev_doca_vdpa_relay *relay)
{
    struct rte_dev_iterator dev_iter;
    struct rte_pci_addr pci_addr;
    const char *vhost_path;
    char *vf_devargs;
    char *comma;
    bool is_pci;
    int err;

    vf_devargs = xstrdup(relay->vf_devargs);
    vhost_path = relay->vm_socket;

    /* If bus is not explicitly provided, choose "pci"/"auxiliary" based
     * on provided string. */
    if (!rte_bus_find(NULL, bus_name_cmp, vf_devargs)) {
        /* If provided devargs, need to omit them for the pci check. */
        comma = strchr(vf_devargs, ',');
        if (comma) {
            *comma = '\0';
            is_pci = !rte_pci_addr_parse(vf_devargs, &pci_addr);
            *comma = ',';
        } else {
            is_pci = !rte_pci_addr_parse(vf_devargs, &pci_addr);
        }

        /* If not PCI, implicitly add auxiliary. */
        if (!is_pci) {
            char *new_vf_devargs;

            new_vf_devargs = xasprintf("auxiliary:%s", vf_devargs);
            free(vf_devargs);
            vf_devargs = new_vf_devargs;
        }
    }

    /* If not explicitly provided, add class=vdpa devarg. */
    if (!strstr(vf_devargs, ",class=vdpa")) {
        vf_devargs = xrealloc(vf_devargs, strlen(vf_devargs) +
                              strlen(",class=vdpa") + 1);
        strcat(vf_devargs, ",class=vdpa");
    }

    err = rte_dev_probe(vf_devargs);
    VLOG_INFO("Probe for VDPA device '%s'. User conf '%s'", vf_devargs,
              relay->vf_devargs);
    if (err) {
        VLOG_ERR("Failed to probe for VDPA device %s", vf_devargs);
        goto err_probe;
    }

    RTE_DEV_FOREACH (relay->rte_dev, "class=vdpa", &dev_iter) {
        const char *dev_name = rte_dev_name(relay->rte_dev);

        relay->vdpa_dev = rte_vdpa_find_device_by_name(dev_name);
        if (!relay->vdpa_dev) {
            VLOG_ERR("Failed to find vdpa device id for %s", dev_name);
            goto err_probe;
        }
        if (netdev_doca_vdpa_is_same_dev(vf_devargs, dev_name)) {
            break;
        }
        relay->vdpa_dev = NULL;
    }
    if (!relay->vdpa_dev) {
        VLOG_ERR("Failed to find vdpa device id for %s",
                 rte_dev_name(relay->rte_dev));
        goto err_probe;
    }

    err = rte_vhost_driver_register(vhost_path, RTE_VHOST_USER_CLIENT);
    if (err) {
        VLOG_ERR("rte_vhost_driver_register failed");
        goto err_find;
    }

    err = rte_vhost_driver_callback_register(vhost_path,
            &netdev_doca_vdpa_sample_devops);
    if (err) {
        VLOG_ERR("rte_vhost_driver_callback_register failed");
        goto err_cb_reg;
    }

    err = rte_vhost_driver_attach_vdpa_device(vhost_path, relay->vdpa_dev);
    if (err) {
        VLOG_ERR("Failed to attach vdpa device");
        goto err_cb_reg;
    }

    err = rte_vhost_driver_start(vhost_path);
    if (err) {
        VLOG_ERR("Failed to start vhost driver: %s", vhost_path);
        goto err_start;
    }

    netdev_doca_vdpa_free(relay->vhost_name);
    relay->vhost_name = xstrdup(vhost_path);
    relays_map_add(relay);
    relay->hw_mode = NETDEV_DPDK_VDPA_MODE_HW;
    free(vf_devargs);
    return 0;

err_start:
    if (rte_vhost_driver_detach_vdpa_device(vhost_path)) {
        VLOG_ERR("Failed to detach vdpa device: %s", relay->vm_socket);
    }
err_cb_reg:
    if (rte_vhost_driver_unregister(relay->vm_socket)) {
        VLOG_ERR("Failed to unregister vhost driver for %s", relay->vm_socket);
    }
err_find:
    if (rte_dev_remove(relay->rte_dev)) {
        VLOG_ERR("Failed to rte_dev_remove %p", relay->vdpa_dev);
    }
err_probe:
    free(vf_devargs);
    return err;
}

int
netdev_doca_vdpa_config_impl(struct netdev_doca_vdpa_relay *relay,
                             const char *vm_socket,
                             const char *vf_devargs)
{
    int err = 0;

    /* If config already been done, don't run it again. */
    if (relay->hw_mode != NETDEV_DPDK_VDPA_MODE_INIT) {
        return 0;
    }

    relay->vm_socket = xstrdup(vm_socket);
    relay->vf_devargs = xstrdup(vf_devargs);
    err = netdev_doca_vdpa_config_hw_impl(relay);
    if (err) {
        netdev_doca_vdpa_free_relay_strings(relay);
    }

    return err;
}

void
netdev_doca_vdpa_destruct_impl(struct netdev_doca_vdpa_relay *relay)
{
    if (relay->hw_mode != NETDEV_DPDK_VDPA_MODE_HW) {
        return;
    }

    relay->started = false;
    if (rte_vhost_driver_detach_vdpa_device(relay->vm_socket)) {
        VLOG_ERR("Failed to detach vdpa device: %s", relay->vm_socket);
    }

    if (rte_vhost_driver_unregister(relay->vm_socket)) {
        VLOG_ERR("Failed to unregister vhost driver for %s", relay->vm_socket);
    }
    relay->hw_mode = NETDEV_DPDK_VDPA_MODE_INIT;
    relays_map_remove(relay);
    ignore(rte_dev_remove(relay->rte_dev));
    netdev_doca_vdpa_free_relay_strings(relay);
}

static
void netdev_doca_vdpa_get_hw_stats(struct netdev_doca_vdpa_relay *relay,
                                   struct netdev_custom_stats *cstm_stats)
{
    enum stats_vals {
        VDPA_CUSTOM_STATS_HW_MODE,
        VDPA_CUSTOM_STATS_PACKETS,
        VDPA_CUSTOM_STATS_ERRORS,
        VDPA_CUSTOM_STATS_TOTAL_SIZE
    };
    const char *stats_names[] = {
        [VDPA_CUSTOM_STATS_HW_MODE] = "HW mode",
        [VDPA_CUSTOM_STATS_PACKETS] = "Packets",
        [VDPA_CUSTOM_STATS_ERRORS] = "Errors"
    };
    struct rte_vdpa_device *vdev = relay->vdpa_dev;
    struct rte_vdpa_stat_name *vdpa_stats_names;
    char name[NETDEV_CUSTOM_STATS_NAME_SIZE];
    uint16_t q, num_q, i, counter, index;
    struct rte_vdpa_stat *stats;
    int stats_n;

    if (!vdev) {
        VLOG_ERR("Failed to find vdpa device id for %s", relay->vf_devargs);
        return;
    }

    stats_n = rte_vdpa_get_stats_names(vdev, NULL, 0);
    if (stats_n <= 0) {
        VLOG_ERR("Failed to get names number of device %s.",
                 relay->vf_devargs);
        return;
    }

    vdpa_stats_names = rte_zmalloc("ovs_doca_vdpa_stats_names",
                                   sizeof(*vdpa_stats_names) * stats_n,
                                   0);
    if (!vdpa_stats_names) {
        VLOG_ERR("Failed to allocate memory for stats names of device %s.",
                 relay->vf_devargs);
        return;
    }

    i = rte_vdpa_get_stats_names(vdev, vdpa_stats_names, stats_n);
    if (stats_n != i) {
        VLOG_ERR("Failed to get names of device %s.", relay->vf_devargs);
        goto err_names;
    }

    stats = rte_zmalloc("ovs_doca_vdpa_stats", sizeof(*stats) * stats_n, 0);
    if (!stats) {
        VLOG_ERR("Failed to allocate memory for stats of device %s.",
                 relay->vf_devargs);
        goto err_names;
    }

    num_q = rte_vhost_get_vring_num(relay->vid);
    if (num_q == 0) {
        VLOG_ERR("Failed to get num of actual virtqs for device %s.",
                 relay->vf_devargs);
        goto err_stats;
    }

    cstm_stats->size = 2 * num_q + 1;
    cstm_stats->counters = xcalloc(cstm_stats->size,
                                   sizeof *cstm_stats->counters);

    for (q = 0; q < num_q; q++) {
        int ret = rte_vdpa_get_stats(vdev, q, stats, stats_n);

        if (ret == 0) {
            continue;
        } else if (ret < 0) {
            VLOG_ERR("Failed to get vdpa queue statistics for device %s "
                     "queue %d.", relay->vf_devargs, q);
            break;
        }

        for (i = 0; i < stats_n; ++i) {
            counter = 0;
            if (!strncmp(vdpa_stats_names[stats[i].id].name,
                         "completed_descriptors", RTE_VDPA_STATS_NAME_SIZE)) {
                index = VDPA_CUSTOM_STATS_PACKETS;
                counter = q * 2 + 1;
            } else if (!strncmp(vdpa_stats_names[stats[i].id].name,
                       "completion errors", RTE_VDPA_STATS_NAME_SIZE)) {
                index = VDPA_CUSTOM_STATS_ERRORS;
                counter = q * 2 + 1 + 1;
            }

            if (counter > 0) {
                snprintf(name, NETDEV_CUSTOM_STATS_NAME_SIZE, "%s queue_%u_%s",
                         stats_names[index], q / 2,
                         q % 2 == 0 ? "rx" : "tx");
                ovs_strlcpy(cstm_stats->counters[counter].name, name,
                            NETDEV_CUSTOM_STATS_NAME_SIZE);
                cstm_stats->counters[counter].value = stats[i].value;
            }
        }
    }

    ovs_strlcpy(cstm_stats->counters[VDPA_CUSTOM_STATS_HW_MODE].name,
                stats_names[VDPA_CUSTOM_STATS_HW_MODE],
                NETDEV_CUSTOM_STATS_NAME_SIZE);
    cstm_stats->counters[VDPA_CUSTOM_STATS_HW_MODE].value = 1;

err_stats:
    rte_free(stats);
err_names:
    rte_free(vdpa_stats_names);
}

static void
netdev_doca_vdpa_get_offline_stats(struct netdev_custom_stats *cstm_stats)
{
    enum stats_vals {
        VDPA_CUSTOM_STATS_HW_MODE,
        VDPA_CUSTOM_STATS_DISCONNECTED,
        VDPA_CUSTOM_STATS_TOTAL_SIZE
    };
    const char *stats_names[] = {
        [VDPA_CUSTOM_STATS_HW_MODE] = "HW mode",
        [VDPA_CUSTOM_STATS_DISCONNECTED] = "Disconnected",
    };

    cstm_stats->size = VDPA_CUSTOM_STATS_TOTAL_SIZE;
    cstm_stats->counters = xcalloc(cstm_stats->size,
                                   sizeof *cstm_stats->counters);

    ovs_strlcpy(cstm_stats->counters[VDPA_CUSTOM_STATS_HW_MODE].name,
                stats_names[VDPA_CUSTOM_STATS_HW_MODE],
                NETDEV_CUSTOM_STATS_NAME_SIZE);
    cstm_stats->counters[VDPA_CUSTOM_STATS_HW_MODE].value = 1;

    ovs_strlcpy(cstm_stats->counters[VDPA_CUSTOM_STATS_DISCONNECTED].name,
                stats_names[VDPA_CUSTOM_STATS_DISCONNECTED],
                NETDEV_CUSTOM_STATS_NAME_SIZE);
    cstm_stats->counters[VDPA_CUSTOM_STATS_DISCONNECTED].value = 1;
}

int
netdev_doca_vdpa_get_custom_stats_impl(struct netdev_doca_vdpa_relay *relay,
                                       struct netdev_custom_stats *cstm_stats,
                                       struct ovs_mutex *dev_mutex,
                                       const struct netdev *netdev,
                                       int (*cb)(const struct netdev *,
                                                 struct netdev_custom_stats *))
{
    struct netdev_custom_stats rep_cstm_stats;
    struct netdev_custom_counter *add_counter;
    uint16_t i;

    ovs_mutex_lock(dev_mutex);
    if (ovs_mutex_trylock(&relay->lock)) {
        if (cb) {
            cb(netdev, cstm_stats);
        }
        ovs_mutex_unlock(dev_mutex);
        return 0;
    }

    if (!relay->started) {
        netdev_doca_vdpa_get_offline_stats(cstm_stats);
    } else {
        netdev_doca_vdpa_get_hw_stats(relay, cstm_stats);
    }
    ovs_mutex_unlock(&relay->lock);
    ovs_mutex_unlock(dev_mutex);

    if (cb) {
        memset(&rep_cstm_stats, 0, sizeof rep_cstm_stats);
        cb(netdev, &rep_cstm_stats);
        cstm_stats->counters =
            xrealloc(cstm_stats->counters,
                     (cstm_stats->size + rep_cstm_stats.size) *
                     sizeof *cstm_stats->counters);
        add_counter = &cstm_stats->counters[cstm_stats->size];
        cstm_stats->size += rep_cstm_stats.size;
        for (i = 0; i < rep_cstm_stats.size; i++, add_counter++) {
            add_counter->value = rep_cstm_stats.counters[i].value;
            ovs_strlcpy(add_counter->name, rep_cstm_stats.counters[i].name,
                        NETDEV_CUSTOM_STATS_NAME_SIZE);
        }
        netdev_free_custom_stats_counters(&rep_cstm_stats);
    }
    return 0;
}
