// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#include <errno.h>
#include <stdint.h>
#include <stdlib.h>

#include "fwpages.h"
#include "cmd.h"
#include "dev.h"
#include "pagealloc.h"
#include "util/util.h"
#include "vfio_mlx5.h"
#include "util/event_log.h"

/* For verbosity only this is true by HW requirement */
_Static_assert(DEVX_ST_SZ_BYTES(manage_pages_in) ==
		       DEVX_ST_SZ_BYTES(manage_pages_out),
	       "manage_pages_in and manage_pages_out must be the same size");

/* The number of pages and in/out of size mbox_len can hold */
#define MBOX_PAGES_NUM(mbox_len)                          \
	((mbox_len - DEVX_ST_SZ_BYTES(manage_pages_in)) / \
	 DEVX_FLD_SZ_BYTES(manage_pages_in, pas[0]))

/* The size of in/out mbox buffer needed for npages */
#define PAGES_MBOX_SIZE(npages)              \
	(DEVX_ST_SZ_BYTES(manage_pages_in) + \
	 (DEVX_FLD_SZ_BYTES(manage_pages_in, pas[0]) * (npages)))

/* Fill manage_pages_in buffer (size inlen) with as many pages as it can hold
 * if page allocation fails, return npages allocated and updates inlen
 */
static size_t fwpages_fill_in(struct vfio_mlx5_dev *dev, void *in,
			      uint16_t func_id, size_t *inlen)
{
	uint32_t npages = MBOX_PAGES_NUM(*inlen);
	unsigned int i;

	DEVX_SET(manage_pages_in, in, opcode, MLX5_CMD_OP_MANAGE_PAGES);
	DEVX_SET(manage_pages_in, in, function_id, func_id);

	for (i = 0; i < npages; i++) {
		uint64_t iova;

		int ret = mlx5_vfio_page_alloc(dev->page_alloc, &iova);
		if (ret)
			break;

		DEVX_SET64(manage_pages_in, in, pas[i], iova);
	}

	if (i == 0) {
		*inlen = DEVX_ST_SZ_BYTES(manage_pages_in);

		DEVX_SET(manage_pages_in, in, input_num_entries, npages);
		DEVX_SET(manage_pages_in, in, op_mod, MLX5_PAGES_CANT_GIVE);
		return 0;
	}

	if (i != npages)
		/* Continue to give FW the pages we have */
		npages = i;

	*inlen = PAGES_MBOX_SIZE(npages);
	DEVX_SET(manage_pages_in, in, input_num_entries, npages);
	DEVX_SET(manage_pages_in, in, op_mod, MLX5_PAGES_GIVE);

	return npages;
}

static void fwpages_in_free_pages(struct vfio_mlx5_dev *dev, const void *in)
{
	uint32_t npages;

	if (DEVX_GET(manage_pages_in, in, op_mod) != MLX5_PAGES_GIVE)
		return;

	npages = DEVX_GET(manage_pages_in, in, input_num_entries);
	for (unsigned int i = 0; i < npages; i++)
		mlx5_vfio_page_free(dev->page_alloc,
				    DEVX_GET64(manage_pages_in, in, pas[i]));
}

static void fwpages_out_free_pages(struct vfio_mlx5_dev *dev, const void *out)
{
	uint32_t npages;

	npages = DEVX_GET(manage_pages_out, out, output_num_entries);
	for (unsigned int i = 0; i < npages; i++)
		mlx5_vfio_page_free(dev->page_alloc,
				    DEVX_GET64(manage_pages_out, out, pas[i]));
}

static int mlx5_vfio_give_startup_pages(struct vfio_mlx5_dev *dev,
					uint16_t func_id, uint32_t npages)
{
	uint32_t out[DEVX_ST_SZ_DW(manage_pages_out)] = {};
	size_t inlen = PAGES_MBOX_SIZE(npages);
	uint32_t npages_alloc;
	uint32_t *in;
	int ret;

	in = grow_buf_get(&dev->heap->fwpage_reqs.in, &inlen);
	npages_alloc = fwpages_fill_in(dev, in, func_id, &inlen);
	if (npages_alloc < npages) {
		ret = -ENOMEM;
		dev_err(dev,
			"Not sufficient pages available for startup, err(%d)",
			ret);
		goto err;
	}

	fwp_event(dev, FWP_EVENT_GIVE_BOOT,
		  "Giving startup pages for func_id %d, npages %u\n", func_id,
		  npages);
	ret = mlx5_vfio_cmd_exec(dev, in, inlen, out, sizeof(out));
	if (ret) {
		dev_err(dev,
			"Failed to execute command MLX5_CMD_OP_MANAGE_PAGES, "
			"ret(%d)",
			ret);
		goto err;
	}
	dev->firmware_pages += npages;
	return 0;

err:
	fwpages_in_free_pages(dev, in);
	return ret;
}

static int mlx5_vfio_query_pages(struct vfio_mlx5_dev *dev, int boot,
				 uint16_t *func_id, uint32_t *npages)
{
	uint32_t query_pages_out[DEVX_ST_SZ_DW(query_pages_out)] = {};
	uint32_t query_pages_in[DEVX_ST_SZ_DW(query_pages_in)] = {};
	int err;

	DEVX_SET(query_pages_in, query_pages_in, opcode,
		 MLX5_CMD_OP_QUERY_PAGES);
	DEVX_SET(query_pages_in, query_pages_in, op_mod, boot ? 0x01 : 0x02);

	err = mlx5_vfio_cmd_exec(dev, query_pages_in, sizeof(query_pages_in),
				 query_pages_out, sizeof(query_pages_out));
	if (err) {
		dev_err(dev,
			"Failed to execute MLX5_CMD_OP_QUERY_PAGES command, "
			"err(%d)",
			err);
		return err;
	}

	*npages = DEVX_GET(query_pages_out, query_pages_out, num_pages);
	*func_id = DEVX_GET(query_pages_out, query_pages_out, function_id);

	return 0;
}

int mlx5_vfio_satisfy_startup_pages(struct vfio_mlx5_dev *dev, int boot,
				    int *pages_given)
{
	uint16_t function_id;
	uint32_t npages = 0;
	int ret;

	ret = mlx5_vfio_query_pages(dev, boot, &function_id, &npages);
	if (ret) {
		dev_err(dev, "Failed to query pages, err(%d)", ret);
		return ret;
	}

	ret = mlx5_vfio_give_startup_pages(dev, function_id, npages);
	if (ret) {
		dev_err(dev, "Failed to give %d pages to function %u, err(%d)",
			npages, function_id, ret);
		return ret;
	}

	*pages_given = npages;

	return ret;
}

static int mlx5_vfio_give_pages_async(struct vfio_mlx5_dev *dev,
				      uint16_t func_id, size_t npages)
{
	size_t inlen, outlen;
	unsigned int event;
	size_t avail_pages;
	uint32_t *in;
	int ret;

	outlen = DEVX_ST_SZ_BYTES(manage_pages_out);
	inlen = PAGES_MBOX_SIZE(npages);

	/* try to grow the cmd slot layout in/out buffers if needed
	 * in case growing was needed but didn't happen,
	 * inlen, outlen will be shrunk
	 */
	mlx5_cmd_page_req_try_grow(dev, &inlen, &outlen);

	in = grow_buf_get(&dev->heap->fwpage_reqs.in, &inlen);
	/* in/out are pre-allocated, no need to check for NULL */

	avail_pages = fwpages_fill_in(dev, in, func_id, &inlen);

	event = FWP_EVENT_GIVE;
	if (avail_pages == 0)
		event = FWP_EVENT_CANT_GIVE;
	else if (avail_pages != npages)
		event = FWP_EVENT_GIVE_PARTIAL;

	fwp_event(dev, event, "Posting func_id (%d) npages (%ld)\n", func_id,
		  npages);
	ret = mlx5_vfio_cmd_page_req_post(dev, in, inlen, outlen);
	if (!ret)
		return 0;

	event = FWP_EVENT_GIVE_ERROR;
	if (avail_pages == 0)
		event = FWP_EVENT_CANT_GIVE_ERROR;

	fwp_event_err(dev, event,
		      "Failed post manage pages give command "
		      "func_id(%d) npages(%ld), ret(%d)\n",
		      func_id, npages, ret);
	fwpages_in_free_pages(dev, in);
	return ret;
}

static int mlx5_vfio_reclaim_pages_async(struct vfio_mlx5_dev *dev,
					 uint32_t func_id, int npages)
{
	size_t inlen, outlen;
	void *in;
	int ret;

	inlen = DEVX_ST_SZ_BYTES(manage_pages_in);
	outlen = PAGES_MBOX_SIZE(npages);

	/* try to grow the cmd slot layout in/out buffers if needed
	 * in case growing was needed but didn't happen,
	 * inlen, outlen will be shrunk
	 */
	mlx5_cmd_page_req_try_grow(dev, &inlen, &outlen);
	/* if growing was needed but didn't happen, inlen, outlen will be shrunk */

	in = grow_buf_get(&dev->heap->fwpage_reqs.in, &inlen);
	/* in/out are pre-allocated, no need to check for NULL */

	/* actual npages is what the outlen can hold, so adapt ... */
	npages = MBOX_PAGES_NUM(outlen);
	DEVX_SET(manage_pages_in, in, opcode, MLX5_CMD_OP_MANAGE_PAGES);
	DEVX_SET(manage_pages_in, in, op_mod, MLX5_PAGES_TAKE);
	DEVX_SET(manage_pages_in, in, function_id, func_id);
	DEVX_SET(manage_pages_in, in, input_num_entries, npages);

	fwp_event(dev, FWP_EVENT_TAKE,
		  "Posting reclaim pages func_id (%d) npages %d\n", func_id,
		  npages);
	ret = mlx5_vfio_cmd_page_req_post(dev, in, inlen, outlen);
	if (!ret)
		return 0;

	fwp_event_err(dev, FWP_EVENT_TAKE_ERROR,
		      "Failed post manage pages reclaim command "
		      "func_id(%d) npages(%d), ret(%d)\n",
		      func_id, npages, ret);
	return ret;
}

static void fwpage_fifo_process(struct vfio_mlx5_dev *dev)
{
	const struct fwpage_req *req;
	int ret;

process_fifo:
	req = fwpage_fifo_pop(&dev->page_reqs.fifo);
	if (!req)
		return;

	dev->page_reqs.inflight = 1;
	if (req->npages > 0)
		ret = mlx5_vfio_give_pages_async(dev, req->func_id,
						 req->npages);
	else
		ret = mlx5_vfio_reclaim_pages_async(dev, req->func_id,
						    -1 * req->npages);

	if (ret) {
		dev->page_reqs.inflight = 0;
		goto process_fifo;
	}
}

void mlx5_vfio_handle_page_req_event(struct vfio_mlx5_dev *dev,
				     uint16_t func_id, int npages)
{
	struct fwpage_req req = { func_id, npages };
	unsigned int event = npages < 0 ? FWP_EVENT_FW_REQ_TAKE :
					  FWP_EVENT_FW_REQ_GIVE;

	fwp_event(dev, event, "func_id %d, npages %d\n", func_id, npages);

	if (fwpage_fifo_push(&dev->page_reqs.fifo, &req)) {
		/* TOO MANY PAGE REQ EQEs should never happen */
		fwp_event_err(dev, FWP_EVENT_FW_REQ_DROP_ERROR,
			      "Fifo push failed, func_id %d, npages %d\n",
			      func_id, npages);
	}

	if (dev->page_reqs.inflight) {
		fwp_event(
			dev, FWP_EVENT_FW_REQ_BUSY,
			"busy, requests[%d] pending in fifo, func_id %d, npages %d\n",
			fwpage_fifo_len(&dev->page_reqs.fifo), func_id, npages);
		return;
	}

	/* process the request in FIFO */
	fwpage_fifo_process(dev);
}

static void log_manage_pages_comp(struct vfio_mlx5_dev *dev, const void *in,
				  const void *out, int status)
{
	uint32_t op_mod, opcode, syndrome, npages;
	unsigned int event = FWP_EVENT_GIVE_ERROR;
	bool is_error = false;
	uint8_t out_status;
	uint16_t func_id;

	opcode = DEVX_GET(manage_pages_in, in, opcode);
	op_mod = DEVX_GET(manage_pages_in, in, op_mod);
	func_id = DEVX_GET(manage_pages_in, in, function_id);
	npages = DEVX_GET(manage_pages_in, in, input_num_entries);
	if (op_mod == MLX5_PAGES_TAKE)
		npages = DEVX_GET(manage_pages_out, out, output_num_entries);

	mlx5_cmd_mbox_status(out, &out_status, &syndrome);

	is_error = (status || out_status);

	switch (op_mod) {
	case MLX5_PAGES_GIVE:
		event = is_error ? FWP_EVENT_GIVE_ERROR :
				   FWP_EVENT_GIVE_SUCCESS;
		break;
	case MLX5_PAGES_TAKE:
		event = is_error ? FWP_EVENT_TAKE_ERROR :
				   FWP_EVENT_TAKE_SUCCESS;
		break;
	case MLX5_PAGES_CANT_GIVE:
		event = is_error ? FWP_EVENT_CANT_GIVE_ERROR :
				   FWP_EVENT_CANT_GIVE_SUCCESS;
		break;
	}

	if (is_error)
		fwp_event_err(
			dev, event,
			"Completed with ERROR: opcode(0x%x) op_mod(0x%x) "
			"func_id(%d) npages(%d) delivery_status(%d) out_status(%d) syndrome(%d)\n",
			opcode, op_mod, func_id, npages, status, out_status,
			syndrome);
	else
		fwp_event(dev, event, "Completed: func_id(%d) npages(%d)\n",
			  func_id, npages);
}

static void fwpage_count_reclaimed(struct vfio_mlx5_dev *dev, const void *out)
{
	dev->firmware_pages -=
		DEVX_GET(manage_pages_out, out, output_num_entries);
}

static void fwpage_count_given(struct vfio_mlx5_dev *dev, const void *in)
{
	dev->firmware_pages += DEVX_GET(manage_pages_in, in, input_num_entries);
}

int mlx5_vfio_page_request_cmd_comp(struct vfio_mlx5_dev *dev, int status)
{
	size_t inlen, outlen;
	uint8_t out_status;
	uint32_t syndrome;
	uint32_t op_mod;
	void *in, *out;
	int ret = 0;

	inlen = mlx5_vfio_cmd_in_size(dev, CMD_SLOT_PAGE_REQ);
	outlen = mlx5_vfio_cmd_out_size(dev, CMD_SLOT_PAGE_REQ);

	/* can't fail as we pre-allocated the buffers to the correct size */
	in = grow_buf_get(&dev->heap->fwpage_reqs.in, &inlen);
	out = grow_buf_get(&dev->heap->fwpage_reqs.out, &outlen);

	mlx5_vfio_cmd_copy_in(dev, in, inlen, CMD_SLOT_PAGE_REQ);
	mlx5_vfio_cmd_copy_out(dev, out, outlen, CMD_SLOT_PAGE_REQ);

	log_manage_pages_comp(dev, in, out, status);

	op_mod = DEVX_GET(manage_pages_in, in, op_mod);
	ret = status;

	if (!ret) {
		mlx5_cmd_mbox_status(out, &out_status, &syndrome);
		ret = out_status;
	}

	if (!ret) { /* success */
		mlx5_cmd_mbox_status(out, &out_status, &syndrome);

		if (op_mod == MLX5_PAGES_TAKE) { /* "TAKE" the pages back */
			fwpages_out_free_pages(dev, out);
			fwpage_count_reclaimed(dev, out);
		} else if (op_mod == MLX5_PAGES_GIVE) {
			fwpage_count_given(dev, in);
		}
		/* nothing to do here for "GIVE" or "CANT_GIVE" */
		goto process_fifo;
	}

	/* free what we tried to give */
	if (op_mod == MLX5_PAGES_GIVE)
		fwpages_in_free_pages(dev, in);

	/* nothing to do here for "CANT_GIVE" or "TAKE" */
process_fifo:

	dev->page_reqs.inflight = 0;

	fwpage_fifo_process(dev); /* Process next request in FIFO if any */
	return ret;
}

#define FWP_EVENT_STR(x) [FWP_EVENT_##x] = #x
static const char *fwpages_event_names[FWP_EVENT_COUNT] = {
	FWP_EVENT_STR(GIVE_BOOT),
	FWP_EVENT_STR(FW_REQ_GIVE),
	FWP_EVENT_STR(FW_REQ_TAKE),
	FWP_EVENT_STR(GIVE),
	FWP_EVENT_STR(GIVE_PARTIAL),
	FWP_EVENT_STR(CANT_GIVE),
	FWP_EVENT_STR(TAKE),
	FWP_EVENT_STR(GIVE_SUCCESS),
	FWP_EVENT_STR(TAKE_SUCCESS),
	FWP_EVENT_STR(CANT_GIVE_SUCCESS),
	FWP_EVENT_STR(GIVE_ERROR),
	FWP_EVENT_STR(TAKE_ERROR),
	FWP_EVENT_STR(CANT_GIVE_ERROR),
	FWP_EVENT_STR(FW_REQ_DROP_ERROR),
	FWP_EVENT_STR(FW_REQ_BUSY),
};

const char **mlx5_vfio_fwpages_event_names(size_t *count)
{
	*count = FWP_EVENT_COUNT;
	return fwpages_event_names;
}

#define DEFAULT_FWPAGE_REQS_SIZE 8096 /* in page num */
/* pre-allocate for 8K pages */
#define INIT_MBOX_SIZE PAGES_MBOX_SIZE(DEFAULT_FWPAGE_REQS_SIZE)

int mlx5_vfio_fwpage_reqs_init(struct mlx5_fwpage_reqs *fwpage_reqs,
			       const char *bdf, mlx5_pg_events_t page_events)
{
	fwpage_reqs->in.buf = calloc(1, INIT_MBOX_SIZE);
	fwpage_reqs->out.buf = calloc(1, INIT_MBOX_SIZE);

	if (!fwpage_reqs->in.buf || !fwpage_reqs->out.buf) {
		log_error("Failed to allocate memory for fwpage_reqs buffers");
		free(fwpage_reqs->in.buf);
		free(fwpage_reqs->out.buf);
		return -ENOMEM;
	}

	fwpage_reqs->in.len = INIT_MBOX_SIZE;
	fwpage_reqs->out.len = INIT_MBOX_SIZE;

	event_log_init(&fwpage_reqs->elogger, "fwp", libmlx5_logger.outf,
		       libmlx5_logger.errf);
	event_log_set_stats(&fwpage_reqs->elogger, page_events,
			    fwpages_event_names,
			    page_events ? FWP_EVENT_COUNT : 0);

	event_log_level_set(&fwpage_reqs->elogger, libmlx5_logger.level);
	event_log_prefix_set(&fwpage_reqs->elogger, "%s(%s)", "fwp", bdf);
	return 0;
}

void mlx5_vfio_fwpage_reqs_free(struct mlx5_fwpage_reqs *fwpage_reqs)
{
	free(fwpage_reqs->in.buf);
	free(fwpage_reqs->out.buf);
}
