/*
 * Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <config.h>
#include <errno.h>

#include "batch.h"
#include "id-fpool.h"
#include "netdev-offload.h"
#include "netdev-offload-private.h"
#include "netdev-offload-provider.h"
#include "netdev-provider.h"
#include "netdev-vport.h"
#include "openvswitch/vlog.h"
#include "ovs-doca.h"

VLOG_DEFINE_THIS_MODULE(netdev_offload_ext);

bool netdev_offload_ct_on_ct_nat = false;
static bool ct_labels_mapping = false;
static bool disable_zone_tables = false;

void
netdev_offload_upkeep(struct netdev *netdev, bool quiescing)
{
    const struct netdev_flow_api *flow_api =
        ovsrcu_get(const struct netdev_flow_api *, &netdev->flow_api);

    if (flow_api && flow_api->upkeep) {
        flow_api->upkeep(netdev, quiescing);
    }
}

int
netdev_offload_get_stats(struct netdev *netdev,
                         struct netdev_offload_stats *stats)
{
    const struct netdev_flow_api *flow_api =
        ovsrcu_get(const struct netdev_flow_api *, &netdev->flow_api);

    return (flow_api && flow_api->get_stats)
           ? flow_api->get_stats(netdev, stats)
           : EOPNOTSUPP;
}

int
netdev_ct_counter_query(struct netdev *netdev,
                        uintptr_t counter,
                        long long now,
                        long long prev_now,
                        struct dpif_flow_stats *stats)
{
    const struct netdev_flow_api *flow_api =
        ovsrcu_get(const struct netdev_flow_api *, &netdev->flow_api);

    return (flow_api && flow_api->ct_counter_query)
           ? flow_api->ct_counter_query(netdev, counter, now, prev_now, stats)
           : EOPNOTSUPP;
}

bool
netdev_is_ct_labels_mapping_enabled(void)
{
    return ct_labels_mapping;
}

bool
netdev_is_zone_tables_disabled(void)
{
    return disable_zone_tables;
}

void
netdev_set_flow_api_enabled_ext(const struct smap *ovs_other_config)
{
    {
        bool req_conf = smap_get_bool(ovs_other_config,
                                      "ct-action-on-nat-conns", false);

        if (req_conf && smap_get_bool(ovs_other_config, "doca-init", false)) {
            VLOG_WARN_ONCE("ct-action-on-nat-conns is not supported by OVS-DOCA.");
        } else if (netdev_offload_ct_on_ct_nat != req_conf) {
            netdev_offload_ct_on_ct_nat = req_conf;
            VLOG_INFO("offloads CT on NAT connections: %s",
                      netdev_offload_ct_on_ct_nat ? "enabled" : "disabled");
        }
    }

    if (smap_get_bool(ovs_other_config, "ct-labels-mapping", false)) {
        static struct ovsthread_once once = OVSTHREAD_ONCE_INITIALIZER;

        if (ovsthread_once_start(&once)) {
            ct_labels_mapping = true;
            VLOG_INFO("CT offloads: labels mapping enabled");
            ovsthread_once_done(&once);
        }
    }

    if (smap_get_bool(ovs_other_config, "disable-zone-tables", false)) {
        static struct ovsthread_once once = OVSTHREAD_ONCE_INITIALIZER;

        if (ovsthread_once_start(&once)) {
            if (!smap_get_bool(ovs_other_config, "doca-init", false)) {
                disable_zone_tables = true;
                VLOG_INFO("CT offloads: zone tables disabled");
            } else {
                VLOG_WARN("disable-zone-tables flag is ignored with doca");
            }
            ovsthread_once_done(&once);
        }
    }
}

int
netdev_packet_hw_hash(struct netdev *netdev,
                      struct dp_packet *packet,
                      uint32_t seed,
                      uint32_t *hash)
{
    const struct netdev_flow_api *flow_api =
        ovsrcu_get(const struct netdev_flow_api *, &netdev->flow_api);

    return (flow_api && flow_api->packet_hw_hash)
        ? flow_api->packet_hw_hash(netdev, packet, seed, hash)
        : EOPNOTSUPP;
}

int
netdev_packet_hw_entropy(struct netdev *netdev,
                         struct dp_packet *packet,
                         uint16_t *entropy)
{
    const struct netdev_flow_api *flow_api =
        ovsrcu_get(const struct netdev_flow_api *, &netdev->flow_api);

    return (flow_api && flow_api->packet_hw_entropy)
        ? flow_api->packet_hw_entropy(netdev, packet, entropy)
        : EOPNOTSUPP;
}

/* Same as 'netdev_ports_get', but do not keep a reference on the
 * netdev returned. This is used for short accesses (within an RCU grace period)
 * when we already know any such reference will remain valid the whole time. */
struct netdev *
netdev_ports_get_short(odp_port_t port_no, const char *dpif_type)
{
    struct port_to_netdev_data *node;
    struct netdev *ret = NULL;

    ovs_rwlock_rdlock(&port_to_netdev_rwlock);
    node = netdev_ports_lookup(port_no, dpif_type);
    if (node) {
        bool visible;

        atomic_read_explicit(&node->visible, &visible, memory_order_acquire);
        if (visible) {
            ret = node->netdev;
        }
    }
    ovs_rwlock_unlock(&port_to_netdev_rwlock);

    return ret;
}
