#include "nnt_device.h"
#include "nnt_device_defs.h"
#include "nnt_defs.h"
#include "nnt_pci_conf_access.h"
#include "nnt_pci_conf_access_recovery.h"
#include "nnt_memory_access.h"
#include "nnt_gpu.h"
#include <linux/module.h>

MODULE_LICENSE("GPL");

/* Forward declarations */
int create_file_name_mstflint(struct pci_dev* pci_device, struct nnt_device* nnt_dev, enum nnt_device_type device_type);
int create_file_name_mft(struct pci_dev* pci_device, struct nnt_device* nnt_dev, enum nnt_device_type device_type);
int nnt_device_structure_init(struct nnt_device** nnt_device);
int create_nnt_device(struct pci_dev* pci_device, enum nnt_device_type device_type, int is_alloc_chrdev_region);
int check_pci_id_range(unsigned short pci_device_id, unsigned short id_range_start);
int is_connectx(unsigned short pci_device_id);
int is_connectx3(unsigned short pci_device_id);
int is_bluefield(unsigned short pci_device_id);
int is_bluefield4(unsigned short pci_device_id);
int is_pcie_switch(unsigned short pci_device_id);
int is_quantum(unsigned short pci_device_id);
int is_spectrum(unsigned short pci_device_id);
int is_switch_ib(unsigned short pci_device_id);
int is_livefish_device(unsigned short pci_device_id);
int is_nic(unsigned short pci_device_id);
int is_switch(unsigned short pci_device_id);
int is_toolspf(unsigned short pci_device_id);
int create_device_file(struct nnt_device* current_nnt_device,
                       dev_t device_number,
                       int minor,
                       struct file_operations* fop,
                       int is_alloc_chrdev_region);
int create_devices(dev_t device_number, struct file_operations* fop, int is_alloc_chrdev_region);
int find_all_vendor_devices(unsigned int vendor_id);

/* Device list to check if device is available
     since it could be removed by hotplug event. */
LIST_HEAD(nnt_device_list);

unsigned short supported_rma_device_ids[] = {
  QUANTUM3_RMA_PCI_ID,  QUANTUM4_RMA_PCI_ID,   QUANTUM5_RMA_PCI_ID,
  CONNECTX8_RMA_PCI_ID, ARCUSEE_RMA_PCI_ID,    QUANTUM2_RMA_PCI_ID,
  CONNECTX7_RMA_PCI_ID, BLUEFIELD3_RMA_PCI_ID, SPECTRUM4_RMA_PCI_ID,
  SPECTRUM5_RMA_PCI_ID, SPECTRUM6_RMA_PCI_ID,  CONNECTX8_RMA_PCI_ID,
  CONNECTX9_RMA_PCI_ID, CONNECTX10_RMA_PCI_ID, CONNECTX9_PURE_PCIE_SWITCH_RMA_PCI_ID,
  ARCUS2_RMA_PCI_ID};
const unsigned int SUPPORTED_RMA_DEVICE_IDS_TABLE_SIZE = sizeof(supported_rma_device_ids) / sizeof(unsigned short);

u_int16_t VERA_PCI_IDS[] = {VERA1_PCI_ID, VERA2_PCI_ID, VERA3_PCI_ID, VERA4_PCI_ID, VERA5_PCI_ID};

int is_vera_pci_device(u_int16_t pci_device_id)
{
    unsigned int i;
    for (i = 0; i < sizeof(VERA_PCI_IDS) / sizeof(u_int16_t); i++)
    {
        if (VERA_PCI_IDS[i] == pci_device_id)
        {
            return 1;
        }
    }
    return 0;
}

int get_nnt_device(struct file* file, struct nnt_device** nnt_device)
{
    int error_code = 0;

    if (!file->private_data)
    {
        error_code = -EINVAL;
    }
    else
    {
        *nnt_device = file->private_data;

        if (!(*nnt_device)->pci_device)
        {
            printk(KERN_ERR "mst_pciconf driver: pci_device pointer is NULL\n");
            error_code = -EINVAL;
        }
    }

    return error_code;
}

void set_private_data_open(struct file* file)
{
    struct nnt_device* current_nnt_device = NULL;
    struct nnt_device* temp_nnt_device = NULL;
    int minor = iminor(file_inode(file));

    /* Set private data to nnt structure. */
    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        if ((minor == current_nnt_device->device_minor_number) && current_nnt_device->device_enabled)
        {
            file->private_data = current_nnt_device;
            return;
        }
    }
}

int set_private_data_bc(struct file* file, unsigned int bus, unsigned int devfn, unsigned int domain)
{
    struct nnt_device* current_nnt_device = NULL;
    struct nnt_device* temp_nnt_device = NULL;
    int minor = iminor(file_inode(file));
    unsigned int current_function;
    unsigned int current_device;

    /* Set private data to nnt structure. */
    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        struct pci_bus* pci_bus = pci_find_bus(current_nnt_device->dbdf.domain, current_nnt_device->dbdf.bus);
        if (!pci_bus)
        {
            return -ENXIO;
        }

        current_nnt_device->pci_device = pci_get_slot(pci_bus, current_nnt_device->dbdf.devfn);
        if (!current_nnt_device->pci_device)
        {
            return -ENXIO;
        }

        current_function = PCI_FUNC(current_nnt_device->dbdf.devfn);
        current_device = PCI_SLOT(current_nnt_device->dbdf.devfn);

        if ((current_nnt_device->dbdf.bus == bus) && (current_device == PCI_SLOT(devfn)) &&
            (current_function == PCI_FUNC(devfn)) && (current_nnt_device->dbdf.domain == domain))
        {
            current_nnt_device->device_minor_number = minor;
            current_nnt_device->device_enabled = true;
            file->private_data = current_nnt_device;
            return 0;
        }
    }

    return -EINVAL;
}

int set_private_data(struct file* file)
{
    struct nnt_device* current_nnt_device = NULL;
    struct nnt_device* temp_nnt_device = NULL;
    int minor = iminor(file_inode(file));

    /* Set private data to nnt structure. */
    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        if (current_nnt_device->device_minor_number == minor)
        {
            file->private_data = current_nnt_device;
            return 0;
        }
    }

    printk(KERN_ERR "failed to find device with minor=%d\n", minor);

    return -EINVAL;
}

int create_file_name_mstflint(struct pci_dev* pci_device, struct nnt_device* nnt_dev, enum nnt_device_type device_type)
{
    sprintf(nnt_dev->device_name, "%4.4x:%2.2x:%2.2x.%1.1x_%s", pci_domain_nr(pci_device->bus), pci_device->bus->number,
            PCI_SLOT(pci_device->devfn), PCI_FUNC(pci_device->devfn),
            (device_type == NNT_PCICONF) ? MSTFLINT_PCICONF_DEVICE_NAME : MSTFLINT_MEMORY_DEVICE_NAME);

    printk(KERN_DEBUG
           "MSTFlint device name created: id: %d, slot id: %d, device name: /dev/%s domain: 0x%x bus: 0x%x\n",
           pci_device->device, PCI_FUNC(pci_device->devfn), nnt_dev->device_name, pci_domain_nr(pci_device->bus),
           pci_device->bus->number);

    return 0;
}

int create_file_name_mft(struct pci_dev* pci_device, struct nnt_device* nnt_dev, enum nnt_device_type device_type)
{
    sprintf(nnt_dev->device_name, "mst/mt%d_%s0.%x", pci_device->device,
            (device_type == NNT_PCICONF) ? MFT_PCICONF_DEVICE_NAME : MFT_MEMORY_DEVICE_NAME,
            PCI_FUNC(pci_device->devfn));

    printk(KERN_DEBUG "MFT device name created: id: %d, slot id: %d, device name: /dev/%s domain: 0x%x bus: 0x%x\n",
           pci_device->device, PCI_FUNC(pci_device->devfn), nnt_dev->device_name, pci_domain_nr(pci_device->bus),
           pci_device->bus->number);

    return 0;
}

int nnt_device_structure_init(struct nnt_device** nnt_device)
{
    /* Allocate nnt device structure. */
    *nnt_device = kzalloc(sizeof(struct nnt_device), GFP_KERNEL);

    if (!(*nnt_device))
    {
        return -ENOMEM;
    }

    /* initialize nnt structure. */
    memset(*nnt_device, 0, sizeof(struct nnt_device));

    return 0;
}

int create_nnt_device(struct pci_dev* pci_device, enum nnt_device_type device_type, int is_alloc_chrdev_region)
{
    struct nnt_device* nnt_device = NULL;
    int error_code = 0;

    /* Allocate nnt device info structure. */
    if ((error_code = nnt_device_structure_init(&nnt_device)) != 0)
        goto ReturnOnError;

    if (is_alloc_chrdev_region)
    {
        /* Build the device file name of MSTFlint. */
        if ((error_code = create_file_name_mstflint(pci_device, nnt_device, device_type)) != 0)
            goto ReturnOnError;
    }
    else
    {
        /* Build the device file name of MFT. */
        if ((error_code = create_file_name_mft(pci_device, nnt_device, device_type)) != 0)
            goto ReturnOnError;
    }

    nnt_device->dbdf.bus = pci_device->bus->number;
    nnt_device->dbdf.devfn = pci_device->devfn;
    nnt_device->dbdf.domain = pci_domain_nr(pci_device->bus);
    nnt_device->pci_device = pci_device;
    nnt_device->device_type = device_type;

    if (device_type == NNT_PCICONF)
    {
        DEBUG_PRINTK("Created NNT_PCICONF device - Device ID: %d, Domain: 0x%x, Bus: 0x%x, Slot: %d, Function: %d\n\n",
                     pci_device->device, pci_domain_nr(pci_device->bus), pci_device->bus->number,
                     PCI_SLOT(pci_device->devfn), PCI_FUNC(pci_device->devfn));
    }
    else if (device_type == NNT_PCI_MEMORY)
    {
        DEBUG_PRINTK(
          "Created NNT_PCI_MEMORY device - Device ID: %d, Domain: 0x%x, Bus: 0x%x, Slot: %d, Function: %d\n\n",
          pci_device->device, pci_domain_nr(pci_device->bus), pci_device->bus->number, PCI_SLOT(pci_device->devfn),
          PCI_FUNC(pci_device->devfn));
    }
    else if (device_type == NNT_PCICONF_RECOVERY)
    {
        DEBUG_PRINTK(
          "Created NNT_PCICONF_RECOVERY device - Device ID: %d, Domain: 0x%x, Bus: 0x%x, Slot: %d, Function: %d\n\n",
          pci_device->device, pci_domain_nr(pci_device->bus), pci_device->bus->number, PCI_SLOT(pci_device->devfn),
          PCI_FUNC(pci_device->devfn));
    }

    /* Add the nnt device structure to the list. */
    list_add_tail(&nnt_device->entry, &nnt_device_list);

    return error_code;

ReturnOnError:
    if (nnt_device)
    {
        kfree(nnt_device);
    }

    return error_code;
}

int check_pci_id_range(unsigned short pci_device_id, unsigned short id_range_start)
{
    return (pci_device_id >= id_range_start) && (pci_device_id <= (id_range_start + 255));
}

int is_connectx(unsigned short pci_device_id)
{
    return check_pci_id_range(pci_device_id, CONNECTX3_PCI_ID);
}

int is_connectx3(unsigned short pci_device_id)
{
    return pci_device_id == CONNECTX3_PCI_ID || pci_device_id == CONNECTX3PRO_PCI_ID;
}

int is_bluefield(unsigned short pci_device_id)
{
    return check_pci_id_range(pci_device_id, BLUEFIELD_PCI_ID) ||
           check_pci_id_range(pci_device_id, BLUEFIELD_DPU_AUX_PCI_ID) || is_bluefield4(pci_device_id);
}

int is_bluefield4(unsigned short pci_device_id)
{
    return check_pci_id_range(pci_device_id, BLUEFIELD4_CRYPTO_ENABLED_PCI_ID) ||
           check_pci_id_range(pci_device_id, BLUEFIELD4_CRYPTO_DISABLED_PCI_ID) ||
           check_pci_id_range(pci_device_id, BLUEFIELD4_NETWORK_CONTROLLER_PCI_ID) ||
           check_pci_id_range(pci_device_id, BLUEFIELD4_MANAGMENT_INTERFACE_PCI_ID);
}

int is_pcie_switch(unsigned short pci_device_id)
{
    return check_pci_id_range(pci_device_id, SCHRODINGER_PCI_ID);
}

int is_quantum(unsigned short pci_device_id)
{
    return check_pci_id_range(pci_device_id, QUANTUM_PCI_ID);
}

int is_spectrum(unsigned short pci_device_id)
{
    return (pci_device_id == SPECTRUM_PCI_ID) || (check_pci_id_range(pci_device_id, SPECTRUM2_PCI_ID));
}

int is_switch_ib(unsigned short pci_device_id)
{
    return pci_device_id == SWITCHIB_PCI_ID || pci_device_id == SWITCHIB2_PCI_ID;
}

int is_rma_device(unsigned short pci_device_id)
{
    int i;
    for (i = 0; i < SUPPORTED_RMA_DEVICE_IDS_TABLE_SIZE; i++)
    {
        if (supported_rma_device_ids[i] == pci_device_id)
        {
            return 1;
        }
    }
    return 0;
}

int is_livefish_device_with_supported_bar_access(unsigned short pci_device_id)
{
    switch (pci_device_id)
    {
        case SPECTRUM6_RECOVERY_ID:
            return 1;
        default:
            return 0;
    }
}

int is_livefish_device(unsigned short pci_device_id)
{
    return pci_device_id >= CONNECTX3_LIVEFISH_ID && pci_device_id < CONNECTX3_PCI_ID &&
           !(is_rma_device(pci_device_id));
}

int is_nic(unsigned short pci_device_id)
{
    return is_connectx(pci_device_id) || is_bluefield(pci_device_id);
}

int is_switch(unsigned short pci_device_id)
{
    return is_pcie_switch(pci_device_id) || is_quantum(pci_device_id) || is_spectrum(pci_device_id) ||
           is_switch_ib(pci_device_id);
}

int is_toolspf(unsigned short pci_device_id)
{
    return is_nic(pci_device_id - 4000) || is_switch(pci_device_id - 4000);
}

int is_cpu_pci_device(u_int16_t pci_device_id)
{
    return is_vera_pci_device(pci_device_id);
}

int is_pciconf_device(unsigned short pci_device_id)
{
    int result = is_nic(pci_device_id) || is_toolspf(pci_device_id) || is_livefish_device(pci_device_id) ||
                 is_rma_device(pci_device_id) || is_switch(pci_device_id) || is_gpu_pci_device(pci_device_id) ||
                 is_cpu_pci_device(pci_device_id);

    DEBUG_PRINTK(
      "is_pciconf_device check for device 0x%x: result=%d (is_nic: %d, is_toolspf: %d, is_livefish: %d, is_rma: %d, is_switch: %d, is_gpu: %d, is_cpu: %d)\n",
      pci_device_id, result, is_nic(pci_device_id), is_toolspf(pci_device_id), is_livefish_device(pci_device_id),
      is_rma_device(pci_device_id), is_switch(pci_device_id), is_gpu_pci_device(pci_device_id),
      is_cpu_pci_device(pci_device_id));

    return result;
}

int is_pcicr_device(unsigned short pci_device_id)
{
    int result = (is_gpu_pci_device(pci_device_id) || is_switch(pci_device_id) || is_toolspf(pci_device_id) ||
                  is_connectx3(pci_device_id)) ||
                 (is_livefish_device(pci_device_id) && is_livefish_device_with_supported_bar_access(pci_device_id));

    DEBUG_PRINTK(
      "is_pcicr_device check for device 0x%x: result=%d (is_gpu: %d, is_toolspf: %d, is_connectx3: %d, is_livefish: %d, is_livefish_device_with_supported_bar_access: %d)\n",
      pci_device_id, result, is_gpu_pci_device(pci_device_id), is_toolspf(pci_device_id), is_connectx3(pci_device_id),
      is_livefish_device(pci_device_id), is_livefish_device_with_supported_bar_access(pci_device_id));

    return result;
}

int create_device_file(struct nnt_device* current_nnt_device,
                       dev_t device_number,
                       int minor,
                       struct file_operations* fop,
                       int is_alloc_chrdev_region)
{
    struct device* device = NULL;
    int error = 0;
    int count = 1;

    /* NNT driver will create the device file
         once we stop support backward compatibility. */
    current_nnt_device->device_minor_number = -1;
    current_nnt_device->device_number = device_number;
    current_nnt_device->mcdev.owner = THIS_MODULE;

    mutex_init(&current_nnt_device->lock);

    if (!is_alloc_chrdev_region)
    {
        goto ReturnOnFinished;
    }

    // Create device with a new minor number.
    current_nnt_device->device_minor_number = minor;
    current_nnt_device->device_number = MKDEV(MAJOR(device_number), minor);

    current_nnt_device->device_enabled = true;
    current_nnt_device->connectx_wa_slot_p1 = 0;

    /* Create device node. */
    device = device_create(nnt_driver_info.class_driver, NULL, current_nnt_device->device_number, NULL,
                           current_nnt_device->device_name);
    if (!device)
    {
        printk(KERN_ERR "Device creation failed\n");
        error = -EINVAL;
        goto ReturnOnFinished;
    }

    /* Init new device. */
    cdev_init(&current_nnt_device->mcdev, fop);

    /* Add device to the system. */
    error = cdev_add(&current_nnt_device->mcdev, current_nnt_device->device_number, count);
    if (error)
    {
        goto ReturnOnFinished;
    }

ReturnOnFinished:
    return error;
}

unsigned int get_vsc_address_by_type(struct nnt_device* current_nnt_device, int vsc_type)
{
    unsigned int vsc_address = 0;
    int type_check_result = 0;

    DEBUG_PRINTK("Starting VSC search for device ID: %d, looking for VSC type: %d\n",
                 current_nnt_device->pci_device->device, vsc_type);

    // DEBUG_PRINTK("Current PCI device details:\n");
    // DEBUG_PRINTK("  - Device ID: 0x%x\n", current_nnt_device->pci_device->device);
    // DEBUG_PRINTK("  - Vendor ID: 0x%x\n", current_nnt_device->pci_device->vendor);
    // DEBUG_PRINTK("  - Subsystem ID: 0x%x\n", current_nnt_device->pci_device->subsystem_device);
    // DEBUG_PRINTK("  - Subsystem Vendor: 0x%x\n", current_nnt_device->pci_device->subsystem_vendor);
    // DEBUG_PRINTK("  - Header Type: 0x%x\n", current_nnt_device->pci_device->hdr_type);
    // DEBUG_PRINTK("  - Class: 0x%x\n", current_nnt_device->pci_device->class);
    // DEBUG_PRINTK("  - Bus: 0x%x\n", current_nnt_device->pci_device->bus->number);
    // DEBUG_PRINTK("  - Device: %d\n", PCI_SLOT(current_nnt_device->pci_device->devfn));
    // DEBUG_PRINTK("  - Function: %d\n", PCI_FUNC(current_nnt_device->pci_device->devfn));
    // DEBUG_PRINTK("  - Domain: 0x%x\n", pci_domain_nr(current_nnt_device->pci_device->bus));
    // DEBUG_PRINTK("  - Error State: %d\n", current_nnt_device->pci_device->error_state);

    // Attempt to find the first Vendor-Specific Capability (VSC)
    vsc_address = pci_find_capability(current_nnt_device->pci_device, VENDOR_SPECIFIC_CAPABILITY_ID);
    DEBUG_PRINTK("pci_find_capability (VSC) returned: 0x%x for device ID: %d\n", vsc_address,
                 current_nnt_device->pci_device->device);

    // Check if the first VSC is of the specified type
    if (vsc_address)
    {
        DEBUG_PRINTK("Found VSC at address 0x%x\n", vsc_address);
        type_check_result = is_vsc_type(current_nnt_device, vsc_address, vsc_type);
        if (type_check_result == 1)
        {
            return vsc_address;
        }
        else if (type_check_result == -1)
        {
            DEBUG_PRINTK("Error reading VSC type at address 0x%x for device ID: %d\n", vsc_address,
                         current_nnt_device->pci_device->device);
            return vsc_address;
        }
    }

    // Iterate through the capabilities linked list to find the specified VSC type
    while ((vsc_address =
              pci_find_next_capability(current_nnt_device->pci_device, vsc_address, VENDOR_SPECIFIC_CAPABILITY_ID)))
    {
        DEBUG_PRINTK("Checking next VSC at address 0x%x\n", vsc_address);
        type_check_result = is_vsc_type(current_nnt_device, vsc_address, vsc_type);
        if (type_check_result == 1)
        {
            return vsc_address;
        }
        else if (type_check_result == -1)
        {
            DEBUG_PRINTK("Error reading VSC type at address 0x%x for device ID: %d\n", vsc_address,
                         current_nnt_device->pci_device->device);
            return vsc_address;
        }
    }

    DEBUG_PRINTK("No VSC of type: %d found for device ID: %d after checking all capabilities\n", vsc_type,
                 current_nnt_device->pci_device->device);
    return 0;
}

/* Return 1 if VSC is of type vsc_type */
int is_vsc_type(struct nnt_device* current_nnt_device, unsigned int vsec_address, unsigned int vsc_type)
{
    u_int8_t type = 0;
    int error = 0;
    /* Read the capability type field */
    if ((error = pci_read_config_byte(current_nnt_device->pci_device, (vsec_address + PCI_TYPE_OFFSET), &type)))
    {
        printk(KERN_ERR "Reading VSC type failed with error %d\n", error);
        return -1;
    }
    if (type == vsc_type)
    {
        return 1;
    }
    return 0;
}

int create_devices(dev_t device_number, struct file_operations* fop, int is_alloc_chrdev_region)
{
    struct nnt_device* current_nnt_device = NULL;
    struct nnt_device* temp_nnt_device = NULL;
    int minor = 0;
    int error = 0;

    DEBUG_PRINTK("Starting create_devices with is_alloc_chrdev_region=%d\n\n", is_alloc_chrdev_region);

    /* Create necessary number of the devices. */
    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        DEBUG_PRINTK("Processing device: %s, type=%d, minor=%d\n",
                     current_nnt_device->device_name,
                     current_nnt_device->device_type,
                     minor);

        /* Create the device file. */
        if ((error = create_device_file(current_nnt_device, device_number, minor, fop, is_alloc_chrdev_region)) != 0)
        {
            DEBUG_PRINTK("Failed to create device file, error: %d\n", error);
            goto ReturnOnFinished;
        }

        /* Members initialization. */
        current_nnt_device->pciconf_device.functional_vsc_offset = get_vsc_address_by_type(
          current_nnt_device, FUNCTIONAL_VSC); // Exists when device is in Functional/Zombiefish mode
        if (!current_nnt_device->pciconf_device.functional_vsc_offset)
        {
            DEBUG_PRINTK("No functional VSC found, checking for recovery VSC\n");
            current_nnt_device->pciconf_device.recovery_vsc_offset =
              get_vsc_address_by_type(current_nnt_device, RECOVERY_VSC);
            if (current_nnt_device->pciconf_device.recovery_vsc_offset != 0 &&
                current_nnt_device->pciconf_device.recovery_vsc_offset != RECOVERY_VSC_OFFSET_IN_CONFIG_SPACE)
            {
                printk(KERN_ERR "ERROR - Invalid recovery VSC offset: 0x%x (expected 0x54 or 0).\n",
                       current_nnt_device->pciconf_device.recovery_vsc_offset);
                return -1;
            }
        }
        current_nnt_device->vpd_capability_address =
          pci_find_capability(current_nnt_device->pci_device, PCI_CAP_ID_VPD);

        /* NNT_PCICONF_RECOVERY is for LF and late LF. */
        if (!current_nnt_device->pciconf_device.functional_vsc_offset &&
            current_nnt_device->device_type != NNT_PCI_MEMORY)
        {
            current_nnt_device->device_type = NNT_PCICONF_RECOVERY;
            DEBUG_PRINTK(
              "Device with device ID: %d is in LF or Late LF mode. current_nnt_device->device_type set to: %d\n",
              current_nnt_device->pci_device->device, current_nnt_device->device_type);
        }

        switch (current_nnt_device->device_type)
        {
            case NNT_PCICONF:
                DEBUG_PRINTK("Setting up NNT_PCICONF access methods for %s\n\n", current_nnt_device->device_name);
                current_nnt_device->access.read = read_pciconf;
                current_nnt_device->access.write = write_pciconf;
                current_nnt_device->access.init = init_pciconf;
                break;

            case NNT_PCICONF_RECOVERY:
                DEBUG_PRINTK("Setting up NNT_PCICONF_RECOVERY access methods for %s\n\n",
                             current_nnt_device->device_name);
                current_nnt_device->access.read = read_pciconf_no_vsec;
                current_nnt_device->access.write = write_pciconf_no_vsec;
                current_nnt_device->access.init = init_pciconf_no_vsec;
                break;

            case NNT_PCI_MEMORY:
                DEBUG_PRINTK("Setting up NNT_PCI_MEMORY access methods for %s\n\n", current_nnt_device->device_name);
                current_nnt_device->access.read = read_memory;
                current_nnt_device->access.write = write_memory;
                current_nnt_device->access.init = init_memory;
                break;
        }

        if (is_alloc_chrdev_region)
        {
            error = current_nnt_device->access.init(current_nnt_device);
            if (error)
            {
                DEBUG_PRINTK("Failed to initialize device access, error: %d\n", error);
            }
        }

        minor++;
    }

ReturnOnFinished:
    return error;
}

int create_nnt_devices(dev_t device_number,
                       int is_alloc_chrdev_region,
                       struct file_operations* fop,
                       enum nnt_device_type_flag nnt_device_flag,
                       unsigned int vendor_id,
                       int with_unknown)
{
    struct pci_dev* pci_device = NULL;
    int error_code = 0;
    int is_pciconf = 0;
    int is_pcicr = 0;

    DEBUG_PRINTK("create_nnt_devices called with:\n");
    DEBUG_PRINTK("  nnt_device_flag = %d (0=NNT_PCICONF_DEVICES, 1=NNT_PCI_DEVICES, 2=NNT_ALL_DEVICES)\n",
                 nnt_device_flag);
    DEBUG_PRINTK("  with_unknown = %d\n", with_unknown);
    DEBUG_PRINTK("  vendor_id = 0x%x\n", vendor_id);
    DEBUG_PRINTK("  is_alloc_chrdev_region = %d\n\n", is_alloc_chrdev_region);

    /* Find all Nvidia PCI devices. */
    while ((pci_device = pci_get_device(vendor_id, PCI_ANY_ID, pci_device)) != NULL)
    {
        is_pciconf = is_pciconf_device(pci_device->device);
        is_pcicr = is_pcicr_device(pci_device->device);

        DEBUG_PRINTK("Device 0x%x: is_pciconf=%d, is_pcicr=%d\n", pci_device->device, is_pciconf, is_pcicr);

        if ((nnt_device_flag == NNT_PCICONF_DEVICES) || (nnt_device_flag == NNT_ALL_DEVICES))
        {
            /* Create pciconf device. */
            if (with_unknown || is_pciconf)
            {
                if ((error_code = create_nnt_device(pci_device, NNT_PCICONF, is_alloc_chrdev_region)) != 0)
                {
                    printk(KERN_ERR "Failed to create NNT_PCICONF device\n");
                    goto ReturnOnFinished;
                }
            }
        }

        if ((nnt_device_flag == NNT_PCI_DEVICES) || (nnt_device_flag == NNT_ALL_DEVICES))
        {
            /* Create pci memory device. */
            if (with_unknown || is_pcicr)
            {
                if ((error_code = create_nnt_device(pci_device, NNT_PCI_MEMORY, is_alloc_chrdev_region)) != 0)
                {
                    printk(KERN_ERR "Failed to create NNT_PCI_MEMORY device\n");
                    goto ReturnOnFinished;
                }
            }
        }
    }

    /* Create the devices. */
    if ((error_code = create_devices(device_number, fop, is_alloc_chrdev_region)) != 0)
    {
        return error_code;
    }

ReturnOnFinished:
    return error_code;
}

int find_all_vendor_devices(unsigned int vendor_id)
{
    struct pci_dev* pci_device = NULL;
    int contiguous_device_numbers = 0;
    while ((pci_device = pci_get_device(vendor_id, PCI_ANY_ID, pci_device)) != NULL)
    {
        contiguous_device_numbers++;
    }
    return contiguous_device_numbers;
}

int get_amount_of_nvidia_devices(void)
{
    int contiguous_device_numbers = 0;
    /* Find all Mellanox & Nvidia PCI devices. */
    contiguous_device_numbers +=
      find_all_vendor_devices(NNT_MELLANOX_PCI_VENDOR) + find_all_vendor_devices(NNT_NVIDIA_PCI_VENDOR);
    return contiguous_device_numbers;
}

int mutex_lock_nnt(struct file* file)
{
    struct nnt_device* nnt_device;

    if (!file)
    {
        return 1;
    }

    nnt_device = file->private_data;

    if (!nnt_device)
    {
        return -EINVAL;
    }

    mutex_lock(&nnt_device->lock);

    return 0;
}

void mutex_unlock_nnt(struct file* file)
{
    struct nnt_device* nnt_device = file->private_data;

    if (nnt_device)
    {
        mutex_unlock(&nnt_device->lock);
    }
}

void destroy_nnt_devices(int is_alloc_chrdev_region)
{
    struct nnt_device* current_nnt_device;
    struct nnt_device* temp_nnt_device;

    /* free all nnt_devices */
    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        /* Character device is no longer, it must be properly destroyed. */
        if (is_alloc_chrdev_region)
        {
            cdev_del(&current_nnt_device->mcdev);
            device_destroy(nnt_driver_info.class_driver, current_nnt_device->device_number);
        }

        list_del(&current_nnt_device->entry);
        kfree(current_nnt_device);
    }
}

void destroy_nnt_devices_bc(void)
{
    struct nnt_device* current_nnt_device;
    struct nnt_device* temp_nnt_device;

    /* free all nnt_devices */
    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        /* Character device is no longer, it must be properly destroyed. */
        list_del(&current_nnt_device->entry);
        kfree(current_nnt_device);
    }
}

int destroy_nnt_device_bc(struct nnt_device* nnt_device)
{
    struct nnt_device* current_nnt_device;
    struct nnt_device* temp_nnt_device;
    unsigned int current_function;
    unsigned int current_device;

    /* Set private data to nnt structure. */
    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        struct pci_bus* pci_bus = pci_find_bus(current_nnt_device->dbdf.domain, current_nnt_device->dbdf.bus);
        if (!pci_bus)
        {
            return -ENXIO;
        }

        current_nnt_device->pci_device = pci_get_slot(pci_bus, current_nnt_device->dbdf.devfn);
        if (!current_nnt_device->pci_device)
        {
            return -ENXIO;
        }

        current_function = PCI_FUNC(current_nnt_device->dbdf.devfn);
        current_device = PCI_SLOT(current_nnt_device->dbdf.devfn);

        if ((current_nnt_device->dbdf.bus == nnt_device->dbdf.bus) &&
            (current_device == PCI_SLOT(nnt_device->dbdf.devfn)) &&
            (current_function == PCI_FUNC(nnt_device->dbdf.devfn)) &&
            (current_nnt_device->dbdf.domain == nnt_device->dbdf.domain))
        {
            /* Character device is no longer, it must be properly disabled. */
            current_nnt_device->device_enabled = false;
            printk(KERN_DEBUG "Device removed: domain: %d, bus: %d, device:%d, function:%d \n",
                   current_nnt_device->dbdf.domain, current_nnt_device->dbdf.bus, current_device, current_function);
            return 0;
        }
    }

    return 0;
}

int rescan(void)
{
    struct nnt_device* current_nnt_device = NULL;
    struct nnt_device* temp_nnt_device = NULL;
    struct pci_dev* new_pci_device;

    list_for_each_entry_safe(current_nnt_device, temp_nnt_device, &nnt_device_list, entry)
    {
        /* Check if rescan is needed */
        if (current_nnt_device->pci_device->error_state == pci_channel_io_normal)
        {
            printk(KERN_DEBUG "No need to rescan for device %s as its error state is in pci_channel_io_normal\n",
                   current_nnt_device->device_name);
            continue;
        }

        new_pci_device = NULL;
        while ((new_pci_device = pci_get_device(NNT_MELLANOX_PCI_VENDOR, current_nnt_device->pci_device->device,
                                                new_pci_device)) != NULL)
        {
            /* Checking if this is the device that we need to replace. */
            if (new_pci_device->device == current_nnt_device->pci_device->device &&
                PCI_FUNC(new_pci_device->devfn) == PCI_FUNC(current_nnt_device->pci_device->devfn) &&
                new_pci_device->bus->number == current_nnt_device->pci_device->bus->number &&
                pci_domain_nr(new_pci_device->bus) == pci_domain_nr(current_nnt_device->pci_device->bus))
            {
                printk(
                  KERN_DEBUG
                  "A new instance of the pci_dev structure has been discovered. new  pointer address: %p, old pointer address: %p, mst device name: %s\n",
                  new_pci_device, current_nnt_device->pci_device, current_nnt_device->device_name);

                if (new_pci_device->error_state != pci_channel_io_normal)
                {
                    printk(
                      KERN_DEBUG
                      "The new instance of the pci_dev structure is also not in pci_channel_io_normal, error_state = %d.\n",
                      new_pci_device->error_state);
                }

                current_nnt_device->pci_device = new_pci_device;
                current_nnt_device->pci_device->error_state = pci_channel_io_normal;
                break;
            }
        }
    }

    return 0;
}
