/* SPDX-License-Identifier: GPL-2.0 */
/*
 * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 */

#ifndef NVFS_DMA_H
#define NVFS_DMA_H

#ifdef HAVE_BLK_MQ_DMA_H
/*
 * New DMA iterator API implementation (HAVE_BLK_MQ_DMA_H)
 */

/* Forward declarations for functions from pci.c that we need */
static blk_status_t nvme_pci_setup_data_prp(struct request *req,
		struct blk_dma_iter *iter);
static blk_status_t nvme_pci_setup_data_sgl(struct request *req,
		struct blk_dma_iter *iter);
static inline struct dma_pool *nvme_dma_pool(struct nvme_queue *nvmeq,
		struct nvme_iod *iod);
static inline dma_addr_t nvme_pci_first_desc_dma_addr(struct nvme_command *cmd);

static inline bool nvme_nvfs_unmap_sgls(struct request *req)
{
	struct nvme_iod *iod = blk_mq_rq_to_pdu(req);
	struct nvme_queue *nvmeq = req->mq_hctx->driver_data;
	struct device *dma_dev = nvmeq->dev->dev;
	dma_addr_t sqe_dma_addr = le64_to_cpu(iod->cmd.common.dptr.sgl.addr);
	unsigned int sqe_dma_len = le32_to_cpu(iod->cmd.common.dptr.sgl.length);
	struct nvme_sgl_desc *sg_list = iod->descriptors[0];
	enum dma_data_direction dir = rq_dma_dir(req);
        
	if (iod->nr_descriptors) {
                unsigned int nr_entries = sqe_dma_len / sizeof(*sg_list), i;

                for (i = 0; i < nr_entries; i++) {
			nvfs_ops->nvfs_dma_unmap_page(dma_dev, 
					              iod->nvfs_cookie, 
						      le64_to_cpu(sg_list[i].addr), 
						      le32_to_cpu(sg_list[i].length), 
						      dir);
		}
        } else
		nvfs_ops->nvfs_dma_unmap_page(dma_dev, iod->nvfs_cookie, sqe_dma_addr, sqe_dma_len, dir);
        
	
	
	return true;
}

static inline bool nvme_nvfs_unmap_prps(struct request *req)
{
	struct nvme_iod *iod = blk_mq_rq_to_pdu(req);
	struct nvme_queue *nvmeq = req->mq_hctx->driver_data;
	struct device *dma_dev = nvmeq->dev->dev;
	enum dma_data_direction dma_dir = rq_dma_dir(req);
	unsigned int i;

	/* Check if dma_vecs was allocated - if setup failed early, it might be NULL */
	if (!iod->dma_vecs)
		return true;

	/* Unmap all DMA vectors - pass page pointer from dma_vecs */
	for (i = 0; i < iod->nr_dma_vecs; i++) {
		nvfs_ops->nvfs_dma_unmap_page(dma_dev, 
				              iod->nvfs_cookie, 
					      iod->dma_vecs[i].addr,
					      iod->dma_vecs[i].len,
				              dma_dir);
	}
	
	/* Free the dma_vecs mempool allocation */
	mempool_free(iod->dma_vecs, nvmeq->dev->dmavec_mempool);
	iod->dma_vecs = NULL;
	iod->nr_dma_vecs = 0;
	
	return true;
}

static inline void nvme_nvfs_free_descriptors(struct request *req)
{
	struct nvme_queue *nvmeq = req->mq_hctx->driver_data;
	const int last_prp = NVME_CTRL_PAGE_SIZE / sizeof(__le64) - 1;
	struct nvme_iod *iod = blk_mq_rq_to_pdu(req);
	dma_addr_t dma_addr = nvme_pci_first_desc_dma_addr(&iod->cmd);
	int i;

	if (iod->nr_descriptors == 1) {
		dma_pool_free(nvme_dma_pool(nvmeq, iod), iod->descriptors[0],
				dma_addr);
		return;
	}

	for (i = 0; i < iod->nr_descriptors; i++) {
		__le64 *prp_list = iod->descriptors[i];
		dma_addr_t next_dma_addr = le64_to_cpu(prp_list[last_prp]);

		dma_pool_free(nvmeq->descriptor_pools.large, prp_list,
				dma_addr);
		dma_addr = next_dma_addr;
	}
}

static inline bool nvme_nvfs_unmap_data(struct request *req)
{
	struct nvme_iod *iod = blk_mq_rq_to_pdu(req);
	bool ret;

	/* Check if this was an NVFS I/O by checking the IOD_NVFS_IO flag */
	if (!(iod->flags & IOD_NVFS_IO))
		return false;

	/* Clear the NVFS flag */
	iod->flags &= ~IOD_NVFS_IO;

	/* Call appropriate unmap function based on command type */
	if (nvme_pci_cmd_use_sgl(&iod->cmd))
		ret = nvme_nvfs_unmap_sgls(req);
	else
		ret = nvme_nvfs_unmap_prps(req);
	
	if (iod->nr_descriptors)
		nvme_nvfs_free_descriptors(req);

	nvfs_put_ops();
	return ret;
}

static inline blk_status_t nvme_nvfs_map_data(struct request *req, 
		bool *is_nvfs_io)
{
	struct nvme_iod *iod = blk_mq_rq_to_pdu(req);
	struct nvme_queue *nvmeq = req->mq_hctx->driver_data;
	struct nvme_dev *dev = nvmeq->dev;
	struct device *dma_dev = nvmeq->dev->dev;
	enum nvme_use_sgl use_sgl = nvme_pci_use_sgls(dev, req);
	struct blk_dma_iter iter;
	blk_status_t ret = BLK_STS_RESOURCE;

	*is_nvfs_io = false;

	/* Check integrity and try to get nvfs_ops */
	if (blk_integrity_rq(req) || !nvfs_get_ops()) {
		return ret;
	}

	/* Initialize total_len for this request */
	iod->total_len = 0;

	if (!nvfs_ops->nvfs_blk_rq_dma_map_iter_start(req, dma_dev, 
						       &iod->dma_state, &iter, &iod->nvfs_cookie)) {
		nvfs_put_ops();
		ret = BLK_STS_IOERR;
		return ret;
	}

	/* NVFS can handle this request, set the flag */
	*is_nvfs_io = true;
	iod->flags |= IOD_NVFS_IO;

	if (use_sgl == SGL_FORCED ||
	    (use_sgl == SGL_SUPPORTED &&
	     (sgl_threshold && nvme_pci_avg_seg_size(req) >= sgl_threshold)))
		ret = nvme_pci_setup_data_sgl(req, &iter);
         else
		ret = nvme_pci_setup_data_prp(req, &iter);

	/* If setup failed, cleanup: unmap DMA, clear flag, release ops */
	if (ret != BLK_STS_OK) {
		nvme_nvfs_unmap_data(req);
	}

	return ret;
}

#else /* !HAVE_BLK_MQ_DMA_H */
/*
 * Old scatterlist-based API implementation (!HAVE_BLK_MQ_DMA_H)
 */

/* Forward declarations for functions from pci.c */
static blk_status_t nvme_pci_setup_prps(struct request *req);

static blk_status_t nvme_pci_setup_sgls(struct request *req);

static bool nvme_nvfs_unmap_data(struct nvme_dev *dev, struct request *req)
{
	struct nvme_iod *iod = blk_mq_rq_to_pdu(req);
	enum dma_data_direction dma_dir = rq_dma_dir(req);

#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
	if (!iod || !iod->sgt.nents)
		return false;

	if (iod->sgt.sgl && !is_pci_p2pdma_page(sg_page(iod->sgt.sgl)) &&
#else
	if (!iod || !iod->nents)
		return false;

	if (iod->sg && !is_pci_p2pdma_page(sg_page(iod->sg)) &&
#endif
		!blk_integrity_rq(req) &&
#if defined(HAVE_BLKDEV_DMA_MAP_BVEC) && defined(HAVE_BLKDEV_REQ_BVEC)
		!iod->dma_len &&
#endif
		nvfs_ops != NULL) {
		int count;
#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
		count = nvfs_ops->nvfs_dma_unmap_sg(dev->dev, iod->sgt.sgl,
				iod->sgt.nents, dma_dir);
#else
		count = nvfs_ops->nvfs_dma_unmap_sg(dev->dev, iod->sg,
				iod->nents, dma_dir);
#endif
		if (!count)
			return false;

		nvfs_put_ops();
		return true;
	}

	return false;
}

static blk_status_t nvme_nvfs_map_data(struct nvme_dev *dev, struct request *req,
		struct nvme_command *cmnd, bool *is_nvfs_io)
{
	struct nvme_iod *iod = blk_mq_rq_to_pdu(req);
	struct request_queue *q = req->q;
	enum dma_data_direction dma_dir = rq_dma_dir(req);
	blk_status_t ret = BLK_STS_RESOURCE;
	int nr_mapped;

	nr_mapped = 0;
	*is_nvfs_io = false;

	if (!blk_integrity_rq(req) && nvfs_get_ops()) {
#if defined(HAVE_BLKDEV_DMA_MAP_BVEC) && defined(HAVE_BLKDEV_REQ_BVEC)
		iod->dma_len = 0;
#endif
#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
		iod->sgt.sgl = mempool_alloc(dev->iod_mempool, GFP_ATOMIC);
		if (!iod->sgt.sgl) {
#else
		iod->sg = mempool_alloc(dev->iod_mempool, GFP_ATOMIC);
		if (!iod->sg) {
#endif
			nvfs_put_ops();
			return BLK_STS_RESOURCE;
		}

#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
		sg_init_table(iod->sgt.sgl, blk_rq_nr_phys_segments(req));
		/* associates bio pages to scatterlist */
		iod->sgt.orig_nents = nvfs_ops->nvfs_blk_rq_map_sg(q, req,
				iod->sgt.sgl);
		if (!iod->sgt.orig_nents) {
			mempool_free(iod->sgt.sgl, dev->iod_mempool);
#else
		sg_init_table(iod->sg, blk_rq_nr_phys_segments(req));
		/* associates bio pages to scatterlist */
		iod->nents = nvfs_ops->nvfs_blk_rq_map_sg(q, req, iod->sg);
		if (!iod->nents) {
			mempool_free(iod->sg, dev->iod_mempool);
#endif
			nvfs_put_ops();
			return BLK_STS_IOERR;
		}
		*is_nvfs_io = true;

#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
		if (unlikely((iod->sgt.orig_nents == NVFS_IO_ERR))) {
			pr_err("%s: failed to map sg_nents=:%d\n", __func__,
					iod->sgt.nents);
			mempool_free(iod->sgt.sgl, dev->iod_mempool);
#else
		if (unlikely((iod->nents == NVFS_IO_ERR))) {
			pr_err("%s: failed to map sg_nents=:%d\n", __func__,
					iod->nents);
			mempool_free(iod->sg, dev->iod_mempool);
#endif
			nvfs_put_ops();
			return BLK_STS_IOERR;
		}

		nr_mapped = nvfs_ops->nvfs_dma_map_sg_attrs(dev->dev,
#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
				iod->sgt.sgl,
				iod->sgt.orig_nents,
#else
				iod->sg,
				iod->nents,
#endif
				dma_dir,
				DMA_ATTR_NO_WARN);

		if (unlikely((nr_mapped == NVFS_IO_ERR))) {
#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
			mempool_free(iod->sgt.sgl, dev->iod_mempool);
			nvfs_put_ops();
			pr_err("%s: failed to dma map sglist=:%d\n", __func__,
					iod->sgt.nents);
#else
			mempool_free(iod->sg, dev->iod_mempool);
			nvfs_put_ops();
			pr_err("%s: failed to dma map sglist=:%d\n", __func__,
					iod->nents);
#endif
			return BLK_STS_IOERR;
		}

		if (unlikely(nr_mapped == NVFS_CPU_REQ)) {
#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
			mempool_free(iod->sgt.sgl, dev->iod_mempool);
#else
			mempool_free(iod->sg, dev->iod_mempool);
#endif
			nvfs_put_ops();
			BUG();
		}

#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
		iod->sgt.nents = nr_mapped;
#else
		iod->nents = nr_mapped;
#endif

#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
		if (nvme_pci_use_sgls(dev, req)) {
			ret = nvme_pci_setup_sgls(req);
#else
		if (nvme_pci_use_sgls(dev, req)) {
			ret = nvme_pci_setup_sgls(req);
#endif
		} else {
			/* push dma address to hw registers */
			ret = nvme_pci_setup_prps(req);
		}

		if (ret != BLK_STS_OK) {
			nvme_nvfs_unmap_data(dev, req);
#ifdef HAVE_DMA_PCI_P2PDMA_SUPPORTED
			mempool_free(iod->sgt.sgl, dev->iod_mempool);
#else
			mempool_free(iod->sg, dev->iod_mempool);
#endif
		}
		return ret;
	}
	return ret;
}

#endif /* HAVE_BLK_MQ_DMA_H */

#endif /* NVFS_DMA_H */
