// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#include <sys/errno.h>
#include <stdbool.h>

#include "dev.h"
#include "ifc.h"

int mlx5_func_enable_hca(struct vfio_mlx5_dev *dev, uint16_t func_id)
{
	uint32_t out[DEVX_ST_SZ_DW(enable_hca_out)] = {};
	uint32_t in[DEVX_ST_SZ_DW(enable_hca_in)] = {};
	int ret;

	DEVX_SET(enable_hca_in, in, opcode, MLX5_CMD_OP_ENABLE_HCA);
	DEVX_SET(enable_hca_in, in, function_id, func_id);

	ret = mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, sizeof(out));
	if (ret)
		dev_err(dev, "%s failed func_id(%d), err(%d)", __func__,
			func_id, ret);

	return ret;
}

static void mlx5_func_disable_hca(struct vfio_mlx5_dev *dev, uint16_t func_id)
{
	uint32_t out[DEVX_ST_SZ_DW(disable_hca_out)] = {};
	uint32_t in[DEVX_ST_SZ_DW(disable_hca_in)] = {};
	int ret;

	DEVX_SET(disable_hca_in, in, opcode, MLX5_CMD_OP_DISABLE_HCA);
	DEVX_SET(disable_hca_in, in, function_id, func_id);

	ret = mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, sizeof(out));
	if (ret)
		dev_err(dev, "%s failed func_id(%d), err(%d)", __func__,
			func_id, ret);
}

static void mlx5_modify_vport_admin_state(struct vfio_mlx5_dev *dev,
					  uint8_t opmod, uint16_t vport,
					  uint8_t other_vport, uint8_t state)
{
	uint32_t out[DEVX_ST_SZ_DW(modify_vport_state_out)] = {};
	uint32_t in[DEVX_ST_SZ_DW(modify_vport_state_in)] = {};

	DEVX_SET(modify_vport_state_in, in, opcode,
		 MLX5_CMD_OP_MODIFY_VPORT_STATE);
	DEVX_SET(modify_vport_state_in, in, op_mod, opmod);
	DEVX_SET(modify_vport_state_in, in, vport_number, vport);
	DEVX_SET(modify_vport_state_in, in, other_vport, other_vport);
	DEVX_SET(modify_vport_state_in, in, admin_state, state);

	mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, sizeof(out));
}

static int mlx5_func_delegate_vhca(struct vfio_mlx5_dev *dev,
				   uint16_t func_vhca_id, uint16_t adj_vhca_id,
				   uint16_t management_profile);

#define MGMT_PROFILE_1 0x1

static int mlx5_func_delegate_set(struct vfio_mlx5_dev *dev, uint16_t func_id,
				  uint16_t mgmt_profile)
{
	uint32_t out[DEVX_ST_SZ_DW(query_hca_cap_out)] = {};
	const void *vf_cap;
	int err, i;

	err = mlx5_get_cap(dev, MLX5_CAP_GEN, HCA_CAP_OPMOD_CUR, func_id, out);
	if (err)
		return err;

	vf_cap = DEVX_ADDR_OF(query_hca_cap_out, out, capability);
	uint16_t vf_vhca_id = DEVX_GET(cmd_hca_cap, vf_cap, vhca_id);

	for (i = 0; i < dev->adj_pfs_count; i++) {
		uint8_t slot = (dev->adj_pfs[i].pci_function & 0xff) >> 3;
		uint8_t function = dev->adj_pfs[i].pci_function & 0x7;

		dev_info(
			dev,
			"%s vf[%d]: vhca_id 0x%x %s adj_pf[%d]: vhca_id 0x%x pci: %02x:%02x.%x",
			mgmt_profile ? "Delegating" : "Un-delegating",
			func_id - 1, vf_vhca_id, mgmt_profile ? "to" : "from",
			i, dev->adj_pfs[i].vhca_id, dev->adj_pfs[i].pci_bus,
			slot, function);

		err = mlx5_func_delegate_vhca(
			dev, vf_vhca_id, dev->adj_pfs[i].vhca_id, mgmt_profile);
		if (err)
			goto teardown;
	}

	return 0;

teardown:
	while (i--)
		mlx5_func_delegate_vhca(dev, vf_vhca_id,
					dev->adj_pfs[i].vhca_id, 0);
	return err;
}

static int mlx5_func_setup_vf(struct vfio_mlx5_dev *dev, unsigned int vf)
{
	uint16_t func_id = vf + 1;
	int err;

	mlx5_modify_vport_admin_state(dev, MLX5_VPORT_STATE_OP_MOD_ESW_VPORT,
				      func_id, 1, 1);
	if (dev->adj_pfs_count) {
		err = mlx5_func_delegate_set(dev, func_id, MGMT_PROFILE_1);
		if (err)
			return err;
	}

	err = mlx5_func_enable_hca(dev, func_id);
	if (err)
		goto teardown;

	return 0;

teardown:
	if (dev->adj_pfs_count)
		mlx5_func_delegate_set(dev, func_id, 0);

	return err;
}

static void mlx5_func_teardown_vf(struct vfio_mlx5_dev *dev, unsigned int vf)
{
	uint16_t func_id = vf + 1;

	mlx5_func_disable_hca(dev, func_id);
	if (dev->adj_pfs_count)
		mlx5_func_delegate_set(dev, func_id, 0);
}

static int mlx5_func_query_adjacent_pfs(struct vfio_mlx5_dev *dev);

int mlx5_func_setup_vfs(struct vfio_mlx5_dev *dev, unsigned int nvfs)
{
	unsigned int vf;
	int err;

	mlx5_func_query_adjacent_pfs(dev);

	if (dev->adj_pfs_count &&
	    nvfs > MLX5_CAP_GEN_2(dev, delegated_vhca_max)) {
		dev_err(dev,
			"Number of vfs %u exceeds max allowed delegated vhca %u",
			nvfs, MLX5_CAP_GEN_2(dev, delegated_vhca_max));
		return -EINVAL;
	}

	for (vf = 0; vf < nvfs; vf++) {
		dev_info(dev, "Setting up vf %d", vf);
		err = mlx5_func_setup_vf(dev, vf);
		if (err)
			goto teardown;
	}

	return 0;

teardown:
	while (vf--)
		mlx5_func_teardown_vf(dev, vf);

	return err;
}

void mlx5_func_teardown_vfs(struct vfio_mlx5_dev *dev, unsigned int nvfs)
{
	unsigned int vf;

	for (vf = 0; vf < nvfs; vf++) {
		dev_info(dev, "Tearing down vf %d", vf);
		mlx5_func_teardown_vf(dev, vf);
	}
}

/* function delegation */

#define MLX5_CMD_OP_QUERY_ADJACENT_FUNCTIONS_ID 0x730
#define MLX5_CMD_OP_DELEGATE_VHCA_MANAGEMENT 0x731
#define MLX5_CMD_OP_QUERY_DELEGATED_VHCA 0x732

static int mlx5_func_delegate_vhca(struct vfio_mlx5_dev *dev,
				   uint16_t func_vhca_id, uint16_t adj_vhca_id,
				   uint16_t management_profile)
{
	uint8_t in[DEVX_ST_SZ_BYTES(delegate_vhca_management_in)] = {};
	uint8_t out[DEVX_ST_SZ_BYTES(delegate_vhca_management_out)] = {};
	int err;

	if (management_profile &&
	    !(MLX5_CAP_GEN_2(dev, delegate_vhca_management_profiles) &
	      management_profile)) {
		dev_err(dev, "Management profile 0x%x not supported",
			management_profile);
		return -EOPNOTSUPP;
	}

	DEVX_SET(delegate_vhca_management_in, in, opcode,
		 MLX5_CMD_OP_DELEGATE_VHCA_MANAGEMENT);
	DEVX_SET(delegate_vhca_management_in, in, op_mod, 0);
	DEVX_SET(delegate_vhca_management_in, in, managed_vhca_id,
		 func_vhca_id);
	DEVX_SET(delegate_vhca_management_in, in, dest_vhca_id, adj_vhca_id);
	DEVX_SET(delegate_vhca_management_in, in, management_profile,
		 management_profile);

	err = mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, sizeof(out));
	if (err)
		dev_err(dev,
			"Failed to %s managed_vhca_id 0x%x dest vhca_id 0x%x, mgmt_profile 0x%x, err(%d)",
			management_profile ? "delegate" : "un-delegate",
			func_vhca_id, adj_vhca_id, management_profile, err);

	return err;
}

#define MLX5_GET_RID_INFO(rid_info, field)                                   \
	DEVX_GET(function_vhca_rid_info_reg,                                 \
		 DEVX_ADDR_OF(adjacent_function_vhca_rid_info_reg, rid_info, \
			      function_vhca_rid_info),                       \
		 field)

#define MLX5_FUNC_TYPE_PCI_PF 0x1

static int mlx5_func_query_adjacent_pfs(struct vfio_mlx5_dev *dev)
{
	uint32_t in[DEVX_ST_SZ_DW(query_adjacent_functions_id_input_reg)] = {};
	uint16_t max_adj_func_ids = 64;
	int outlen, err, i;
	uint8_t *out;

	if (!MLX5_CAP_GEN_2(dev, query_adjacent_functions_id) ||
	    !MLX5_CAP_GEN_2(dev, delegate_vhca_management_profiles))
		return 0;

	outlen = DEVX_ST_SZ_BYTES(query_adjacent_functions_id_output_reg) +
		 max_adj_func_ids *
			 DEVX_ST_SZ_BYTES(adjacent_function_vhca_rid_info_reg);

	out = calloc(outlen, 1);
	if (!out) {
		dev_err(dev,
			"Failed to allocate memory for query_delegated_vhca_out");
		return -ENOMEM;
	}

	DEVX_SET(query_adjacent_functions_id_input_reg, in, opcode,
		 MLX5_CMD_OP_QUERY_ADJACENT_FUNCTIONS_ID);
	DEVX_SET(query_adjacent_functions_id_input_reg, in, function_type,
		 MLX5_FUNC_TYPE_PCI_PF);

	err = mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, outlen);
	if (err)
		goto free_out;

	uint16_t max_count = DEVX_GET(query_adjacent_functions_id_output_reg,
				      out, functions_count);

	dev_info(dev,
		 "Adjacent functions count %u/%u; management profiles 0x%x",
		 max_count, MLX5_CAP_GEN_2(dev, delegate_vhca_max),
		 MLX5_CAP_GEN_2(dev, delegate_vhca_management_profiles));

	for (i = 0; i < max_count; i++) {
		uint8_t *rinfo =
			DEVX_ADDR_OF(query_adjacent_functions_id_output_reg,
				     out, adjacent_function_vhca_rid_info[i]);

		if (MLX5_GET_RID_INFO(rinfo, function_type) !=
		    MLX5_FUNC_TYPE_PCI_PF)
			continue;

		uint16_t vhca_id = MLX5_GET_RID_INFO(rinfo, vhca_id);
		uint16_t function_id = MLX5_GET_RID_INFO(rinfo, function_id);
		uint8_t host_number = MLX5_GET_RID_INFO(rinfo, host_number);
		uint8_t pci_bus = MLX5_GET_RID_INFO(rinfo, host_pci_bus);
		uint8_t pci_bus_assigned =
			MLX5_GET_RID_INFO(rinfo, pci_bus_assigned);
		uint8_t pci_function =
			MLX5_GET_RID_INFO(rinfo, host_pci_device_function);
		uint8_t slot = (pci_function & 0xff) >> 3;
		uint8_t function = pci_function & 0x7;

		dev_info(
			dev,
			"\t pci: %02x:%02x.%x, vhca_id: 0x%x, func_id: 0x%x, host: 0%x, bus_assigned: %d",
			pci_bus, slot, function, vhca_id, function_id,
			host_number, pci_bus_assigned);

		if (i >= MLX5_MAX_ADJ_PFS ||
		    i >= MLX5_CAP_GEN_2(dev, delegate_vhca_max)) {
			dev_warn(
				dev,
				"\tAdjacent functions count %u exceeds max %u/%u, skipping.",
				i + 1, MLX5_MAX_ADJ_PFS,
				MLX5_CAP_GEN_2(dev, delegate_vhca_max));
			break;
		}
		dev->adj_pfs[i].func_id = function_id;
		dev->adj_pfs[i].vhca_id = vhca_id;
		dev->adj_pfs[i].pci_bus = pci_bus;
		dev->adj_pfs[i].pci_function = function;
	}
	dev->adj_pfs_count = i;

free_out:
	free(out);
	return err;
}
