// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include <sys/ioctl.h>
#include <sys/eventfd.h>
#include <linux/vfio.h>
#include <string.h>

#include "vfio.h"

/* Helper function to get physical address for noiommu mode */
uint64_t get_physical_address(const void *vaddr)
{
	uint64_t page_size, page_frame_number = 0;
	uint64_t entry;
	off_t offset;
	int fd;

	/* Get the page size */
	page_size = sysconf(_SC_PAGESIZE);

	/* Open the pagemap file */
	fd = open("/proc/self/pagemap", O_RDONLY);
	if (fd < 0) {
		fprintf(stderr, "Failed to open pagemap, err(%d)\n", -errno);
		return 0;
	}

	/* Calculate the offset in the pagemap file */
	offset = ((uintptr_t)vaddr / page_size) * sizeof(uint64_t);

	/* Seek to the page entry */
	if (lseek(fd, offset, SEEK_SET) != offset) {
		fprintf(stderr, "Failed to seek in pagemap, err(%d)\n", -errno);
		close(fd);
		return 0;
	}

	/* Read the entry */
	if (read(fd, &entry, sizeof(entry)) != sizeof(entry)) {
		fprintf(stderr, "Failed to read pagemap, err(%d)\n", -errno);
		close(fd);
		return 0;
	}

	close(fd);

	/* Check if the page is present */
	if (!(entry & (1ULL << 63))) {
		fprintf(stderr, "Page not present in memory\n");
		return 0;
	}

	/* Extract the page frame number (bits 0-54) */
	page_frame_number = entry & ((1ULL << 55) - 1);

	/* Calculate the physical address: PFN * page_size + offset_in_page */
	return (page_frame_number * page_size) + ((uintptr_t)vaddr % page_size);
}

int vfio_container_fd_open(bool noiommu)
{
	int container_fd;
	int iommu_type;
	int ret;

	container_fd = open("/dev/vfio/vfio", O_RDWR);
	if (container_fd < 0) {
		fprintf(stderr, "Failed to open /dev/vfio/vfio, err(%d)\n",
			-errno);
		return -errno;
	}

	if (ioctl(container_fd, VFIO_GET_API_VERSION) != VFIO_API_VERSION) {
		fprintf(stderr, "Failed to verify VFIO API version, err(%d)\n",
			-errno);
		ret = -errno;
		goto close_cfd;
	}

	/* Check for required IOMMU extension */
	iommu_type = noiommu ? VFIO_NOIOMMU_IOMMU : VFIO_TYPE1v2_IOMMU;
	if (!ioctl(container_fd, VFIO_CHECK_EXTENSION, iommu_type)) {
		fprintf(stderr,
			"Failed to find required %s extension, err(%d)\n",
			noiommu ? "VFIO_NOIOMMU_IOMMU" : "VFIO TYPE1v2 IOMMU",
			-errno);
		ret = -errno;
		goto close_cfd;
	}

	return container_fd;
close_cfd:
	close(container_fd);
	return ret;
}

void vfio_container_fd_close(int container_fd)
{
	if (container_fd < 0)
		return;

	close(container_fd);
}

/**
 * Resolves the IOMMU group for a given PCI device name and constructs the corresponding
 * VFIO group device path.
 *
 * Given a PCI device name (e.g., "0000:17:00.0"), this function determines the associated
 * IOMMU group and builds the VFIO group device path (e.g., "/dev/vfio/7" for regular mode
 * or "/dev/vfio/noiommu-7" for noiommu mode).
 *
 * @param pci_device_name  The PCI device name as a string (e.g., "0000:17:00.0").
 * @param vfio_group_path  Output buffer to store the resulting VFIO group device path.
 * @param noiommu          Whether to use noiommu mode paths.
 *
 * @return 0 on success, negative errno on failure.
 */
int sysfs_pci2vfio_path(const char *pci_name, char *vfio_path, bool noiommu)
{
	char iommu_group_path[SYSFS_PATH_MAX];
	char link_target[SYSFS_PATH_MAX];
	const char *group_name;
	ssize_t len;
	int ret;

	/* Build the sysfs path to the IOMMU group symlink */
	ret = snprintf(iommu_group_path, sizeof(iommu_group_path),
		       "/sys/bus/pci/devices/%s/iommu_group", pci_name);
	if (ret < 0 || (unsigned int)ret >= sizeof(iommu_group_path)) {
		fprintf(stderr,
			"Failed to build iommu_group path for %s, err(%d)\n",
			pci_name, -errno);
		return -errno;
	}

	/* Read the symlink to get the group number */
	len = readlink(iommu_group_path, link_target, sizeof(link_target) - 1);
	if (len < 0) {
		fprintf(stderr, "Failed to readlink %s, err(%d)\n",
			iommu_group_path, -errno);
		return -errno;
	}
	link_target[len] = '\0';

	/* The group number is the last component of the symlink target */
	group_name = strrchr(link_target, '/');
	if (!group_name || !*(group_name + 1)) {
		fprintf(stderr, "Failed to parse group number from %s\n",
			link_target);
		return -EINVAL;
	}
	group_name++; /* skip '/' */

	/* Build the VFIO device path based on mode */
	ret = snprintf(vfio_path, SYSFS_PATH_MAX, "/dev/vfio/%s%s",
		       noiommu ? "noiommu-" : "", group_name);

	if (ret < 0 || ret >= SYSFS_PATH_MAX) {
		fprintf(stderr,
			"Failed to build vfio path for group %s, err(%d)\n",
			group_name, -errno);
		return -errno;
	}

	return 0;
}

static int vfio_is_iommu_mode(int container_fd)
{
	struct vfio_iommu_type1_info info = { .argsz = sizeof(info) };
	int ret;

	ret = ioctl(container_fd, VFIO_IOMMU_GET_INFO, &info);

	/* Return 0 for noiommu, 1 for iommu, -errno for other errors
	 * ENOTTY: noiommu is set (TYPE1 ioctl not supported in noiommu)
	 * EINVAL: no IOMMU type set yet (or noiommu before device added)
	 */
	if (ret != 0) {
		if (errno == ENOTTY || errno == EINVAL)
			return 0; /* noiommu mode */
		else
			return -errno;
	}

	return (info.flags & VFIO_IOMMU_INFO_PGSIZES) ? 1 : 0;
}

int vfio_pci_dev_open(int container_fd, const char *pci_name,
		      struct vfio_pci_dev *vdev, bool noiommu)
{
	struct vfio_group_status gstatus = { .argsz = sizeof(gstatus) };
	char dev_pci_name[PCI_NAME_MAX];
	int device_fd = -1;
	int group_fd = -1;
	int iommu_type;
	int ret;

	ret = sscanf(pci_name, "%127[^,]", dev_pci_name);
	if (ret != 1) {
		printf("Invalid input format %s. Expected: <bdf>,vf_token=<token> or <bdf>\n",
		       pci_name);
		return -errno;
	}

	/* Get the VFIO device path for the given PCI device */
	ret = sysfs_pci2vfio_path(dev_pci_name, vdev->vfio_path, noiommu);
	if (ret) {
		fprintf(stderr,
			"Failed to get VFIO device path for %s, err(%d)\n",
			dev_pci_name, ret);
		return ret;
	}

	group_fd = open(vdev->vfio_path, O_RDWR);
	if (group_fd < 0) {
		fprintf(stderr, "Failed to open VFIO group path %s, err(%d)\n",
			vdev->vfio_path, -errno);
		return -errno;
	}

	if (ioctl(group_fd, VFIO_GROUP_GET_STATUS, &gstatus)) {
		fprintf(stderr, "Failed to get VFIO group status, err(%d)\n",
			-errno);
		ret = -errno;
		goto abort;
	}

	if (!(gstatus.flags & VFIO_GROUP_FLAGS_VIABLE)) {
		fprintf(stderr, "VFIO group is not viable, err(%d)\n", -errno);
		ret = -errno;
		goto abort;
	}

	/* Add group to the container */
	if (ioctl(group_fd, VFIO_GROUP_SET_CONTAINER, &container_fd)) {
		fprintf(stderr, "Failed to set VFIO group container, err(%d)\n",
			-errno);
		ret = -errno;
		goto abort;
	}

	/* Enable the IOMMU model we want */
	iommu_type = noiommu ? VFIO_NOIOMMU_IOMMU : VFIO_TYPE1v2_IOMMU;
	ret = ioctl(container_fd, VFIO_SET_IOMMU, iommu_type);
	if (ret < 0) {
		/* EBUSY (regular IOMMU) or EINVAL (noiommu) means IOMMU type already set
		 * Verify it's the correct one
		 */
		if (errno != EBUSY && errno != EINVAL) {
			fprintf(stderr,
				"Failed to set %s IOMMU type, err(%d)\n",
				noiommu ? "noiommu" : "VFIO", -errno);
			ret = -errno;
			goto abort;
		}

		/* IOMMU already set - verify it matches our expectation */
		int iommu_mode_result = vfio_is_iommu_mode(container_fd);
		int expected_result = noiommu ? 0 : 1;
		if (iommu_mode_result < 0) {
			fprintf(stderr,
				"Failed to verify IOMMU mode, err(%d)\n",
				iommu_mode_result);
			ret = iommu_mode_result;
			goto abort;
		}
		if (iommu_mode_result != expected_result) {
			fprintf(stderr,
				"IOMMU mode mismatch: container has %s but requested %s, err(%d)\n",
				iommu_mode_result ? "VFIO" : "noiommu",
				noiommu ? "noiommu" : "VFIO", -EINVAL);
			ret = -EINVAL;
			goto abort;
		}
	}

	char *comma_pos = strchr(vdev->pci_bdf_vf_token, ',');

	if (comma_pos)
		*comma_pos = ' ';

	device_fd = ioctl(group_fd, VFIO_GROUP_GET_DEVICE_FD,
			  vdev->pci_bdf_vf_token);
	if (device_fd < 0) {
		fprintf(stderr, "Failed to get device fd, err(%d), path(%s)\n",
			-errno, pci_name);
		ret = -errno;
		goto abort;
	}
	vdev->device_fd = device_fd;
	vdev->group_fd = group_fd;
	vdev->container_fd = container_fd;
	vdev->event_fd = -1;
	strncpy(vdev->pci_bdf_vf_token, dev_pci_name, PCI_NAME_MAX);
	return 0;
abort:
	if (device_fd >= 0)
		close(device_fd);
	if (group_fd >= 0)
		close(group_fd);
	return ret;
}

void vfio_pci_dev_close(struct vfio_pci_dev *vdev)
{
	if (vdev->device_fd >= 0)
		close(vdev->device_fd);
	if (vdev->group_fd >= 0)
		close(vdev->group_fd);
}

/* must be called after at least 1 device has been opened via vfio_pci_dev_open() */
int vfio_mem_register(int container_fd, const void *vaddr, uint64_t iova,
		      uint64_t size)
{
	struct vfio_iommu_type1_dma_map dma_map = { .argsz = sizeof(dma_map) };

	dma_map.vaddr = (uintptr_t)vaddr;
	dma_map.size = size;
	dma_map.iova = iova;
	dma_map.flags = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE;

	return ioctl(container_fd, VFIO_IOMMU_MAP_DMA, &dma_map);
}

void vfio_mem_unregister(int container_fd, uint64_t iova, uint64_t size)
{
	struct vfio_iommu_type1_dma_unmap dma_unmap = {};

	dma_unmap.argsz = sizeof(struct vfio_iommu_type1_dma_unmap);
	dma_unmap.size = size;
	dma_unmap.iova = iova;

	if (ioctl(container_fd, VFIO_IOMMU_UNMAP_DMA, &dma_unmap))
		fprintf(stderr, "Failed to unmap dma, err(%d)\n", errno);
}

#define MLX5_VFIO_IRQ_VEC_IDX 0
int vfio_interrupt_fd_open(int device_fd)
{
	struct vfio_irq_info irq = { .argsz = sizeof(irq) };
	struct vfio_irq_set *irq_set_buf = NULL;
	int *msix_fds = NULL;
	int irq_eventfd;
	int fdlen, ret;
	__u32 i;

	irq.index = VFIO_PCI_MSIX_IRQ_INDEX;
	if (ioctl(device_fd, VFIO_DEVICE_GET_IRQ_INFO, &irq)) {
		fprintf(stderr, "Failed to get VFIO device IRQ info, err(%d)\n",
			-errno);
		return -errno;
	}

	if ((irq.flags & VFIO_IRQ_INFO_EVENTFD) == 0) {
		fprintf(stderr,
			"VFIO device does not support eventfd for IRQs, err(%d)\n",
			-EINVAL);
		return -EINVAL;
	}

	fdlen = sizeof(int) * irq.count;
	irq_set_buf = malloc(sizeof(*irq_set_buf) + fdlen);
	msix_fds = malloc(fdlen);
	if (!irq_set_buf || !msix_fds) {
		fprintf(stderr, "Failed to allocate IRQ set buffers, err(%d)\n",
			-errno);
		free(irq_set_buf);
		free(msix_fds);
		return -errno;
	}

	for (i = 0; i < irq.count; i++)
		msix_fds[i] = -1;

	/* setup eventfd for command completion interrupts */
	irq_eventfd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
	if (irq_eventfd < 0) {
		fprintf(stderr, "Failed to create eventfd, err(%d)\n", -errno);
		ret = -errno;
		goto out;
	}

	msix_fds[MLX5_VFIO_IRQ_VEC_IDX] = irq_eventfd;

	memcpy(irq_set_buf->data, msix_fds, fdlen);

	/* Enable MSI-X interrupts.  */
	irq_set_buf->argsz = sizeof(*irq_set_buf) + sizeof(int) * irq.count;
	irq_set_buf->index = VFIO_PCI_MSIX_IRQ_INDEX;
	irq_set_buf->flags = VFIO_IRQ_SET_DATA_EVENTFD |
			     VFIO_IRQ_SET_ACTION_TRIGGER;
	irq_set_buf->start = 0;
	irq_set_buf->count = irq.count;

	if (ioctl(device_fd, VFIO_DEVICE_SET_IRQS, irq_set_buf)) {
		fprintf(stderr, "Failed to set VFIO device IRQs, err(%d)\n",
			-errno);
		ret = -errno;
		goto out;
	}

	free(irq_set_buf);
	free(msix_fds);
	return irq_eventfd;

out:
	free(irq_set_buf);
	free(msix_fds);
	close(irq_eventfd);
	return ret;
}

void vfio_interrupt_fd_close(int irq_eventfd)
{
	if (irq_eventfd >= 0)
		close(irq_eventfd);
}
