/*
 * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


#include <config.h>

#include "cmap.h"
#include "concurrent-array.h"
#include "id-fpool.h"
#include "id-refmap.h"
#include "hash.h"
#include "openvswitch/vlog.h"
#include "ovs-atomic.h"
#include "timeval.h"

VLOG_DEFINE_THIS_MODULE(id_refmap);
static struct vlog_rate_limit rl = VLOG_RATE_LIMIT_INIT(600, 600);

struct id_refmap {
    struct cmap map;
    struct id_fpool *pool;
    struct ovs_mutex map_lock;
    size_t key_size;
    unsigned int delay_ms;
    unsigned int nb_thread;
    uint32_t invalid_id;
    uint32_t base_id;
    uint32_t n_ids;
    struct ovs_list *free_lists;
    id_refmap_format format;
    char *name;
    struct concurrent_array *id2node;
};

struct id_refmap_node {
    struct cmap_node map_node;
    struct ovsrcu_gc_node rcu_node;
    struct ovs_refcount refcount;
    uint32_t hash;
    uint32_t id; /* value. */
    char key[0];
};

struct delayed_release_item {
    struct ovs_list in_free_list;
    struct id_refmap_node *node;
    long long int insertion_time_ms;
    unsigned int tid;
};

struct id_refmap *
id_refmap_create(unsigned int nb_thread, const char *name,
                 size_t key_size, id_refmap_format format,
                 uint32_t base_id, uint32_t n_ids,
                 unsigned int id_release_delay_ms,
                 bool disable_map_shrink)
{
    struct id_refmap *irfm;
    struct id_fpool *pool;
    struct concurrent_array *c;

    if (n_ids == 0) {
        VLOG_ERR("Failed to create id-refmap %s: invalid range of 0.", name);
        return NULL;
    }

    if (base_id - 1 == base_id + n_ids) {
        /* No room left in the provided range for an error value,
         * no way to signal an error to the user. */
        VLOG_ERR("Failed to create id-refmap %s: "
                 "range is too large and no sentinel value can be set.",
                 name);
        return NULL;
    }

    pool = id_fpool_create(nb_thread, base_id, n_ids);
    if (pool == NULL) {
        VLOG_ERR("Failed to create id-refmap %s: "
                 "ID pool could not be initialized.",
                 name);
        return NULL;
    }

    c = concurrent_array_create();
    if (c == NULL) {
        VLOG_ERR("Failed to create id-refmap %s: "
                 "reversal array could not be initialized.",
                 name);
        id_fpool_destroy(pool);
        return NULL;
    }

    irfm = xzalloc(sizeof *irfm);
    irfm->name = xstrdup(name);
    irfm->key_size = key_size;
    irfm->format = format;
    irfm->pool = pool;
    irfm->id2node = c;

    irfm->nb_thread = nb_thread;
    irfm->delay_ms = id_release_delay_ms;
    if (irfm->delay_ms > 0) {
        irfm->free_lists = xcalloc(nb_thread, sizeof *irfm->free_lists);
        for (unsigned int i = 0; i < nb_thread; i++) {
            ovs_list_init(&irfm->free_lists[i]);
        }
    }

    irfm->base_id = base_id;
    irfm->n_ids = n_ids;
    /* Expect most users to use base ID either 0 or 1.
     * In all cases, as the ID range is contiguous the one
     * ID before the first valid is necessarily invalid. */
    irfm->invalid_id = base_id - 1;

    ovs_mutex_init(&irfm->map_lock);
    cmap_init(&irfm->map);

    if (disable_map_shrink) {
        cmap_set_min_load(&irfm->map, 0.0);
    }

    return irfm;
}

static void
id_refmap_node_free(struct id_refmap_node *node)
{
    free(node);
}

void
id_refmap_destroy(struct id_refmap *irfm)
{
    if (irfm == NULL) {
        return;
    }

    if (irfm->delay_ms > 0) {
        /* Force-release everything. */
        long long int future = time_msec() + irfm->delay_ms + 42;

        for (unsigned int tid = 0; tid < irfm->nb_thread; tid++) {
            id_refmap_upkeep(irfm, tid, future);
        }
        free(irfm->free_lists);
    }

    ovs_mutex_lock(&irfm->map_lock);
    if (!cmap_is_empty(&irfm->map)) {
        struct id_refmap_node *node;

        VLOG_WARN("%s: %s called with elements remaining in the map",
                  irfm->name, __func__);
        CMAP_FOR_EACH (node, map_node, &irfm->map) {
            /* No need to remove the node from the CMAP, it will
             * be destroyed immediately. */
            ovsrcu_gc(id_refmap_node_free, node, rcu_node);
        }
    }
    cmap_destroy(&irfm->map);
    ovs_mutex_unlock(&irfm->map_lock);

    ovs_mutex_destroy(&irfm->map_lock);
    concurrent_array_destroy(irfm->id2node);
    id_fpool_destroy(irfm->pool);
    free(irfm->name);
    free(irfm);
}

static struct id_refmap_node *
id_refmap_node_from_id(struct id_refmap *irfm, uint32_t id)
{
    /* Move 'id' into the [0, n_ids) range. */
    uint32_t idx = id - irfm->base_id;

    /* Checks both under and over flow. */
    if (idx >= irfm->n_ids) {
        return NULL;
    }

    return concurrent_array_get(irfm->id2node, idx);
}

void *
id_refmap_find_key(struct id_refmap *irfm, uint32_t id)
{
    struct id_refmap_node *node;

    node = id_refmap_node_from_id(irfm, id);
    return node ? &node->key[0] : NULL;
}

static struct id_refmap_node *
id_refmap_node_from_key(struct id_refmap *irfm, void *key)
{
    struct id_refmap_node *node;
    uint32_t hash;

    hash = hash_bytes(key, irfm->key_size, 0);
    CMAP_FOR_EACH_WITH_HASH (node, map_node, hash, &irfm->map) {
        if (!memcmp(key, node->key, irfm->key_size) &&
            ovs_refcount_read(&node->refcount) != 0) {
            return node;
        }
    }
    return NULL;
}

uint32_t
id_refmap_find(struct id_refmap *irfm, void *key)
{
    struct id_refmap_node *node;

    node = id_refmap_node_from_key(irfm, key);
    if (node) {
        return node->id;
    }
    return irfm->invalid_id;
}

bool
id_refmap_ref(struct id_refmap *irfm, unsigned int tid, void *key, uint32_t *id)
{
    struct id_refmap_node *node;
    uint32_t idx;
    bool error;

    error = false;

    node = id_refmap_node_from_key(irfm, key);
    if (node) {
        idx = node->id - irfm->base_id;
        ovs_assert(idx < irfm->n_ids);
        if (!ovs_refcount_try_ref_rcu(&node->refcount)) {
            /* If already fallen to 0, the release for this node
             * was scheduled but this reference is interrupting it. */
            ovs_refcount_init(&node->refcount);
        }
        goto log_value;
    }

    node = xzalloc(sizeof(struct id_refmap_node) + irfm->key_size);
    node->hash = hash_bytes(key, irfm->key_size, 0);
    ovs_refcount_init(&node->refcount);
    memcpy(node->key, key, irfm->key_size);

    node->id = irfm->invalid_id;
    if (!id_fpool_new_id(irfm->pool, tid, &node->id)) {
        error = true;
        goto log_value;
    }

    ovs_mutex_lock(&irfm->map_lock);
    idx = node->id - irfm->base_id;
    ovs_assert(idx < irfm->n_ids);
    concurrent_array_set(irfm->id2node, idx, node);
    cmap_insert(&irfm->map, &node->map_node, node->hash);
    ovs_mutex_unlock(&irfm->map_lock);

log_value:
    if (OVS_UNLIKELY(!VLOG_DROP_DBG((&rl)))) {
        struct ds s = DS_EMPTY_INITIALIZER;

        if (irfm->format) {
            ds_put_cstr(&s, ", '");
            irfm->format(&s, key);
            ds_put_cstr(&s, "'");
        }
        VLOG_DBG("%s: %p ref, refcnt=%d id=%"PRIu32"%s",
                 irfm->name, node, ovs_refcount_read(&node->refcount),
                 node->id, ds_cstr(&s));
        ds_destroy(&s);
    }

    if (error) {
        free(node);
    } else {
        *id = node->id;
    }

    return !error;
}

static void
id_refmap_node_release(struct id_refmap *irfm, unsigned int tid,
                       struct id_refmap_node *node)
{
    uint32_t idx;

    idx = node->id - irfm->base_id;
    ovs_assert(idx < irfm->n_ids);

    concurrent_array_set(irfm->id2node, idx, NULL);
    id_fpool_free_id(irfm->pool, tid, node->id);

    ovs_mutex_lock(&irfm->map_lock);
    cmap_remove(&irfm->map, &node->map_node, node->hash);
    ovs_mutex_unlock(&irfm->map_lock);

    ovsrcu_gc(id_refmap_node_free, node, rcu_node);
}

static void
delayed_release_item_destroy(struct id_refmap *irfm,
                             struct delayed_release_item *item)
{
    id_refmap_node_release(irfm, item->tid, item->node);
    free(item);
}

void
id_refmap_upkeep(struct id_refmap *irfm, unsigned int tid,
                 long long int now)
{
    struct delayed_release_item *item;

    if (irfm == NULL || irfm->delay_ms == 0) {
        return;
    }

    ovs_assert(tid < irfm->nb_thread);

    LIST_FOR_EACH_SAFE (item, in_free_list, &irfm->free_lists[tid]) {
        if (now < item->insertion_time_ms + irfm->delay_ms) {
            break;
        }
        ovs_list_remove(&item->in_free_list);
        delayed_release_item_destroy(irfm, item);
    }
}

static void
id_refmap_node_schedule_release(struct id_refmap *irfm, unsigned int tid,
                                struct id_refmap_node *node)
{
    struct delayed_release_item *item;
    struct ovs_list *free_list;

    item = xzalloc(sizeof *item);
    item->node = node;
    item->tid = tid;

    free_list = &irfm->free_lists[tid];
    item->insertion_time_ms = time_msec();
    ovs_list_push_back(free_list, &item->in_free_list);
}

bool
id_refmap_unref(struct id_refmap *irfm, unsigned int tid, uint32_t id)
{
    struct id_refmap_node *node;

    /* Checks both under and over flow. */
    if ((id - irfm->base_id) >= irfm->n_ids) {
        return false;
    }

    node = id_refmap_node_from_id(irfm, id);
    if (node == NULL) {
        VLOG_ERR("%s: Unknown ID %" PRIu32 " provided for unref.",
                 OVS_SOURCE_LOCATOR, id);
        return false;
    }

    if (OVS_UNLIKELY(!VLOG_DROP_DBG((&rl)))) {
        VLOG_DBG("%s: %p unref, refcnt=%d id %"PRIu32, irfm->name,
                 node, ovs_refcount_read(&node->refcount), node->id);
    }

    if (ovs_refcount_unref(&node->refcount) == 1) {
        if (irfm->delay_ms == 0) {
            id_refmap_node_release(irfm, tid, node);
        } else {
            id_refmap_node_schedule_release(irfm, tid, node);
        }
        return true;
    }

    return false;
}
