/*
 * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef BATCH_H
#define BATCH_H

#include "util.h"
#include "openvswitch/util.h"

/* Generic batch
 * =============
 *
 * This structure is used to build batches of generic objects.
 * Its purpose is to iterate over the collection objects.
 * There are two iteration types: conserving and filtering.
 *
 * Consider the following batch:
 *
 *     struct batch foos = BATCH_INITIALIZER;
 *     int bars[N];
 *
 *     batch_add(&foos, &bars[0]);
 *     batch_add(&foos, &bars[1]);
 *     batch_add(&foos, &bars[2]);
 *     batch_add(&foos, &bars[3]);
 *
 * When object must be conserved in the batch:
 *
 *     int *foo;
 *
 *     BATCH_FOREACH (foo, foos) {
 *         baz(foo);
 *     }
 *     // 'foos' is unchanged.
 *
 * When objects must be filtered out from the batch:
 *
 *     int *foo;
 *
 *     BATCH_FOREACH_POP (idx, foo, foos) {
 *         if (idx == 0 || idx == 2) {
 *             batch_add(&foos, foo);
 *         }
 *     }
 *     // Only '&bars[0]' and '&bars[2]' are in 'foos'.
 *
 * If 'BATCH_FOREACH_POP' is interrupted before the iteration completed,
 * all objects that were not reinserted with 'batch_add' are still out.
 */

#define BATCH_SIZE 32

struct batch {
    size_t count;
    void *md;
    void *ptrs[BATCH_SIZE];
};

#define BATCH_INITIALIZER { .count = 0, .ptrs = {0}, .md = NULL, }

static inline void
batch_init(struct batch *batch)
{
    *batch = (struct batch) BATCH_INITIALIZER;
}

static inline struct batch
batch_init_one(void *ptr)
{
    return (struct batch) {
        .count = 1,
        .ptrs = {ptr},
    };
}

static inline void
batch_copy(struct batch *dst, struct batch *src)
{
    *dst = *src;
}

static inline size_t
batch_size(struct batch *batch)
{
    return batch->count;
}

static inline bool
batch_is_empty(struct batch *batch)
{
    return !batch_size(batch);
}

static inline bool
batch_is_full(struct batch *batch)
{
    return batch_size(batch) == ARRAY_SIZE(batch->ptrs);
}

static inline bool
batch_add(struct batch *batch, void *ptr)
{
    if (batch->count < ARRAY_SIZE(batch->ptrs)) {
        batch->ptrs[batch->count++] = ptr;
        return true;
    }
    return false;
}

static inline void *
batch_first(struct batch *batch)
{
    return batch_is_empty(batch) ? NULL : batch->ptrs[0];
}

static inline size_t
batch_foreach_pop_start(struct batch *batch)
{
    size_t prev = batch->count;

    batch->count = 0;
    return prev;
}

/* Iterate over all elements of the batch.
 * The batch remains unchanged after the iteration.
 *
 * IDX (size_t): (optional)
 *   Name of the variable that holds the index position
 *   of each iterated elements. Defined within the iteration.
 *
 * PTR (void *):
 *   Holds the value of each successfive elements in the batch
 *   during the iteration. Must be defined before the iteration.
 *
 * BATCH (struct batch *):
 *   Batch to iterate over.
 */
#define BATCH_FOREACH(...) \
    OVERLOAD_SAFE_MACRO(BATCH_FOREACH_3, BATCH_FOREACH_2, 3, __VA_ARGS__)

#define BATCH_FOREACH_3(IDX, PTR, BATCH) \
    for (size_t (IDX) = 0; \
         (IDX) < batch_size(BATCH) && (PTR = (BATCH)->ptrs[(IDX)], true); \
         (IDX)++) \

#define BATCH_FOREACH_2(PTR, BATCH) \
    BATCH_FOREACH_3(OVS_JOIN(idx, __COUNTER__), PTR, BATCH)

/* Iterate over all elements of the batch.
 * The batch is emptied of its elements. Elements that should
 * remain in the batch must be put back by using 'batch_add'.
 *
 * IDX (size_t):
 *   Name of the variable that holds the index position
 *   of each iterated elements. Defined within the iteration.
 *
 * PTR (void *):
 *   Holds the value of each successfive elements in the batch
 *   during the iteration. Must be defined before the iteration.
 *
 * BATCH (struct batch *):
 *   Batch to iterate over.
 */
#define BATCH_FOREACH_POP(...) \
    OVERLOAD_SAFE_MACRO(BATCH_FOREACH_POP_3, BATCH_FOREACH_POP_2, 3, __VA_ARGS__)

#define BATCH_FOREACH_POP_3(IDX, PTR, BATCH) \
    BATCH_FOREACH_POP__(IDX, PTR, BATCH, OVS_JOIN(start_size, __COUNTER__))

#define BATCH_FOREACH_POP_2(PTR, BATCH) \
    BATCH_FOREACH_POP__(OVS_JOIN(idx, __COUNTER__), PTR, BATCH, \
                        OVS_JOIN(start_size, __COUNTER__))

#define BATCH_FOREACH_POP__(IDX, PTR, BATCH, START_SIZE) \
    for (size_t (START_SIZE) = batch_foreach_pop_start(BATCH), (IDX) = 0; \
         (IDX) < (START_SIZE) && ((PTR) = (BATCH)->ptrs[(IDX)], true); (IDX)++)  \

#endif /* BATCH_H */
