// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#include <errno.h>
#include <stdint.h>
#include <stdlib.h>
#include <unistd.h>

#include "hcacap.h"
#include "cmd.h"
#include "dev.h"
#include "util/util.h"
#include "vfio_mlx5.h"

int mlx5_get_cap(struct vfio_mlx5_dev *dev, enum mlx5_cap_type cap_type,
		 enum mlx5_cap_mode cap_mode, uint8_t func_id, void *out)
{
	uint16_t opmod = (cap_type << 1) | (cap_mode & 0x01);
	uint8_t in[DEVX_ST_SZ_BYTES(query_hca_cap_in)] = {};
	int out_sz = DEVX_ST_SZ_BYTES(query_hca_cap_out);

	DEVX_SET(query_hca_cap_in, in, opcode, MLX5_CMD_OP_QUERY_HCA_CAP);
	DEVX_SET(query_hca_cap_in, in, op_mod, opmod);
	DEVX_SET(query_hca_cap_in, in, other_function, func_id ? 1 : 0);
	DEVX_SET(query_hca_cap_in, in, function_id, func_id);

	return mlx5_vfio_cmd_exec(dev, in, sizeof(in), out, out_sz);
}

static int mlx5_set_cap(struct vfio_mlx5_dev *dev, enum mlx5_cap_type cap_type,
			uint8_t func_id, void *in)
{
	uint16_t opmod = (cap_type << 1);
	uint8_t out[DEVX_ST_SZ_BYTES(set_hca_cap_out)] = {};
	size_t set_size = DEVX_ST_SZ_BYTES(set_hca_cap_in);

	/* reset only mbox header, caps to set are in in->capability */
	memset(in, 0, DEVX_ST_SZ_BYTES(mbox_in));
	DEVX_SET(set_hca_cap_in, in, opcode, MLX5_CMD_OP_SET_HCA_CAP);
	DEVX_SET(set_hca_cap_in, in, op_mod, opmod);
	DEVX_SET(set_hca_cap_in, in, other_function, func_id ? 1 : 0);
	DEVX_SET(set_hca_cap_in, in, function_id, func_id);

	return mlx5_vfio_cmd_exec(dev, in, set_size, out, sizeof(out));
}

/* we use the same out buffer for both query and set, let's make it clear they
 * have same size and same offset for capability field
 */
_Static_assert(DEVX_ST_SZ_BYTES(query_hca_cap_in) ==
		       DEVX_ST_SZ_BYTES(set_hca_cap_out),
	       "query_hca_cap_in and set_hca_cap_out must be the same size");

_Static_assert(
	DEVX_BYTE_OFF(query_hca_cap_out, capability) ==
		DEVX_BYTE_OFF(set_hca_cap_in, capability),
	"query_hca_cap_out and set_hca_cap_in must have the same offset for capability");

/* buf: must be query_hca_cap_out or set_hca_cap_in which are the same size */
static int handle_hca_cap_general(struct vfio_mlx5_dev *dev, void *buf)
{
	int sys_page_shift = ilog32(sysconf(_SC_PAGESIZE) - 1);
	bool uar_4k = false;
	void *hca_cap;
	int ret;

	/* Read max HCA capabilities to find out supported features */
	ret = mlx5_get_cap(dev, MLX5_CAP_GEN, HCA_CAP_OPMOD_MAX, 0, buf);
	if (ret) {
		log_error("Failed to get general capabilities, ret(%d)", ret);
		return ret;
	}
	hca_cap = DEVX_ADDR_OF(query_hca_cap_out, buf, capability);

	dev->caps.general.fast_teardown =
		!!DEVX_GET(cmd_hca_cap, hca_cap, fast_teardown);
	uar_4k = !!DEVX_GET(cmd_hca_cap, hca_cap, uar_4k);

	memset(buf, 0, DEVX_ST_SZ_BYTES(query_hca_cap_out));
	/* Read current HCA capabilities to keep enabled features */
	ret = mlx5_get_cap(dev, MLX5_CAP_GEN, HCA_CAP_OPMOD_CUR, 0, buf);
	if (ret) {
		log_error("Failed to get general capabilities, ret(%d)", ret);
		return ret;
	}

	/* set_cmd already points to buf; */
	hca_cap = DEVX_ADDR_OF(query_hca_cap_out, buf, capability);

	/* overwrite the fields we want to change */
	DEVX_SET(cmd_hca_cap, hca_cap, cmdif_checksum, 0);

	DEVX_SET(cmd_hca_cap, hca_cap, disable_link_up_by_init_hca, 1);

	/* Enable 4K UAR only when HCA supports it */
	if (uar_4k && sysconf(_SC_PAGESIZE) > 4096) {
		dev_info(dev, "Enabling 4K UAR on system with page size(%d)",
			 sysconf(_SC_PAGESIZE));
		DEVX_SET(cmd_hca_cap, hca_cap, uar_4k, 1);
		dev->caps.general.uar_4k = true;
	}

	DEVX_SET(cmd_hca_cap, hca_cap, log_uar_page_sz, sys_page_shift - 12);

	DEVX_SET(cmd_hca_cap, hca_cap, fast_teardown,
		 dev->caps.general.fast_teardown);

	return mlx5_set_cap(dev, MLX5_CAP_GEN, 0, buf);
}

static void query_hca_cap_2(struct vfio_mlx5_dev *dev, void *buf)
{
	int ret;

	ret = mlx5_get_cap(dev, MLX5_CAP_GEN_2, HCA_CAP_OPMOD_CUR, 0, buf);
	if (ret) {
		log_error("Failed to get caps2 capabilities, ret(%d)", ret);
		return;
	}

	const void *caps2 = DEVX_ADDR_OF(query_hca_cap_out, buf, capability);

	dev->caps.cap2.query_adjacent_functions_id =
		!!DEVX_GET(cmd_hca_cap_2, caps2, query_adjacent_functions_id);

	dev->caps.cap2.delegate_vhca_management_profiles = !!DEVX_GET(
		cmd_hca_cap_2, caps2, delegate_vhca_management_profiles);

	dev->caps.cap2.delegated_vhca_max =
		DEVX_GET(cmd_hca_cap_2, caps2, delegated_vhca_max);

	dev->caps.cap2.delegate_vhca_max =
		DEVX_GET(cmd_hca_cap_2, caps2, delegate_vhca_max);
}

int mlx5_vfio_set_hca_cap(struct vfio_mlx5_dev *dev)
{
	uint32_t cap_buff[DEVX_ST_SZ_DW(set_hca_cap_in)] = {};
	int ret;

	ret = handle_hca_cap_general(dev, cap_buff);
	if (ret)
		return ret;

	query_hca_cap_2(dev, cap_buff);

	return 0;
}
