// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#include <errno.h>
#include <stddef.h>
#include <stdint.h>
#include <sys/mman.h>
#include <pthread.h>
#include <stdio.h>

#include "util/util.h"
#include "mlx5_memlay.h"
#include "dev.h"
#include "version_config.h"

#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)

#define LOG_PREFIX                                         \
	"vfio-mlx5v" STR(VFIO_MLX5_VERSION_MAJOR) "." STR( \
		VFIO_MLX5_VERSION_MINOR)

/* Library global logger */
event_log_t libmlx5_logger = {
	.prefix = LOG_PREFIX,
	.level = DEFAULT_LOG_LEVEL,
};

void vfio_mlx5_log_set(enum vfio_mlx5_log_level level, FILE *outf, FILE *errf)
{
	event_log_init(&libmlx5_logger, LOG_PREFIX, outf, errf);
	event_log_level_set(&libmlx5_logger, level);
}

static void init_dev_mem_lay(struct vfio_mlx5_handle *vmh, unsigned int index)
{
	struct vfio_mlx5_dev *dev = vfio_mlx5_dev_get(vmh, index);

	assert(index < vmh->hdr.max_devices);

	/* Check if the device area is 4K aligned */
	assert(((uintptr_t)dev & (MLX5_ADAPTER_PAGE_SIZE - 1)) == 0);
	memset(dev, 0, sizeof(struct vfio_mlx5_dev));

	dev->device_fd = -1;
	dev->index = index;
	dev->vmh_iova = vmh->hdr.base_iova;
}

#define MLX5_MIN_PAGES_PER_DEVICE 32

static int memlay_validate_storage(const struct vfio_mlx5_handle *vmh,
				   size_t storage_len, unsigned int num_devices)
{
	size_t min_storage_size =
		sizeof(struct vfio_mlx5_handle) +
		/* line below covers both bitmap and pages area minimums */
		num_devices * MLX5_ADAPTER_PAGE_SIZE *
			MLX5_MIN_PAGES_PER_DEVICE;

	if (num_devices > VFIO_MLX5_MAX_DEVICES) {
		log_error("Number of devices exceeds max (%u)\n",
			  VFIO_MLX5_MAX_DEVICES);
		return -EINVAL;
	}

	if (num_devices < 1) {
		log_error("Number of devices must be greater than 0\n");
		return -EINVAL;
	}

	if (storage_len < min_storage_size) {
		log_error(
			"Persist storage size needs to be greater than %zu bytes\n",
			min_storage_size);
		return -EINVAL;
	}

	/* Check if persist_storage is page aligned */
	if ((uintptr_t)vmh & (MLX5_ADAPTER_PAGE_SIZE - 1)) {
		log_error("persist_storage must be 4K page aligned\n");
		return -EINVAL;
	}

	return 0;
}

static void memlay_init(void *storage, size_t storage_len, uint64_t iova,
			unsigned int num_devices)
{
	struct vfio_mlx5_handle *vmh = (struct vfio_mlx5_handle *)storage;
	struct vfio_mlx5_hdr *hdr = &vmh->hdr;

	/* Initialize the header */
	memset(hdr, 0, sizeof(struct vfio_mlx5_hdr));

	hdr->magic_value = MLX5_MAGIC_VALUE;
	hdr->version.major = VFIO_MLX5_VERSION_MAJOR;
	hdr->version.minor = VFIO_MLX5_VERSION_MINOR;
	hdr->max_devices = num_devices;
	hdr->total_mem_size = storage_len;
	hdr->base_iova = iova;
	hdr->state = VFIO_MLX5_STATE_ENABLED;
	hdr->resume_token = 0;

	for (unsigned int i = 0; i < hdr->max_devices; i++)
		init_dev_mem_lay(vmh, i);
}

#if DEBUG
static void dbg_memlay_print(struct vfio_mlx5_handle *vmh);
#else
static void dbg_memlay_print(struct vfio_mlx5_handle *vmh)
{
	(void)vmh;
}
#endif

struct vfio_mlx5_handle *vfio_mlx5_init(void *persist_storage,
					size_t storage_len, uint64_t iova,
					unsigned int num_devices)
{
	struct vfio_mlx5_handle *vmh;
	int err;

	if (!libmlx5_logger.outf)
		libmlx5_logger.outf = stdout;
	if (!libmlx5_logger.errf)
		libmlx5_logger.errf = stderr;

	log_info("Initializing vfio_mlx5_handle: "
		 "Storage size: %zu, iova: 0x%lx, num_devices: %u",
		 storage_len, iova, num_devices);
	/* Check if storage_len is greater than the minimum size */
	err = memlay_validate_storage(persist_storage, storage_len,
				      num_devices);
	if (err) {
		log_error("Failed to validate storage size: %s", strerror(err));
		return NULL;
	}

	/* Lock the persist_storage memory region to prevent it from being
	 * swapped out */
	if (mlock(persist_storage, storage_len) != 0) {
		log_error("Failed to mlock persist_storage: %s",
			  strerror(errno));
		return NULL;
	}

	memlay_init(persist_storage, storage_len, iova, num_devices);

	vmh = (struct vfio_mlx5_handle *)persist_storage;

	/* Available memory for page allocator is the rest of the storage
	 * starting from the page allocator struct ptr
	 */
	size_t storage_len_for_pages =
		storage_len - offsetof(struct vfio_mlx5_handle, page_alloc);
	mlx5_vfio_page_alloc_init(&vmh->page_alloc, storage_len_for_pages,
				  iova + offsetof(struct vfio_mlx5_handle,
						  page_alloc));
	dbg_memlay_print(vmh);
	return vmh;
}

void vfio_mlx5_uninit(struct vfio_mlx5_handle *vmh)
{
	memset(vmh, 0, vmh->hdr.total_mem_size);
	munlock(vmh, vmh->hdr.total_mem_size);
}

static pthread_spinlock_t dev_cnt_spinlock;
static bool dev_cnt_spinlock_init;

static bool dev_cnt_test_and_inc(struct vfio_mlx5_handle *vmh, int *index)
{
	bool result = false;

	if (!dev_cnt_spinlock_init) {
		pthread_spin_init(&dev_cnt_spinlock, PTHREAD_PROCESS_PRIVATE);
		dev_cnt_spinlock_init = true;
	}

	pthread_spin_lock(&dev_cnt_spinlock);
	if (vmh->hdr.dev_cnt < vmh->hdr.max_devices) {
		if (index)
			*index = vmh->hdr.dev_cnt;
		vmh->hdr.dev_cnt++;
		result = true;
	}
	pthread_spin_unlock(&dev_cnt_spinlock);
	return result;
}

static void dev_cnt_dec(struct vfio_mlx5_handle *vmh)
{
	pthread_spin_lock(&dev_cnt_spinlock);
	vmh->hdr.dev_cnt--;
	pthread_spin_unlock(&dev_cnt_spinlock);
}

struct vfio_mlx5_dev *vfio_mlx5_device_add(struct vfio_mlx5_handle *vmh,
					   const char *bdf, int device_fd,
					   int num_vfs)
{
	struct vfio_mlx5_dev *dev;
	int dev_index;
	int err;

	errno = 0; /* reset errno */

	if (!dev_cnt_test_and_inc(vmh, &dev_index)) {
		log_error("No more space for devices on this handle");
		errno = ENOSPC;
		return NULL;
	}

	init_dev_mem_lay(vmh, dev_index);
	dev = vfio_mlx5_dev_get(vmh, dev_index);

	dev->num_vfs = num_vfs;
	dev->device_fd = device_fd;
	memcpy(dev->pci_bdf, bdf, PCI_BDF_LEN);

	err = mlx5_vfio_device_init(dev, &vmh->page_alloc, vmh);
	if (err) {
		log_error("%s add failed, err(%d) errno(%d)", bdf, err, errno);
		dev_cnt_dec(vmh);
		errno = errno ? errno : -err;
		return NULL;
	}

	return dev;
}

void vfio_mlx5_device_del(struct vfio_mlx5_dev *dev)
{
	mlx5_vfio_device_uninit(dev);
	/* no access to vmh here so we can't decrement dev_cnt_dec(dev->vmh);
	 * Anyhow, not required to by API
	 */
}

struct vfio_mlx5_dev *vfio_mlx5_dev_get(struct vfio_mlx5_handle *vmh,
					uint32_t index)
{
	if (index >= VFIO_MLX5_MAX_DEVICES) {
		log_error("Invalid device index: %u", index);
		return NULL;
	}

	return &vmh->devices[index];
}

unsigned int vfio_mlx5_dev_index(const struct vfio_mlx5_dev *dev)
{
	return dev->index;
}

/* Suspend/Resume API */

int vfio_mlx5_device_suspend(struct vfio_mlx5_dev *dev)
{
	if (dev->state != VFIO_MLX5_STATE_ENABLED) {
		log_error("Device is not enabled");
		return -EINVAL;
	}

	return mlx5_vfio_dev_suspend(dev);
}

int vfio_mlx5_suspend(struct vfio_mlx5_handle *vmh)
{
	if (vmh->hdr.state == VFIO_MLX5_STATE_SUSPENDED) {
		log_error("vfio_mlx5_handle is already suspended");
		return -EBUSY;
	}

	/* Suspend all devices just in case they were not suspended by user */
	for (int i = 0; i < vmh->hdr.dev_cnt; i++) {
		struct vfio_mlx5_dev *dev = vfio_mlx5_dev_get(vmh, i);
		int err;

		if (dev->state != VFIO_MLX5_STATE_ENABLED)
			continue;

		err = mlx5_vfio_dev_suspend(dev);
		if (err)
			return err;
	}
	vmh->hdr.state = VFIO_MLX5_STATE_SUSPENDED;
	return 0;
}

struct vfio_mlx5_handle *vfio_mlx5_resume(void *persist_storage)
{
	struct vfio_mlx5_handle *vmh;

	vmh = (struct vfio_mlx5_handle *)persist_storage;

	if (vmh->hdr.magic_value != MLX5_MAGIC_VALUE) {
		log_error("vfio_mlx5_handle is not initialized");
		return NULL;
	}

	if (vmh->hdr.state != VFIO_MLX5_STATE_SUSPENDED) {
		log_error("vfio_mlx5_handle is not suspended");
		return NULL;
	}

	pagealloc_spinlock_init();
	vmh->hdr.state = VFIO_MLX5_STATE_ENABLED;
	vmh->hdr.resume_token++;

	for (int i = 0; i < vmh->hdr.dev_cnt; i++) {
		struct vfio_mlx5_dev *dev = vfio_mlx5_dev_get(vmh, i);

		dev->resume_token = vmh->hdr.resume_token;
	}

	return vmh;
}

int vfio_mlx5_device_resume(struct vfio_mlx5_handle *vmh,
			    struct vfio_mlx5_dev *dev, int device_fd)
{
	if (vmh->hdr.state != VFIO_MLX5_STATE_ENABLED) {
		log_error(
			"mlx5_vfio_resume() wasn't called before vfio_mlx5_device_resume()");
		return -EINVAL;
	}

	if (dev->resume_token != vmh->hdr.resume_token) {
		log_error("Bad device handle, resume_token mismatch");
		return -EINVAL;
	}

	if (dev->state != VFIO_MLX5_STATE_SUSPENDED) {
		log_error("Device is not suspended");
		return -EINVAL;
	}

	dev->device_fd = device_fd;
	return mlx5_vfio_dev_resume(dev, &vmh->page_alloc, vmh);
}

#if DEBUG

#define PRNT_SEPARATOR()                                                   \
	fprintf(stdout, "%-18s | %-20s | %-18s | %-12s | %-12s | %-10s\n", \
		"------------------", "--------------------",              \
		"-------------------", "-----------", "-----------",       \
		"----------");

#define PRINTF_HEADER_LINE()                                               \
	PRNT_SEPARATOR();                                                  \
	fprintf(stdout, "%-18s | %-20s | %-19s | %-12s | %-12s | %-10s\n", \
		"vaddr", "struct", "iova", "offset", "size", "bytes");     \
	PRNT_SEPARATOR()

static inline void prnt_human_readable_size(char *buf, size_t buf_size,
					    size_t size)
{
	if (size >= (1 << 30)) {
		snprintf(buf, buf_size, "%.2f GB", (double)size / (1 << 30));
	} else if (size >= (1 << 20)) {
		snprintf(buf, buf_size, "%.2f MB", (double)size / (1 << 20));
	} else if (size >= (1 << 10)) {
		snprintf(buf, buf_size, "%.2f KB", (double)size / (1 << 10));
	} else {
		snprintf(buf, buf_size, "%lu B", size);
	}
}

/* is 4k aligned */
#define _IS_4K(_size) \
	((unsigned long)(_size) & (MLX5_ADAPTER_PAGE_SIZE - 1)) == 0 ? "*" : " "

#define PRNT_ROW(vaddr, struct_name, iova, offset, size)                   \
	do {                                                               \
		char buf[32];                                              \
		prnt_human_readable_size(buf, sizeof(buf), offset + size); \
		assert((uintptr_t)(vaddr) == (uintptr_t)vmh + offset);     \
		fprintf(stdout,                                            \
			"0x%016lx | %-20s | 0x%016lx%s | 0x%09lx%s | "     \
			"0x%09lx%s | %s\n",                                \
			(uintptr_t)(vaddr), struct_name,                   \
			(unsigned long)(iova), _IS_4K(iova),               \
			(unsigned long)(offset), _IS_4K(offset),           \
			(unsigned long)(size), _IS_4K(size), buf);         \
		iova += (size);                                            \
		offset += (size);                                          \
	} while (0)

/* print memory layout + assert on bad offsets from vmh, great for validation */
static void dbg_memlay_print(struct vfio_mlx5_handle *vmh)
{
	struct vfio_mlx5_hdr *hdr = &vmh->hdr;
	struct page_allocator *page_alloc = NULL;
	struct vfio_mlx5_dev *dev = NULL;
	size_t curr_iova = hdr->base_iova;
	size_t curr_offset = 0;
	char buf[32];

	prnt_human_readable_size(buf, sizeof(buf), hdr->total_mem_size);
	fprintf(stdout, "VFIO MLX5 Memory Layout: for %s \n", buf);

	PRINTF_HEADER_LINE();
	PRNT_ROW(hdr, "vfio_mlx5_hdr", curr_iova, curr_offset, sizeof(*hdr));

	for (unsigned int i = 0; i < VFIO_MLX5_MAX_DEVICES; i++) {
		dev = vfio_mlx5_dev_get(vmh, i);
		assert(dev != NULL);
		PRNT_ROW((uintptr_t)dev, "vfio_mlx5_dev[i]", curr_iova,
			 curr_offset, sizeof(*dev));
	}

	page_alloc = &vmh->page_alloc;
	PRNT_ROW(page_alloc, "page_allocator", curr_iova, curr_offset,
		 sizeof(struct page_allocator));

	size_t first_page_offset = page_alloc->base_iova - vmh->hdr.base_iova;
	size_t page_bitmap_size = first_page_offset - curr_offset;
	PRNT_ROW(page_alloc->page_bitmap, "bitmap", curr_iova, curr_offset,
		 page_bitmap_size);

	void *first_page = (void *)vmh + first_page_offset;
	PRNT_ROW(first_page, "page[0]", curr_iova, curr_offset,
		 MLX5_ADAPTER_PAGE_SIZE);

	unsigned long last_page_index = page_alloc->num_pages - 1;
	size_t last_page_offset = page_alloc->base_iova - vmh->hdr.base_iova +
				  last_page_index * MLX5_ADAPTER_PAGE_SIZE;
	void *last_page = (void *)(uintptr_t)vmh + last_page_offset;
	curr_offset =
		first_page_offset + MLX5_ADAPTER_PAGE_SIZE * last_page_index;
	curr_iova = hdr->base_iova + curr_offset;
	assert(curr_iova <= vmh->hdr.base_iova + vmh->hdr.total_mem_size);
	PRNT_ROW(last_page, "page[N]", curr_iova, curr_offset,
		 MLX5_ADAPTER_PAGE_SIZE);
	PRNT_SEPARATOR();
	PRNT_ROW(last_page + MLX5_ADAPTER_PAGE_SIZE, "BOUNDARY", curr_iova,
		 curr_offset, 0);
	PRNT_SEPARATOR();
}
#endif /* DEBUG */
