// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#include <getopt.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <libgen.h>
#include <fcntl.h>
#include <errno.h>
#include <time.h>
#include <sys/ioctl.h>
#include <sys/eventfd.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <linux/vfio.h>
#include <string.h>
#include <limits.h>
#include <sys/epoll.h>
#include <linux/mman.h>
#include <signal.h>
#include <pthread.h>
#include <stdatomic.h> /* For atomic operations */

#include "vfio_mlx5.h"
#include "vfio.h"

/* iova = 0x400000000, 16GB
 * on ARM '0x8000000' iova address is preserved, which prevents us from
 * mapping iova 0 with len > 128MB, avoid that by using a high iova
 * address to allow memory lengths > 128MB on any architecture
 * on some GB systems, iova=0x40000000  (1GB) has also a preserved region,
 * we default to a higher   0x400000000 (16GB) to avoid such issues.
 */
#define DEFAULT_IOVA 0x400000000 /* 16GB */
#define DEFAULT_MEM_SIZE (128 * 1024 * 1024UL) /* 128MB */

static FILE *log_file;
static bool noiommu;

struct mlx5_vfio_ctx {
	int container_fd;
	int num_devices;
	struct vfio_pci_dev vdevs[VFIO_MLX5_MAX_DEVICES];
	int num_vfs;
	uint64_t mem_size;
	void *mem_vaddr;
	uint64_t iova;
	/* Stats collection thread management */
	pthread_t stats_thread;
	atomic_int stats_thread_running;

	/* libvfio-mlx5 components */
	struct vfio_mlx5_handle *vmh;
	struct vfio_mlx5_dev *mdevs[VFIO_MLX5_MAX_DEVICES];
};

static int stats_interval;

/**
 * Stats collection thread function
 * Runs vfio_mlx5_dev_stats every 1 second
 */
static void *stats_collection_thread(void *arg)
{
	struct mlx5_vfio_ctx *tctx = (struct mlx5_vfio_ctx *)arg;

	fprintf(stdout, "Stats collection thread started\n");

	while (1) {
		/* Check if we should stop */
		if (!atomic_load(&tctx->stats_thread_running))
			break;

		for (int i = 0; i < tctx->num_devices; i++)
			vfio_mlx5_dev_stats_dump(tctx->mdevs[i], NULL);

		sleep(stats_interval);
	}

	fprintf(stdout, "Stats collection thread stopped\n");
	return NULL;
}

/**
 * Start stats collection thread
 * @param dev: Device structure
 * @return: 0 on success, <0 on error
 */
static int start_stats_thread(struct mlx5_vfio_ctx *ctx)
{
	if (stats_interval <= 0)
		return 0; // No stats collection requested

	/* Mark thread as running */
	ctx->stats_thread_running = 1;

	/* Create the thread */
	int err = pthread_create(&ctx->stats_thread, NULL,
				 stats_collection_thread, ctx);

	if (err != 0) {
		ctx->stats_thread_running = 0;
		fprintf(stderr, "Failed to create stats thread, err(%d)\n",
			-err);
		return -err;
	}

	fprintf(stdout, "Stats collection thread started\n");

	return 0;
}

/**
 * Stop stats collection thread
 * @param dev: Device structure
 * @return: 0 on success, <0 on error
 */
static void stop_stats_thread(struct mlx5_vfio_ctx *ctx)
{
	if (stats_interval <= 0)
		return; // No stats collection requested
	/* Signal thread to stop */
	ctx->stats_thread_running = 0;

	/* Wait for thread to finish */
	int result = pthread_join(ctx->stats_thread, NULL);

	if (result != 0) {
		fprintf(stdout, "Failed to join stats thread, err(%d)\n",
			-errno);
		return;
	}

	fprintf(stdout, "Stats collection thread stopped\n");
}

static volatile sig_atomic_t stop_main_loop = 0;

static void handle_sigint(int sig)
{
	(void)sig;
	stop_main_loop = 1;
}

static void main_event_loop(struct mlx5_vfio_ctx *ctx)
{
	struct epoll_event ev, events[VFIO_MLX5_MAX_DEVICES];
	int epfd = -1;
	int ret, i, n;

	fprintf(stdout, "Entering main_event_loop\n");

	epfd = epoll_create1(0);
	if (epfd < 0) {
		perror("epoll_create1");
		return;
	}

	/* Add all device eventfds to epoll */
	for (i = 0; i < ctx->num_devices; i++) {
		if (ctx->vdevs[i].event_fd < 0)
			continue;
		memset(&ev, 0, sizeof(ev));
		ev.events = EPOLLIN;
		ev.data.u32 = i; /* store device index */
		if (epoll_ctl(epfd, EPOLL_CTL_ADD, ctx->vdevs[i].event_fd,
			      &ev) < 0) {
			fprintf(stderr,
				"epoll_ctl add failed for device %d, eventfd %d\n",
				i, ctx->vdevs[i].event_fd);
		}
	}

	fprintf(stdout, "epoll_wait for events on all devices...\n");
	while (!stop_main_loop) {
		int nfds = epoll_wait(epfd, events, ctx->num_devices, -1);

		if (nfds < 0) {
			if (errno == EINTR)
				continue;
			perror("epoll_wait");
			break;
		}
		for (n = 0; n < nfds; n++) {
			int idx = events[n].data.u32;
			uint64_t val;
			/* Read eventfd to clear the event */
			if (read(ctx->vdevs[idx].event_fd, &val, sizeof(val)) !=
			    sizeof(val)) {
				perror("eventfd read");
				continue;
			}
			fprintf(stdout, "Event received for device %d\n", idx);
			/* Call event process for this device */
			ret = vfio_mlx5_events_process(ctx->mdevs[idx]);
			if (ret >= 0)
				continue;
			fprintf(stderr, "Failed to process events, ret(%d)",
				ret);
			epoll_ctl(epfd, EPOLL_CTL_DEL, ctx->vdevs[idx].event_fd,
				  NULL);
			fprintf(stderr,
				"Device %s removed from epoll fd with errno %d\n",
				errno ? "failed to be" : "", errno);
			/* collect stats and dump status which has the latest health info
			 * vfio_mlx5_dev_stats()
			 * Users are supposed to perform vfio_mlx5_dev_del() after detecting the error.
			 * then cleanup any other resources for this device e.g vfio fds, etc ..
			 * We recommend to do all of the above in a separate thread,
			 * to avoid any event handling stalls.
			 */
			vfio_mlx5_dev_stats_dump(ctx->mdevs[idx], NULL);
		}
	}

	fprintf(stdout, "Exiting main_event_loop\n");
	close(epfd);
}

static void print_usage(const char *prog)
{
	printf("Usage: %s [OPTIONS]\n", prog);
	printf("Options:\n");
	printf("  --help                Show this help message\n");
	printf("  --device=PCIBDF, -d PCIBDF   Add PCI device (can be specified up to 8 times)\n");
	printf("           use PCIBDF,vf_token=... to add PCI device with vf_token\n");
	printf("  --memsize=SIZE, -m SIZE      Set memory size in bytes (4K aligned, default 128MB)\n");
	printf("  --nvfs=NUM, -n NUM           Set number of vfs for devices (default 0)\n");
	printf("  --file=FILE, -f FILE         Log output to file (default stdout)\n");
	printf("  --stats=int, -s int          Set stats collection interval in seconds (default 1)\n");
	printf("  --noiommu                    Enable NoIOMMU mode\n");
	printf("  --iova=IOVA_ADDR, -i IOVA_ADDR  Set IOVA start address (default 0x%lx)\n",
	       DEFAULT_IOVA);

	printf("\nExample:\n");
	printf("  %s --device=0000:03:00.0 --memsize=256M --nvfs=2\n", prog);
	printf("  %s --device=0000:03:00.0 --memsize=1g --nvfs=2 --noiommu\n",
	       prog);
}

static int parse_size(const char *str, uint64_t *size)
{
	char *endptr;
	uint64_t val = strtoull(str, &endptr, 0);
	if (endptr == str)
		return -1;
	if (*endptr == 'K' || *endptr == 'k') {
		val *= 1024;
		endptr++;
	} else if (*endptr == 'M' || *endptr == 'm') {
		val *= 1024 * 1024;
		endptr++;
	} else if (*endptr == 'G' || *endptr == 'g') {
		val *= 1024 * 1024 * 1024;
		endptr++;
	}
	if (*endptr != '\0')
		return -1;
	if (val % 4096 != 0)
		return -1;
	*size = val;
	return 0;
}

struct mlx5_vfio_ctx ctx = {
	.container_fd = -1,
	.num_devices = 0,
	.num_vfs = 0,
	.mem_size = DEFAULT_MEM_SIZE,
	.mem_vaddr = NULL,
	.iova = DEFAULT_IOVA,

	.vmh = NULL,
	.mdevs = { NULL },
};

static int parse_args(int argc, char *argv[])
{
	char *devices[VFIO_MLX5_MAX_DEVICES] = { NULL };
	unsigned long mem_size = DEFAULT_MEM_SIZE;
	unsigned int num_devices = 0;
	int c, option_index = 0;
	int num_vfs = 0;

	static struct option long_options[] = {
		{ "help", no_argument, 0, 0 },
		{ "device", required_argument, 0, 'd' },
		{ "memsize", optional_argument, 0, 'm' },
		{ "nvfs", optional_argument, 0, 'n' },
		{ "file", optional_argument, 0, 'f' },
		{ "stats", optional_argument, 0, 's' },
		{ "noiommu", no_argument, 0, 9999 },
		{ "iova", optional_argument, 0, 'i' },
		{ 0, 0, 0, 0 }
	};

	while ((c = getopt_long(argc, argv, "d:m:n:f:s:i:", long_options,
				&option_index)) != -1) {
		switch (c) {
		case 0:
			if (strcmp(long_options[option_index].name, "help") ==
			    0) {
				print_usage(argv[0]);
				return 1;
			}
			break;
		case 'd':
			if (num_devices >= VFIO_MLX5_MAX_DEVICES) {
				fprintf(stderr, "Too many devices (max %d)\n",
					VFIO_MLX5_MAX_DEVICES);
				return -1;
			}
			devices[num_devices] = optarg;
			num_devices++;
			break;
		case 'm':
			if (parse_size(optarg, &mem_size)) {
				fprintf(stderr, "Invalid memory size: %s\n",
					optarg);
				return -1;
			}
			break;
		case 'n':
			num_vfs = atoi(optarg);
			if (num_vfs < 0) {
				fprintf(stderr, "Invalid number of vfs: %s\n",
					optarg);
				return -1;
			}
			break;
		case 'f':
			log_file = fopen(optarg, "w");
			if (!log_file) {
				fprintf(stderr,
					"Failed to open log file %s, err(%d)\n",
					optarg, errno);
				return -1;
			}
			break;
		case 's':
			stats_interval = optarg ? atoi(optarg) : 1;
			if (stats_interval < 0) {
				fprintf(stderr, "Invalid stats interval: %s\n",
					optarg);
				return -1;
			}
			break;
		case 'i':
			ctx.iova = strtoull(optarg, NULL, 0);
			break;

		case 9999:
			noiommu = true;
			break;
		case '?':
		default:
			print_usage(argv[0]);
			return -1;
		}
	}

	if (num_devices == 0) {
		fprintf(stderr, "No devices specified\n");
		print_usage(argv[0]);
		return -1;
	}

	/* prepare context */
	ctx.num_devices = num_devices;
	ctx.mem_size = mem_size;
	ctx.num_vfs = num_vfs;
	for (unsigned int i = 0; i < num_devices; i++) {
		strncpy(ctx.vdevs[i].pci_bdf_vf_token, devices[i],
			PCI_NAME_MAX);
		ctx.vdevs[i].pci_bdf_vf_token[PCI_NAME_MAX - 1] = '\0';
	}

	return 0;
}

static void setup_sighandler(void)
{
	struct sigaction sa;
	memset(&sa, 0, sizeof(sa));
	sa.sa_handler = handle_sigint;
	sigaction(SIGINT, &sa, NULL);
}

/* Set OOM killer score adjustment to protect this process */
static void set_oom_score_adj(int score)
{
	FILE *f = fopen("/proc/self/oom_score_adj", "w");

	if (!f)
		return;

	fprintf(f, "%d\n", score);
	fclose(f);
}

int main(int argc, char *argv[])
{
	set_oom_score_adj(-1000);

	int err = parse_args(argc, argv);
	if (err)
		return err;

	ctx.container_fd = vfio_container_fd_open(noiommu);
	if (ctx.container_fd < 0) {
		fprintf(stderr, "Failed to open VFIO container%s, err(%d)\n",
			noiommu ? " in NoIOMMU mode" : " fd", -errno);
		return -errno;
	}

	for (int i = 0; i < ctx.num_devices; i++) {
		err = vfio_pci_dev_open(ctx.container_fd,
					ctx.vdevs[i].pci_bdf_vf_token,
					&ctx.vdevs[i], noiommu);
		if (err) {
			fprintf(stderr,
				"Failed to open VFIO device %s, err(%d) errno (%d)\n",
				ctx.vdevs[i].pci_bdf_vf_token, err, -errno);
			goto cleanup;
		}
	}

	/* Memory mapping with different flags based on mode */
	int mmap_flags = MAP_PRIVATE | MAP_ANONYMOUS |
			 (noiommu ? (MAP_HUGETLB | MAP_HUGE_1GB) : 0);
	ctx.mem_vaddr = mmap(NULL, ctx.mem_size, PROT_READ | PROT_WRITE,
			     mmap_flags, -1, 0);

	if (ctx.mem_vaddr == MAP_FAILED) {
		fprintf(stderr, "Failed to mmap memory, errno(%d)\n", -errno);
		err = -errno;
		goto cleanup;
	}
	fprintf(stdout, "mmap succeeded, mem_vaddr=%p\n", ctx.mem_vaddr);

	if (noiommu) {
		/* Needed for pagemap to have valid phy virt address mapping */
		memset(ctx.mem_vaddr, 0, ctx.mem_size);

		/* In NoIOMMU mode, get physical address */
		ctx.iova = get_physical_address(ctx.mem_vaddr);
		if (ctx.iova == 0) {
			fprintf(stderr,
				"Failed to get physical address for NoIOMMU mode\n");
			err = -EFAULT;
			goto cleanup;
		}

		/* Note: Memory registration is skipped in NoIOMMU mode */
		fprintf(stdout, "Memory registration skipped (NoIOMMU mode)\n");
	} else {
		err = vfio_mem_register(ctx.container_fd, ctx.mem_vaddr,
					ctx.iova, ctx.mem_size);
		if (err) {
			fprintf(stderr,
				"Failed to register memory with vfio, err(%d) errno(%d), iova(0x%lx), vaddr(0x%p)\n",
				err, -errno, ctx.iova, ctx.mem_vaddr);
			goto cleanup;
		}
	}

	/* libvfio_mlx5 API starts here */
	if (log_file)
		vfio_mlx5_log_set(MLX5_LOG_LVL_INFO, log_file, log_file);

	fprintf(stdout,
		"vfio_mlx5_init: memory size(%lu) iova(0x%lx) devices(%d)\n",
		ctx.mem_size, ctx.iova, ctx.num_devices);
	ctx.vmh = vfio_mlx5_init(ctx.mem_vaddr, ctx.mem_size, ctx.iova,
				 ctx.num_devices);
	if (!ctx.vmh) {
		fprintf(stderr, "vfio_mlx5_init failed, errno(%d)\n", -errno);
		goto cleanup_mlx5_init;
	}

	for (int i = 0; i < ctx.num_devices; i++) {
		ctx.vdevs[i].event_fd =
			vfio_interrupt_fd_open(ctx.vdevs[i].device_fd);
		if (ctx.vdevs[i].event_fd < 0) {
			fprintf(stderr,
				"Failed to open eventfd for device %s, err(%d)\n",
				ctx.vdevs[i].pci_bdf_vf_token, -errno);
			goto cleanup_mlx5_devs;
		}
		ctx.mdevs[i] = vfio_mlx5_device_add(
			ctx.vmh, ctx.vdevs[i].pci_bdf_vf_token,
			ctx.vdevs[i].device_fd, ctx.num_vfs);
		if (!ctx.mdevs[i]) {
			fprintf(stderr,
				"Failed to add device %s, err(%d) errno(%d)\n",
				ctx.vdevs[i].pci_bdf_vf_token, err, -errno);
			goto cleanup_mlx5_devs;
		}
		fprintf(stdout, "vfio_mlx5_device_add device %s, eventfd(%d)\n",
			ctx.vdevs[i].pci_bdf_vf_token, ctx.vdevs[i].event_fd);
	}

	setup_sighandler();
	/* Start the stats thread */
	err = start_stats_thread(&ctx);
	if (err) {
		fprintf(stderr, "Failed to start stats thread, err(%d)\n", err);
		goto cleanup_mlx5_devs;
	}

	main_event_loop(&ctx);

	/* Stop the stats thread */
	fprintf(stdout, "Stopping stats collection thread...\n");
	stop_stats_thread(&ctx);

cleanup_mlx5_devs:
	for (int i = 0; i < ctx.num_devices; i++) {
		if (!ctx.mdevs[i])
			continue;
		struct mlx5_dev_stats stats = { 0 };

		vfio_mlx5_dev_stats(ctx.mdevs[i], &stats);

		fprintf(stdout, "vfio_mlx5_device_del device %s\n",
			ctx.vdevs[i].pci_bdf_vf_token);
		vfio_mlx5_device_del(ctx.mdevs[i]);
		vfio_mlx5_dev_stats_dump(ctx.mdevs[i], &stats);
		vfio_interrupt_fd_close(ctx.vdevs[i].event_fd);
	}
	fprintf(stdout, "vfio_mlx5_uninit\n");
	vfio_mlx5_uninit(ctx.mem_vaddr);

cleanup_mlx5_init:
	if (!noiommu)
		vfio_mem_unregister(ctx.container_fd, ctx.iova, ctx.mem_size);

cleanup:
	if (ctx.mem_vaddr != MAP_FAILED && ctx.mem_vaddr != NULL)
		munmap(ctx.mem_vaddr, ctx.mem_size);

	for (int i = 0; i < ctx.num_devices; i++) {
		vfio_pci_dev_close(&ctx.vdevs[i]);
	}
	if (ctx.container_fd >= 0)
		vfio_container_fd_close(ctx.container_fd);

	if (log_file)
		fclose(log_file);

	return err;
}
