// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include <stdbool.h>
#include <sys/ioctl.h>
#include <sys/eventfd.h>
#include <sys/time.h>
#include <sys/mman.h>
#include <linux/vfio.h>

#include "util/util.h"
#include "util/mmio.h"

#include "dev.h"
#include "cmd.h"
#include "fwpages.h"
#include "hcacap.h"
#include "eq.h"
#include "func.h"
#include "ifc.h"

#define FW_VER_STR_LEN (18)

struct __packed mlx5_reg_host_endianness {
	uint8_t he;
	uint8_t rsvd[15];
};

#define POLL_HEALTH_INTERVAL 100 /* ms */
#define MAX_HEALTH_MISSES 3

#if __BYTE_ORDER == __LITTLE_ENDIAN
#define MLX5_SET_HOST_ENDIANNESS 0
#elif __BYTE_ORDER == __BIG_ENDIAN
#define MLX5_SET_HOST_ENDIANNESS 0x80
#else
#error Host endianness not defined
#endif

enum mlx5_cmd_addr_l_sz_offset {
	MLX5_NIC_IFC_OFFSET = 8,
};

enum {
	MLX5_NIC_IFC_DISABLED = 1,
	MLX5_NIC_IFC_SW_RESET = 7,
};

enum mlx5dv_vfio_context_attr_flags {
	MLX5DV_VFIO_CTX_FLAGS_INIT_LINK_DOWN = 1 << 0,
};

static uint8_t mlx5_vfio_get_nic_state(struct mlx5_bar *bar0)
{
	return (be32toh(mmio_read32_be(&bar0->cmdq_addr_l_sz)) >> 8) & 7;
}

static void mlx5_vfio_set_nic_state(struct mlx5_bar *bar0, uint8_t state)
{
	uint32_t cur_cmdq_addr_l_sz;

	cur_cmdq_addr_l_sz = be32toh(mmio_read32_be(&bar0->cmdq_addr_l_sz));
	mmio_write32_be(&bar0->cmdq_addr_l_sz,
			htobe32((cur_cmdq_addr_l_sz & 0xFFFFF000) |
				state << MLX5_NIC_IFC_OFFSET));
}

static bool sensor_pci_not_working(struct mlx5_bar *bar0)
{
	/* Offline PCI reads return 0xffffffff */
	return (be32toh(mmio_read32_be(&bar0->health.fw_ver)) == 0xffffffff);
}

enum mlx5_fatal_assert_bit_offsets {
	MLX5_RFR_OFFSET = 31,
};

enum {
	MLX5_SEVERITY_MASK = 0x7,
	MLX5_SEVERITY_VALID_MASK = 0x8,
};

enum mlx5_rfr_severity_bit_offsets {
	MLX5_RFR_BIT_OFFSET = 0x7,
};

enum {
	LOGLEVEL_EMERG = 0,
	LOGLEVEL_ALERT = 1,
	LOGLEVEL_CRIT = 2,
	LOGLEVEL_ERR = 3,
	LOGLEVEL_WARNING = 4,
	LOGLEVEL_NOTICE = 5,
	LOGLEVEL_INFO = 6,
	LOGLEVEL_DEBUG = 7,
};

static int mlx5_health_get_severity(uint8_t rfr_severity)
{
	return rfr_severity & MLX5_SEVERITY_VALID_MASK ?
		       rfr_severity & MLX5_SEVERITY_MASK :
		       LOGLEVEL_ERR;
}

/*
 * RFR: Recovery Flow Required. When set, the error indicated in the
 * health buffer cannot be recovered by the device without a flow
 * involving a reset.
 */
static int mlx5_health_get_rfr(uint8_t rfr_severity)
{
	return rfr_severity >> MLX5_RFR_BIT_OFFSET;
}

static bool sensor_fw_synd_rfr(struct mlx5_bar *bar0)
{
	uint8_t rfr_severity = mmio_read8(&bar0->health.rfr_severity);
	uint8_t rfr = mlx5_health_get_rfr(rfr_severity);
	uint8_t synd = mmio_read8(&bar0->health.synd);

	return (rfr && synd);
}

static uint64_t mlx5_health_get_hw_timestamp(struct mlx5_bar *bar0)
{
	return ((uint64_t)be32toh(mmio_read32_be(&bar0->internal_timer_h)))
		       << 32 |
	       be32toh(mmio_read32_be(&bar0->internal_timer_l));
}

enum {
	MLX5_SENSOR_NO_ERR = 0,
	MLX5_SENSOR_PCI_COMM_ERR = 1,
	MLX5_SENSOR_NIC_DISABLED = 3,
	MLX5_SENSOR_NIC_SW_RESET = 4,
	MLX5_SENSOR_FW_SYND_RFR = 5,
};

static const char *sensor_str(uint32_t sensor)
{
	switch (sensor) {
	case MLX5_SENSOR_PCI_COMM_ERR:
		return "PCI communication error";
	case MLX5_SENSOR_NIC_DISABLED:
		return "NIC disabled";
	case MLX5_SENSOR_NIC_SW_RESET:
		return "NIC software reset";
	case MLX5_SENSOR_FW_SYND_RFR:
		return "firmware syndrome recovery flow required";
	case MLX5_SENSOR_NO_ERR:
		return "no error";
	default:
		return "unknown sensor error";
	}
}

enum {
	MLX5_HEALTH_SYNDR_FW_ERR = 0x1,
	MLX5_HEALTH_SYNDR_IRISC_ERR = 0x7,
	MLX5_HEALTH_SYNDR_HW_UNRECOVERABLE_ERR = 0x8,
	MLX5_HEALTH_SYNDR_CRC_ERR = 0x9,
	MLX5_HEALTH_SYNDR_FETCH_PCI_ERR = 0xa,
	MLX5_HEALTH_SYNDR_HW_FTL_ERR = 0xb,
	MLX5_HEALTH_SYNDR_ASYNC_EQ_OVERRUN_ERR = 0xc,
	MLX5_HEALTH_SYNDR_EQ_ERR = 0xd,
	MLX5_HEALTH_SYNDR_EQ_INV = 0xe,
	MLX5_HEALTH_SYNDR_FFSER_ERR = 0xf,
	MLX5_HEALTH_SYNDR_HIGH_TEMP = 0x10,
};

static const char *hsynd_str(uint8_t synd)
{
	switch (synd) {
	case MLX5_HEALTH_SYNDR_FW_ERR:
		return "firmware internal error";
	case MLX5_HEALTH_SYNDR_IRISC_ERR:
		return "irisc not responding";
	case MLX5_HEALTH_SYNDR_HW_UNRECOVERABLE_ERR:
		return "unrecoverable hardware error";
	case MLX5_HEALTH_SYNDR_CRC_ERR:
		return "firmware CRC error";
	case MLX5_HEALTH_SYNDR_FETCH_PCI_ERR:
		return "ICM fetch PCI error";
	case MLX5_HEALTH_SYNDR_HW_FTL_ERR:
		return "HW fatal error\n";
	case MLX5_HEALTH_SYNDR_ASYNC_EQ_OVERRUN_ERR:
		return "async EQ buffer overrun";
	case MLX5_HEALTH_SYNDR_EQ_ERR:
		return "EQ error";
	case MLX5_HEALTH_SYNDR_EQ_INV:
		return "Invalid EQ referenced";
	case MLX5_HEALTH_SYNDR_FFSER_ERR:
		return "FFSER error";
	case MLX5_HEALTH_SYNDR_HIGH_TEMP:
		return "High temperature";
	default:
		return "unrecognized error";
	}
}

static const char *mlx5_loglevel_str(int level)
{
	switch (level) {
	case LOGLEVEL_EMERG:
		return "EMERGENCY";
	case LOGLEVEL_ALERT:
		return "ALERT";
	case LOGLEVEL_CRIT:
		return "CRITICAL";
	case LOGLEVEL_ERR:
		return "ERROR";
	case LOGLEVEL_WARNING:
		return "WARNING";
	case LOGLEVEL_NOTICE:
		return "NOTICE";
	case LOGLEVEL_INFO:
		return "INFO";
	case LOGLEVEL_DEBUG:
		return "DEBUG";
	}
	return "Unknown log level";
}

static void mlx5_health_info_print(struct vfio_mlx5_dev *dev)
{
	struct mlx5_health_record *h_rec = &dev->health_rec;
	size_t i;

	dev_info(dev, "time_stamp %lu", h_rec->time_stamp);
	dev_info(dev, "hw_timestamp %lu", h_rec->hw_timestamp);
	dev_info(dev, "sensor 0x%x (%s)", h_rec->sensor,
		 sensor_str(h_rec->sensor));
	dev_info(dev, "synd 0x%x (%s)", h_rec->synd, hsynd_str(h_rec->synd));

	dev_info(dev, "assert_var:");
	for (i = 0; i < ARRAY_SIZE(h_rec->assert_var); i++)
		dev_info(dev, "[%zu] 0x%08x", i, h_rec->assert_var[i]);

	dev_info(dev, "assert_exit_ptr 0x%08x", h_rec->assert_exit_ptr);
	dev_info(dev, "assert_callra 0x%08x", h_rec->assert_callra);
	dev_info(dev, "fw_ver %u.%u.%u", (h_rec->fw_ver >> 24) & 0xff,
		 (h_rec->fw_ver >> 16) & 0xff, h_rec->fw_ver & 0xffff);
	dev_info(dev, "time %u", h_rec->time);
	dev_info(dev, "hw_id 0x%08x", h_rec->hw_id);
	dev_info(dev, "rfr %d", h_rec->rfr);
	dev_info(dev, "severity %d (%s)", h_rec->severity,
		 mlx5_loglevel_str(h_rec->severity));
	dev_info(dev, "irisc_index %d", h_rec->irisc_index);
	dev_info(dev, "synd 0x%x: %s", h_rec->synd, hsynd_str(h_rec->synd));
	dev_info(dev, "ext_synd 0x%04x", h_rec->ext_synd);
	dev_info(dev, "raw fw_ver 0x%08x", h_rec->fw_ver);
}

/*
 * Check for fatal sensors.
 *
 * @dev: vfio_mlx5_dev struct
 * @sensor_val: pointer to uint32_t to store the sensor value
 *
 * @return: 0 if no fatal sensors are detected,
 *    -EIO if a fatal sensor is detected
 *    fatal errors are:
 *        - PCI communication error
 *        - NIC disabled
 */
static int mlx5_health_check_fatal_sensors(struct vfio_mlx5_dev *dev,
					   uint32_t *sensor_val)
{
	uint8_t nic_state;

	if (sensor_pci_not_working(dev->heap->bar_map)) {
		*sensor_val = MLX5_SENSOR_PCI_COMM_ERR;
		goto err_pci;
	}

	nic_state = mlx5_vfio_get_nic_state(dev->heap->bar_map);
	if (nic_state == MLX5_NIC_IFC_DISABLED) {
		*sensor_val = MLX5_SENSOR_NIC_DISABLED;
		/* NIC is disabled, disable pci access*/
		goto err_pci;
	}

	if (nic_state == MLX5_NIC_IFC_SW_RESET)
		*sensor_val = MLX5_SENSOR_NIC_SW_RESET;

	if (sensor_fw_synd_rfr(dev->heap->bar_map))
		*sensor_val = MLX5_SENSOR_FW_SYND_RFR;

	return 0;

err_pci:
	dev->pci_err = true;
	return -EIO;
}

static int mlx5_health_info_record(struct vfio_mlx5_dev *dev)
{
	struct health_buffer *h_buff = &dev->heap->bar_map->health;
	struct mlx5_health_record *h_rec = &dev->health_rec;
	uint8_t rfr_severity;
	int severity, ret;
	size_t i;

	h_rec->hw_timestamp = mlx5_health_get_hw_timestamp(dev->heap->bar_map);
	ret = mlx5_health_check_fatal_sensors(dev, &h_rec->sensor);

	/* no point on reading other values if pci is not working
	 * hope previous record holds some useful information
	 */
	if (ret && dev->pci_err)
		return ret;

	/* Read all values into native endianness struct */
	for (i = 0; i < ARRAY_SIZE(h_rec->assert_var); i++)
		h_rec->assert_var[i] =
			be32toh(mmio_read32_be(h_rec->assert_var + i));

	h_rec->assert_exit_ptr =
		be32toh(mmio_read32_be(&h_buff->assert_exit_ptr));
	h_rec->assert_callra = be32toh(mmio_read32_be(&h_buff->assert_callra));
	h_rec->time = be32toh(mmio_read32_be(&h_buff->time));
	h_rec->hw_id = be32toh(mmio_read32_be(&h_buff->hw_id));

	rfr_severity = mmio_read8(&h_buff->rfr_severity);
	severity = mlx5_health_get_severity(rfr_severity);
	h_rec->rfr = mlx5_health_get_rfr(rfr_severity);
	h_rec->severity = severity;

	h_rec->irisc_index = mmio_read8(&h_buff->irisc_index);
	h_rec->synd = mmio_read8(&h_buff->synd);
	h_rec->ext_synd = be16toh(mmio_read16_be(&h_buff->ext_synd));
	h_rec->fw_ver = be32toh(mmio_read32_be(&h_buff->fw_ver));

	if (h_rec->synd && h_rec->severity <= LOGLEVEL_CRIT) {
		dev_err(dev, "Fatal health fw syndrome error detected");
		return -EIO;
	}

	return ret;
}

static int mlx5_vfio_poll_health(struct vfio_mlx5_dev *dev)
{
	struct mlx5_health_state *hstate = &dev->health_state;
	struct timespec ts;
	uint32_t count;
	uint64_t time;
	int ret;

	clock_gettime(CLOCK_MONOTONIC, &ts);

	time = (uint64_t)ts.tv_sec * 1000 + ts.tv_nsec / 1000000;
	if (time - hstate->prev_time < POLL_HEALTH_INTERVAL)
		return 0;

	dev->health_rec.time_stamp = time;
	ret = mlx5_health_info_record(dev);
	if (ret)
		goto err;

	/* check if health counter is incrementing */
	count = be32toh(mmio_read32_be(&dev->heap->bar_map->health_counter)) &
		0xffffff;
	if (count == hstate->prev_count)
		++hstate->miss_counter;
	else
		hstate->miss_counter = 0;

	hstate->prev_time = time;
	hstate->prev_count = count;
	if (hstate->miss_counter == MAX_HEALTH_MISSES) {
		dev_warn(dev,
			 "device's health compromised - reached miss count");
		ret = 0; /* Not catastrophic, just log it */
		goto err;
	}

	return 0;
err:
	mlx5_health_info_print(dev);
	return ret;
}

int vfio_mlx5_events_process(struct vfio_mlx5_dev *dev)
{
	int ret = mlx5_vfio_poll_health(dev);
	if (ret)
		return ret;

	return mlx5_vfio_async_events_process(dev);
}

static int mlx5_vfio_access_reg(struct vfio_mlx5_dev *dev, const void *data_in,
				int size_in, void *data_out, int size_out,
				uint16_t reg_id, int arg, int write)
{
	int outlen = DEVX_ST_SZ_BYTES(access_register_out) + size_out;
	int inlen = DEVX_ST_SZ_BYTES(access_register_in) + size_in;
	uint32_t *out = NULL;
	uint32_t *in = NULL;
	void *data;
	int ret;

	in = calloc(1, inlen);
	out = calloc(1, outlen);
	if (!in || !out) {
		dev_err(dev,
			"Failed to allocate memory for register access, ret(%d)",
			-errno);
		ret = -errno;
		goto out;
	}

	data = DEVX_ADDR_OF(access_register_in, in, register_data);
	memcpy(data, data_in, size_in);

	DEVX_SET(access_register_in, in, opcode, MLX5_CMD_OP_ACCESS_REG);
	DEVX_SET(access_register_in, in, op_mod, !write);
	DEVX_SET(access_register_in, in, argument, arg);
	DEVX_SET(access_register_in, in, register_id, reg_id);

	ret = mlx5_vfio_cmd_exec(dev, in, inlen, out, outlen);
	if (ret) {
		dev_err(dev,
			"Failed to execute register access command for reg_id 0x%x, ret(%d)",
			reg_id, ret);
		goto out;
	}

	data = DEVX_ADDR_OF(access_register_out, out, register_data);
	memcpy(data_out, data, size_out);

out:
	free(out);
	free(in);
	return ret;
}

static int mlx5_vfio_set_hca_ctrl(struct vfio_mlx5_dev *dev)
{
	struct mlx5_reg_host_endianness he_out = {};
	struct mlx5_reg_host_endianness he_in = {};

	he_in.he = MLX5_SET_HOST_ENDIANNESS;
	return mlx5_vfio_access_reg(dev, &he_in, sizeof(he_in), &he_out,
				    sizeof(he_out), MLX5_REG_HOST_ENDIANNESS, 0,
				    1);
}

static int mlx5_vfio_init_hca(struct vfio_mlx5_dev *dev)
{
	uint32_t out[DEVX_ST_SZ_DW(init_hca_out)] = {};
	uint32_t in[DEVX_ST_SZ_DW(init_hca_in)] = {};

	DEVX_SET(init_hca_in, in, opcode, MLX5_CMD_OP_INIT_HCA);
	return mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, sizeof(out));
}

static int mlx5_vfio_teardown_hca_regular(struct vfio_mlx5_dev *dev)
{
	uint32_t out[DEVX_ST_SZ_DW(teardown_hca_out)] = {};
	uint32_t in[DEVX_ST_SZ_DW(teardown_hca_in)] = {};

	dev_info(dev, "Graceful teardown HCA ...");
	DEVX_SET(teardown_hca_in, in, opcode, MLX5_CMD_OP_TEARDOWN_HCA);
	DEVX_SET(teardown_hca_in, in, profile,
		 MLX5_TEARDOWN_HCA_IN_PROFILE_GRACEFUL_CLOSE);
	return mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, sizeof(out));
}

static int mlx5_vfio_teardown_hca_fast(struct vfio_mlx5_dev *dev)
{
#define MLX5_FAST_TEARDOWN_WAIT_MS 3000
#define MLX5_FAST_TEARDOWN_WAIT_ONCE_MS 1
	uint32_t out[DEVX_ST_SZ_DW(teardown_hca_out)] = {};
	uint32_t in[DEVX_ST_SZ_DW(teardown_hca_in)] = {};
	struct mlx5_bar *bar0 = dev->heap->bar_map;
	int waited = 0, state, err;

	DEVX_SET(teardown_hca_in, in, opcode, MLX5_CMD_OP_TEARDOWN_HCA);
	DEVX_SET(teardown_hca_in, in, profile,
		 MLX5_TEARDOWN_HCA_IN_PROFILE_PREPARE_FAST_TEARDOWN);
	dev_info(dev, "Fast teardown HCA ...");
	err = mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, sizeof(out));
	if (err) {
		dev_err(dev,
			"Failed to execute fast teardown HCA command, err(%d)",
			err);
		return err;
	}

	state = DEVX_GET(teardown_hca_out, out, state);
	if (state == MLX5_TEARDOWN_HCA_OUT_FORCE_STATE_FAIL) {
		dev_err(dev, "teardown with fast mode failed");
		return -EIO;
	}

	mlx5_vfio_set_nic_state(bar0, MLX5_NIC_IFC_DISABLED);
	do {
		if (mlx5_vfio_get_nic_state(bar0) == MLX5_NIC_IFC_DISABLED)
			break;
		usleep(MLX5_FAST_TEARDOWN_WAIT_ONCE_MS * 1000);
		waited += MLX5_FAST_TEARDOWN_WAIT_ONCE_MS;
	} while (waited < MLX5_FAST_TEARDOWN_WAIT_MS);

	if (mlx5_vfio_get_nic_state(bar0) != MLX5_NIC_IFC_DISABLED) {
		dev_err(dev, "NIC IFC still %d after %ums.",
			mlx5_vfio_get_nic_state(bar0), waited);
		return -EIO;
	}

	return 0;
}

static int mlx5_vfio_teardown_hca(struct vfio_mlx5_dev *dev, bool post_init_hca)
{
	int ret;

	if (MLX5_CAP_GEN(dev, fast_teardown)) {
		ret = mlx5_vfio_teardown_hca_fast(dev);
		if (!ret)
			return 0;

		dev_err(dev,
			"Failed to fast teardown, falling back to regular teardown, ret(%d)",
			ret);
	}

	dev_info(dev, "No Fast teardown, falling back to regular teardown");

	/*
	 * TODO: For Graceful teardown + disable_hca, must reclaim all pages,
	 * see NIC HCA teardown flow in PRM
	 */
	ret = -EOPNOTSUPP;

	if (post_init_hca) {
		ret = mlx5_vfio_teardown_hca_regular(dev);
		if (ret)
			dev_err(dev, "Graceful teardown failed, ret(%d)", ret);
	}

	/* TODO: disable_hca here, but need to reclaim all pages first */
	return ret;
}

static int mlx5_vfio_set_issi(struct vfio_mlx5_dev *dev)
{
	uint32_t query_out[DEVX_ST_SZ_DW(query_issi_out)] = {};
	uint32_t query_in[DEVX_ST_SZ_DW(query_issi_in)] = {};
	uint32_t set_out[DEVX_ST_SZ_DW(set_issi_out)] = {};
	uint32_t set_in[DEVX_ST_SZ_DW(set_issi_in)] = {};
	uint32_t sup_issi;
	int ret;

	DEVX_SET(query_issi_in, query_in, opcode, MLX5_CMD_OP_QUERY_ISSI);
	ret = mlx5_vfio_cmd_exec(dev, query_in, sizeof(query_in), query_out,
				 sizeof(query_out));
	if (ret) {
		dev_err(dev,
			"Failed to execute MLX5_CMD_OP_QUERY_ISSI command, ret(%d)",
			ret);
		return ret;
	}

	sup_issi = DEVX_GET(query_issi_out, query_out, supported_issi_dw0);

	if (!(sup_issi & (1 << 1))) {
		ret = -EOPNOTSUPP;
		dev_err(dev, "ISSI 1 is not supported by the device, ret(%d)",
			ret);
		return ret;
	}

	DEVX_SET(set_issi_in, set_in, opcode, MLX5_CMD_OP_SET_ISSI);
	DEVX_SET(set_issi_in, set_in, current_issi, 1);
	ret = mlx5_vfio_cmd_exec(dev, set_in, sizeof(set_in), set_out,
				 sizeof(set_out));
	if (ret)
		dev_err(dev,
			"Failed to execute MLX5_CMD_OP_SET_ISSI command, ret(%d)",
			ret);

	return ret;
}

static int fw_initializing(struct mlx5_bar *init_seg)
{
	return be32toh(mmio_read32_be(&init_seg->initializing)) >> 31;
}

#define FW_INIT_TIMEOUT_SEC (60)
#define FW_INIT_TIMEOUT_MS (FW_INIT_TIMEOUT_SEC * 1000)

static int wait_fw_init(struct vfio_mlx5_dev *dev)
{
	struct mlx5_bar *init_seg = dev->heap->bar_map;
	int waited_ms = 0;

	/* Check if FW is already ready */
	if (!fw_initializing(init_seg))
		return 0;

	dev_info(dev, "waiting fw init ...");
	while (fw_initializing(init_seg)) {
		usleep(1000); /* 1ms */
		if (++waited_ms >= FW_INIT_TIMEOUT_MS) {
			errno = ETIMEDOUT;
			dev_err(dev, "FW wait init timeout after %dms",
				waited_ms);
			return -errno;
		}
	}

	dev_info(dev, "FW init is ready after %dms", waited_ms);
	return 0;
}

static int mlx5_vfio_init_bar0(struct vfio_mlx5_dev *dev)
{
	struct vfio_region_info reg = { .argsz = sizeof(reg), .index = 0 };
	void *base;
	int err;

	err = ioctl(dev->device_fd, VFIO_DEVICE_GET_REGION_INFO, &reg);
	if (err)
		return -errno;

	base = mmap(NULL, reg.size, PROT_READ | PROT_WRITE, MAP_SHARED,
		    dev->device_fd, reg.offset);
	if (base == MAP_FAILED)
		return -errno;

	dev_dbg(dev, "Bar0: %p, size: %zu", base, reg.size);

	dev->heap->bar_map = base;
	dev->heap->bar_map_size = reg.size;

	return 0;
}

static void mlx5_vfio_uninit_bar0(struct vfio_mlx5_dev *dev)
{
	if (!dev->heap->bar_map)
		return;
	munmap(dev->heap->bar_map, dev->heap->bar_map_size);
	dev->heap->bar_map = NULL;
}

static uint16_t fw_rev_maj(struct mlx5_bar *bar0)
{
	return be32toh(mmio_read32_be(&bar0->fw_rev)) & 0xffff;
}

static uint16_t fw_rev_min(struct mlx5_bar *bar0)
{
	return be32toh(mmio_read32_be(&bar0->fw_rev)) >> 16;
}

static uint16_t fw_rev_sub(struct mlx5_bar *bar0)
{
	return be32toh(mmio_read32_be(&bar0->cmdif_rev_fw_sub)) & 0xffff;
}

static int mlx5_vfio_pci_config(int device_fd)
{
	struct vfio_region_info pci_config_reg = {};
	uint16_t pci_com_buf = htole16(0x6);
	ssize_t write_len = sizeof(pci_com_buf);
	char buffer[4096];

	pci_config_reg.argsz = sizeof(pci_config_reg);
	pci_config_reg.index = VFIO_PCI_CONFIG_REGION_INDEX;

	if (ioctl(device_fd, VFIO_DEVICE_GET_REGION_INFO, &pci_config_reg)) {
		log_error("Failed to get PCI config region info, err(%d)",
			  -errno);
		return -errno;
	}

	if (pwrite(device_fd, &pci_com_buf, write_len,
		   pci_config_reg.offset + 0x4) != write_len) {
		log_error("Failed to write PCI command register, err(%d)",
			  -errno);
		return -errno;
	}

	if (pread(device_fd, buffer, sizeof(buffer), pci_config_reg.offset) !=
	    sizeof(buffer)) {
		log_error("Failed to read PCI config space, err(%d)", -errno);
		return -errno;
	}

	return 0;
}

static int mlx5_device_init(struct vfio_mlx5_dev *dev)
{
	int non_boot_pages = 0;
	int boot_pages = 0;
	int ret;

	ret = mlx5_vfio_pci_config(dev->device_fd);
	if (ret) {
		dev_err(dev, "Failed to configure PCI, err(%d)", ret);
		return ret;
	}

	ret = wait_fw_init(dev);
	if (ret) {
		dev_err(dev, "Waiting FW init after pci config, err (%d)", ret);
		return ret;
	}

	ret = mlx5_vfio_cmd_interface_init(dev);
	if (ret) {
		dev_err(dev, "Failed to init cmd interface, err(%d)", ret);
		return ret;
	}

	dev_info(dev, "firmware version: %d.%d.%04d",
		 fw_rev_maj(dev->heap->bar_map), fw_rev_min(dev->heap->bar_map),
		 fw_rev_sub(dev->heap->bar_map));

	ret = wait_fw_init(dev);
	if (ret) {
		dev_err(dev, "Waiting FW init bit after cmdq, err (%d)", ret);
		goto err_cmd;
	}

	ret = mlx5_func_enable_hca(dev, 0);
	if (ret) {
		dev_err(dev, "Failed to enable hca, err(%d)", ret);
		goto err_cmd;
	}

	ret = mlx5_vfio_set_issi(dev);
	if (ret) {
		dev_err(dev, "Failed to set issi, err(%d)", ret);
		goto err_disable_hca;
	}

	ret = mlx5_vfio_satisfy_startup_pages(dev, 1, &boot_pages);
	if (ret) {
		dev_err(dev, "Failed to satisfy boot startup pages, err(%d)",
			ret);
		goto err_disable_hca;
	}

	ret = mlx5_vfio_set_hca_ctrl(dev);
	if (ret) {
		dev_err(dev, "Failed to set hca ctrl, err(%d)", ret);
		goto err_disable_hca;
	}

	ret = mlx5_vfio_set_hca_cap(dev);
	if (ret) {
		dev_err(dev, "Failed to set hca cap, err(%d)", ret);
		goto err_disable_hca;
	}

	ret = mlx5_vfio_satisfy_startup_pages(dev, 0, &non_boot_pages);
	if (ret) {
		dev_err(dev,
			"Failed to satisfy non-boot startup pages, err(%d)",
			ret);
		goto err_disable_hca;
	}

	ret = mlx5_vfio_init_hca(dev);
	if (ret) {
		dev_err(dev, "Failed to init hca, err(%d)", ret);
		goto err_disable_hca;
	}

	return 0;

/* No need to free boot/startup pages, the entire page allocator is going down
 * with us :)
 */
err_disable_hca:
	dev_err(dev, "Teardown-HCA due to error, err(%d)", ret);
	mlx5_vfio_teardown_hca(dev, false /* post_init_hca */);
err_cmd:
	mlx5_vfio_cmd_interface_uninit(dev);
	return ret;
}

static void mlx5_device_uninit(struct vfio_mlx5_dev *dev)
{
	mlx5_vfio_teardown_hca(dev, true /* post_init_hca */);
	mlx5_vfio_cmd_interface_uninit(dev);
}

#define CMD_EVENT_STR(x) [CMD_EVENT_##x] = #x
static const char *cmd_event_names[CMD_EVENT_COUNT] = {

	CMD_EVENT_STR(POST),	      CMD_EVENT_STR(EXEC),
	CMD_EVENT_STR(POST_ERROR),    CMD_EVENT_STR(COMP),
	CMD_EVENT_STR(COMP_ERROR),    CMD_EVENT_STR(COMP_OUT_FAIL),
	CMD_EVENT_STR(TIMEOUT_ERROR), CMD_EVENT_STR(GROW),
	CMD_EVENT_STR(GROW_FAIL),     CMD_EVENT_STR(COPY_ERROR),
	CMD_EVENT_STR(BUSY_ERROR),
};

int mlx5_device_heap_alloc(struct vfio_mlx5_dev *dev,
			   struct page_allocator *page_alloc,
			   const void *vmh_vaddr)
{
	struct mlx5_dev_heap *heap = calloc(1, sizeof(struct mlx5_dev_heap));

	if (!heap) {
		log_error("%s Failed to allocate mlx5_dev_heap", dev->pci_bdf);
		return -ENOMEM;
	}

	/* TODO: move to heap */
	dev->page_alloc = page_alloc;

	heap->vmh_vaddr = (uintptr_t)vmh_vaddr;

	event_log_init(&heap->devlog, "dev", libmlx5_logger.outf,
		       libmlx5_logger.errf);
	event_log_prefix_set(&heap->devlog, "dev(%s)", dev->pci_bdf);
	event_log_level_set(&heap->devlog, libmlx5_logger.level);

	event_log_init(&heap->cmdlog, "cmd", libmlx5_logger.outf,
		       libmlx5_logger.errf);
	event_log_prefix_set(&heap->cmdlog, "cmd(%s)", dev->pci_bdf);
	event_log_level_set(&heap->cmdlog, libmlx5_logger.level);
	event_log_set_stats(&heap->cmdlog, dev->cmd_stats, cmd_event_names,
			    CMD_EVENT_COUNT);

	int err = mlx5_vfio_fwpage_reqs_init(&heap->fwpage_reqs, dev->pci_bdf,
					     dev->page_events);
	if (err) {
		log_error("Failed to init fwpage_reqs, err(%d)", err);
		free(heap);
		return err;
	}
	dev->heap = heap;
	return 0;
}

void mlx5_device_heap_free(struct vfio_mlx5_dev *dev)
{
	if (!dev->heap)
		return;

	mlx5_vfio_fwpage_reqs_free(&dev->heap->fwpage_reqs);
	free(dev->heap);
	dev->heap = NULL;
}

int mlx5_vfio_device_init(struct vfio_mlx5_dev *dev,
			  struct page_allocator *page_alloc,
			  const void *vmh_vaddr)
{
	struct timespec ts;
	int err;

	err = clock_gettime(CLOCK_MONOTONIC, &ts);
	if (err) {
		log_error("Failed to get current time, err(%d), errno(%d)", err,
			  errno);
		return err;
	}

	err = mlx5_device_heap_alloc(dev, page_alloc, vmh_vaddr);
	if (err) {
		log_error("Failed to allocate mlx5_dev_heap, err(%d)", err);
		return err;
	}

	err = mlx5_vfio_init_bar0(dev);
	if (err < 0) {
		dev_err(dev, "Failed to init bar0, err(%d)", err);
		mlx5_device_heap_free(dev);
		return err;
	}

	err = mlx5_device_init(dev);
	if (err) {
		dev_err(dev, "Failed to initialize mlx5 device, err(%d)", err);
		goto err_init;
	}

	err = mlx5_func_setup_vfs(dev, dev->num_vfs);
	if (err) {
		dev_err(dev, "Failed to setup vfs, err(%d)", err);
		goto err_vfs;
	}

	/* Keep this LAST in vfio_mlx5_device_add(), as after the eq is created,
	 * using mlx5_vfio_cmd_exec() will fail. Other APIs like
	 * vfio_mlx5_events_process() will use mlx5_vfio_cmd_post_async() API
	 */
	err = mlx5_vfio_create_async_eq(dev);
	if (err) {
		dev_err(dev, "Failed to create async EQs, err(%d)", err);
		goto err_eq;
	}

	dev->state = VFIO_MLX5_STATE_ENABLED;
	dev_info(dev, "Device initialized successfully\n");

	return 0;

err_eq:
	mlx5_func_teardown_vfs(dev, dev->num_vfs);
err_vfs:
	mlx5_device_uninit(dev);
err_init:
	mlx5_vfio_uninit_bar0(dev);
	mlx5_device_heap_free(dev);
	return err;
}

void mlx5_vfio_device_uninit(struct vfio_mlx5_dev *dev)
{
	/* Keep this FIRST in vfio_mlx5_device_del(), as unint/del flow will
	 * only be using mlx5_vfio_cmd_exec() blocking API, and for it to work
	 * properly, the EQs must be destroyed first.
	 */

	dev->state = VFIO_MLX5_STATE_DISABLED;

	if (!dev->heap || !dev->heap->bar_map) {
		log_error("%s Skip uninit, device bar is not mapped",
			  dev->pci_bdf);
		goto heap_free;
	}

	vfio_mlx5_dev_stats_dump(dev, NULL);
	mlx5_vfio_destroy_async_eqs(dev);
	mlx5_func_teardown_vfs(dev, dev->num_vfs);
	mlx5_device_uninit(dev);
	mlx5_vfio_uninit_bar0(dev);
	dev_info(dev, "Device uninitialized");
heap_free:
	mlx5_device_heap_free(dev);
}

int mlx5_vfio_dev_suspend(struct vfio_mlx5_dev *dev)
{
	dev_info(dev, "Suspending device");

	dev->state = VFIO_MLX5_STATE_SUSPENDED;

	mlx5_vfio_uninit_bar0(dev);
	dev_info(dev, "Device suspended");
	mlx5_device_heap_free(dev);
	return 0;
}

int mlx5_vfio_dev_resume(struct vfio_mlx5_dev *dev,
			 struct page_allocator *page_alloc,
			 const void *vmh_vaddr)
{
	int err;

	err = mlx5_device_heap_alloc(dev, page_alloc, vmh_vaddr);
	if (err) {
		log_error("%s: Failed to allocate mlx5_dev_heap (%s), err(%d)",
			  __func__, dev->pci_bdf, err);
		return err;
	}

	dev_info(dev, "Resuming device");

	err = mlx5_vfio_init_bar0(dev);
	if (err) {
		dev_err(dev, "Failed to init bar0, err(%d)", err);
		goto err_bar0;
	}

	/* TODO: This is not ideal */
	dev->async_eq.vaddr = iova2vaddr(dev, dev->async_eq.iova);
	dev->async_eq.doorbell = mlx5_dev_eq_doorbell_addr(dev);

	dev->state = VFIO_MLX5_STATE_ENABLED;

	err = vfio_mlx5_events_process(dev);
	if (err < 0) {
		dev_warn(
			dev,
			"Device resumed, but failed to process events, err(%d)",
			err);
	} else {
		dev_info(dev, "Device resumed, %d events processed", err);
	}

	return 0;

err_bar0:
	mlx5_device_heap_free(dev);
	return err;
}

int vfio_mlx5_dev_stats(struct vfio_mlx5_dev *dev, struct mlx5_dev_stats *stats)
{
	int ret;

	/* Get page allocation statistics */

	stats->dev_info.state = dev->state;
	stats->dev_info.index = dev->index;
	stats->dev_info.num_vfs = dev->num_vfs;

	mlx5_vfio_page_alloc_stats(dev->page_alloc, &stats->page_stats);
	memcpy(stats->page_events, dev->page_events, sizeof(mlx5_pg_events_t));
	stats->firmware_pages = dev->firmware_pages;
	stats->driver_pages = stats->page_stats.total_pages -
			      stats->page_stats.free_pages -
			      stats->firmware_pages;
	memcpy(stats->cmd_stats, dev->cmd_stats, sizeof(mlx5_cmd_stats_t));
	ret = mlx5_vfio_poll_health(dev);
	stats->health_rec = dev->health_rec;
	return ret;
}

void vfio_mlx5_dev_stats_dump(struct vfio_mlx5_dev *dev,
			      struct mlx5_dev_stats *stats)
{
	struct mlx5_dev_stats tmp_stats;
	uint64_t timestamp = time(NULL);
	char time_str[64];

	if (!stats) {
		vfio_mlx5_dev_stats(dev, &tmp_stats);
		stats = &tmp_stats;
	}

	/* Convert timestamp to readable format */
	const struct tm *tm_info = localtime((time_t *)&timestamp);

	strftime(time_str, sizeof(time_str), "%Y-%m-%d %H:%M:%S", tm_info);

	/* don't use dev_info, as it device might be suspended or disabled */
	log_info("=== VFIO MLX5 Device Statistics ===");
	log_info("Timestamp: %s", time_str);
	log_info("Device: %s", dev->pci_bdf);
	log_info("State: %d", stats->dev_info.state);
	log_info("Index: %d", stats->dev_info.index);
	log_info("Num VFs: %d", stats->dev_info.num_vfs);

	log_debug("--- Cmd Stats ---");
	for (unsigned int i = 0; i < CMD_EVENT_COUNT; i++)
		log_debug("%-19s: %lu", cmd_event_names[i],
			  stats->cmd_stats[i]);

	size_t fwpages_event_count;
	const char **fwpages_event_name =
		mlx5_vfio_fwpages_event_names(&fwpages_event_count);

	log_info("--- Events Stats ---");
	for (unsigned int i = 0; i < fwpages_event_count; i++)
		log_info("%-19s: %lu", fwpages_event_name[i],
			 stats->page_events[i]);

	log_info("--- Page Allocation Stats ---");
	log_info("Total Pages        : %lu", stats->page_stats.total_pages);
	log_info("Free Pages         : %lu", stats->page_stats.free_pages);
	log_info("Used Pages         : %lu",
		 stats->page_stats.total_pages - stats->page_stats.free_pages);
	log_info("Allocations        : %lu", stats->page_stats.allocs);
	log_info("Frees              : %lu", stats->page_stats.frees);
	log_info("Double Frees       : %lu", stats->page_stats.double_frees);
	log_info("Failed Allocs      : %lu", stats->page_stats.allocs_failed);

	log_info("--- FW Pages Stats ---");
	log_info("FW Pages           : %lu", stats->firmware_pages);
	log_info("Driver Pages       : %lu", stats->driver_pages);

	/* Add utilization if valid */
	if (stats->page_stats.total_pages > 0) {
		double utilization = (double)(stats->page_stats.total_pages -
					      stats->page_stats.free_pages) /
				     stats->page_stats.total_pages * 100.0;
		log_info("Page Utilization   : %.2f%%", utilization);
	}

	log_info("====================================");
}
