/*
 * SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
 * Copyright (c) 2001-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: GPL-2.0-only or BSD-2-Clause
 */


#ifndef UTILS_H
#define UTILS_H

#include <chrono>
#include <time.h>
#include <string>
#include <string.h>
#include <ifaddrs.h>
#include <linux/if_ether.h>
#include <exception>

#include "vtypes.h"
#include "vlogger/vlogger.h"
#include "vma/proto/mem_buf_desc.h"
#include "vma/util/vma_stats.h"

struct iphdr; //forward declaration

#define VMA_ALIGN(x, y) ((((x) + (y) - 1) / (y)) * (y) )

#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0]))

/**
* Check if file type is regular
**/
int check_if_regular_file (char *path);

/**
 * L3 and L4 Header Checksum Calculation
 */
void compute_tx_checksum(mem_buf_desc_t* p_mem_buf_desc, bool l3_csum, bool l4_csum);

/**
 * IP Header Checksum Calculation
 */
unsigned short compute_ip_checksum(const unsigned short *buf, unsigned int nshort_words);

/**
* get tcp checksum: given IP header and tcp segment (assume checksum field in TCP header contains zero)
* matches RFC 793
*/
unsigned short compute_tcp_checksum(const struct iphdr *p_iphdr, const uint16_t *p_ip_payload);

/**
* get udp checksum: given IP header and UDP datagram (assume checksum field in UDP header contains zero)
* matches RFC 793
*/
unsigned short compute_udp_checksum_rx(const struct iphdr *p_iphdr, const struct udphdr *udphdrp, mem_buf_desc_t* p_rx_wc_buf_desc);

/**
 * get user space max number of open fd's using getrlimit, default parameter equals to 1024
 */

int get_sys_max_fd_num(int def_max_fd=1024);

/**
 * iovec extensions
 * Returns total bytes copyed
 */
int memcpy_fromiovec(u_int8_t* p_dst, const struct iovec* p_iov, size_t sz_iov, size_t sz_src_start_offset, size_t sz_data);

/**
 * get base interface from an aliased/vlan tagged one. i.e. eth2:1 --> eth2 / eth2.1 --> eth2
 * Functions gets:interface name,output variable for base interface,output size; and returns the base interface
 */
int get_base_interface_name(const char *if_name, char *base_ifname, size_t sz_base_ifname);

/**
 * Count bitmark set bits
 */
int netmask_bitcount(uint32_t netmask);


/** 
 * Set the fd blocking mode 
 * @param fd the file descriptor on which to operate 
 * @param block 'true' to set to block 
 *              'false' to set to non-blocking
 */
void set_fd_block_mode(int fd, bool block);

/**
 * @param a number
 * @param b number
 * @return true if 'a' and 'b' are equal. else false.
 */
bool compare_double(double a, double b);

/** 
 * Run a system command while bypassing LD_PRELOADed with VMA 
 * @param cmd_line to be exceuted wiout VMA in process space
 * @param return_str is the output of the system call
 */
int run_and_retreive_system_command(const char* cmd_line, char* return_str, int return_str_len);

const char* iphdr_protocol_type_to_str(const int type);

/**
 * Read content of file detailed in 'path' (usually a sysfs file) and
 * store the file content into the given 'buf' up to 'size' characters.
 * print log in case of failure according to the given 'log_level' argument.
 * @return length of content that was read, or -1 upon any error
 */
int priv_read_file(const char *path, char *buf, size_t size, vlog_levels_t log_level = VLOG_ERROR);

/**
 * like above 'priv_read_file' however make sure that upon success the result in buf is a null terminated string
 */
inline int priv_safe_read_file(const char *path, char *buf, size_t size, vlog_levels_t log_level = VLOG_ERROR){
	int ret = -1;
	if (size > 0) {
		ret = priv_read_file(path, buf, size - 1, log_level);
		if (0 <= ret) buf[ret] = '\0';
	}
	return ret;
}


/**
 * like above however make sure that upon success the result in buf is a null terminated string and VLOG_DEBUG
 */
inline int priv_safe_try_read_file(const char *path, char *buf, size_t size) {
	int ret = -1;
	if (size > 0) {
		ret = priv_read_file(path, buf, size - 1, VLOG_DEBUG);
		if (0 <= ret) buf[ret] = '\0';
	}
	return ret;
}

/**
 * Read content of file detailed in 'path' (usually a sysfs file)
 * upon failure print error
 * @return int value (atoi) of the file content, or 'default_value' upon failure
 */
int read_file_to_int(const char *path, int default_value);

/** 
 * Get interface name and flags from local address
 * 
 * @char ifname[IFNAMSIZ]; 
 * @unsigned int ifflags; Flags as from SIOCGIFFLAGS ioctl. 
 *  
 * @return zero on success
 */
int get_ifinfo_from_ip(const struct sockaddr& local_addr, char* ifname, uint32_t &ifflags);

/**
 * Get port number from interface name
 * @param ifname input interface name of device (e.g. eth1, ib2)
 *  should be of size IFNAMSIZ
 * @return zero on failure, else port number
 */
int get_port_from_ifname(const char* ifname);

/** 
 * Get interface type value from interface name
 * 
 * @param ifname input interface name of device (e.g. eth1, ib2)
 *  should be of size IFNAMSIZ
 * @return if type on success or -1 on failure
 */
int get_iftype_from_ifname(const char* ifname);

/**
 * Get interface mtu from interface name
 *
 * @param ifname input interface name of device (e.g. eth1, ib2)
 *  should be of size IFNAMSIZ
 * @return mtu length zero on failure
 */
int get_if_mtu_from_ifname(const char* ifname);

/**
 * Get the OS TCP window scaling factor when tcp_window_scaling is enabled.
 * The value is calculated from the maximum receive buffer value.
 *
 * @param tcp_rmem_max the maximum size of the receive buffer used by each TCP socket
 * @parma core_rmem_max contains the maximum socket receive buffer size in bytes which a user may set by using the SO_RCVBUF socket option.
 *
 * @return TCP window scaling factor
 */
int get_window_scaling_factor(int tcp_rmem_max, int core_rmem_max);

/**
 * Get Ethernet ipv4 address from interface name
 *
 * @param ifname input interface name of device (e.g. eth1, ib2)
 *  should be of size IFNAMSIZ
 * @param sockaddr_in output interface inet address
 *
 * @return -1 on failure
 */
int get_ipv4_from_ifname(char *ifname, struct sockaddr_in *addr);

/**
 * Get Ethernet ipv4 address from interface index
 *
 * @param ifindex input interface index of device
 * @param sockaddr_in output interface inet address
 *
 * @return -1 on failure
 */
int get_ipv4_from_ifindex(int ifindex, struct sockaddr_in *addr);

/** 
 * Get vlan id from interface name
 * 
 * @param ifname input interface name of device (e.g. eth2, eth2.5)
 * @return the vlan id or 0 if not a vlan
 */
uint16_t get_vlan_id_from_ifname(const char* ifname);

/** 
 * Get vlan base name from interface name
 *
 * @param ifname input interface name of device (e.g. eth2, eth2.5)
 * @param base_ifname output base interface name of device (e.g. eth2)
 * @param sz_base_ifname input the size of base_ifname param
 * @return the vlan base name length or 0 if not a vlan
 */
size_t get_vlan_base_name_from_ifname(const char* ifname, char* base_ifname, size_t sz_base_ifname);

/* Upon success - returns the actual address len in bytes; Upon error - returns zero*/
size_t get_local_ll_addr(const char* ifname, unsigned char* addr, int addr_len,  bool is_broadcast);

/* Print warning while RoCE Lag is enabled */
void print_roce_lag_warnings(char* interface, const char* port1 = NULL, const char* port2 = NULL);

bool check_bond_device_exist(const char* ifname);
bool get_bond_active_slave_name(IN const char* bond_name, OUT char* active_slave_name, IN int sz);
bool get_bond_slave_state(IN const char* slave_name, OUT char* curr_state, IN int sz);
bool get_bond_slaves_name_list(IN const char* bond_name, OUT char* slaves_list, IN int sz);
bool check_device_exist(const char* ifname, const char *path);
bool check_device_name_ib_name(const char* ifname, const char* ibname);
bool check_netvsc_device_exist(const char* ifname);
bool get_netvsc_slave(IN const char* ifname, OUT char* slave_name, OUT unsigned int &slave_flags);
bool get_interface_oper_state(IN const char* interface_name, OUT char* slaves_list, IN int sz);
bool validate_user_has_cap_net_raw_privliges();

size_t default_huge_page_size(void);

static inline int get_procname(int pid, char *proc, size_t size)
{
	char app_full_name[PATH_MAX] = {0};
	char proccess_proc_dir[FILE_NAME_MAX_SIZE] = {0};
	char* app_base_name = NULL;
	int n;

	if (NULL == proc) {
		return -1;
	}

	n = snprintf(proccess_proc_dir, sizeof(proccess_proc_dir), "/proc/%d/exe", pid);
	if (likely((0 < n) && (n < (int)sizeof(proccess_proc_dir)))) {
		n = readlink(proccess_proc_dir, app_full_name, sizeof(app_full_name) - 1);
		if (n > 0) {
			app_full_name[n] = '\0';
			app_base_name = strrchr(app_full_name, '/');
			if (app_base_name) {
				strncpy(proc, app_base_name + 1, size - 1);
				proc[size - 1] = '\0';
				return 0;
			}
		}
	}

	return -1;
}

static inline in_addr_t prefix_to_netmask(int prefix_length)
{
    in_addr_t mask = 0;

    if (prefix_length <= 0 || prefix_length > 32) {
        return 0;
    }
    mask = ~mask << (32 - prefix_length);
    mask = htonl(mask);
    return mask;
}

//Creates multicast MAC from multicast IP
//inline void create_multicast_mac_from_ip(uint8_t (& mc_mac) [6], in_addr_t ip)
inline void create_multicast_mac_from_ip(unsigned char* mc_mac, in_addr_t ip)
{
	if(mc_mac == NULL)
		return;

	mc_mac[0] = 0x01;
	mc_mac[1] = 0x00;
	mc_mac[2] = 0x5e;
	mc_mac[3] = (uint8_t)((ip>> 8)&0x7f);
	mc_mac[4] = (uint8_t)((ip>>16)&0xff);
	mc_mac[5] = (uint8_t)((ip>>24)&0xff);
}

/**
 * special design for the rx loop. 
 */
class loops_timer {
public:
	loops_timer();
	void start();
	int  time_left_msec();
	void set_timeout_msec(int timeout_msec) { m_timeout_msec = timeout_msec; }
	int  get_timeout_msec() { return m_timeout_msec; }
	bool is_timeout();
	
private:
	std::chrono::time_point<std::chrono::steady_clock> m_start;
	std::chrono::milliseconds m_elapsed;
	std::chrono::time_point<std::chrono::steady_clock> m_current;
	int m_interval_it;
	int m_timer_countdown;
	int m_timeout_msec;
};

// Returns the filesystem's inode number for the given 'fd' using 'fstat' system call that assumes 32 bit inodes
// This should be safe for 'proc' filesytem and for standard filesystems
uint32_t fd2inode(int fd);


/**
 * @class vma_error
 *
 * base class for vma exceptions classes.
 * Note: VMA code should NOT catch vma_error; VMA code should only catch exceptions of derived classes
 */
class vma_error : public std::exception {
	char formatted_message[512];
public:
	const char * const message;
	const char * const function;
	const char * const filename;
	const int lineno;
	const int errnum;

	/**
	 * Create an object that contains const members for all the given arguments, plus a formatted message that will be
	 * available thru the 'what()' method of base class.
	 *
	 * The formatted_message will look like this:
	 * 		"vma_error <create internal epoll> (errno=24 Too many open files) in sock/sockinfo.cpp:61"
	 * catcher can print it to log like this:
	 * 		fdcoll_loginfo("recovering from %s", e.what());
	 */
	vma_error(const char* _message, const char* _function, const char* _filename, int _lineno, int _errnum) throw();

	virtual ~vma_error() throw();

	virtual const char* what() const throw();

};

/**
 * @class vma_exception
 * NOTE: ALL exceptions that can be caught by VMA should be derived of this class
 */
class vma_exception : public vma_error {
public:
	vma_exception(const char* _message, const char* _function, const char* _filename, int _lineno, int _errnum) throw() :
		vma_error(_message, _function, _filename, _lineno, _errnum)
	{
	}
};


#define create_vma_exception_class(clsname, basecls) \
	class clsname : public basecls { \
	public: \
	clsname(const char* _message, const char* _function, const char* _filename, int _lineno, int _errnum) throw() : \
		basecls(_message, _function, _filename, _lineno, _errnum) {} \
	}

create_vma_exception_class(vma_unsupported_api, vma_error);

#define throw_vma_exception(msg) throw vma_exception(msg, __PRETTY_FUNCTION__, __FILE__, __LINE__, errno)
// uses for throwing  something that is derived from vma_error and has similar CTOR; msg will automatically be class name
#define vma_throw_object(_class)  throw _class(#_class, __PRETTY_FUNCTION__, __FILE__, __LINE__, errno)
#define vma_throw_object_with_msg(_class, _msg)  throw _class(_msg, __PRETTY_FUNCTION__, __FILE__, __LINE__, errno)

/* Rounding up to nearest power of 2 */
static inline uint32_t align32pow2(uint32_t x)
{
	x--;
	x |= x >> 1;
	x |= x >> 2;
	x |= x >> 4;
	x |= x >> 8;
	x |= x >> 16;

	return x + 1;
}


static inline int ilog_2(uint32_t n) {
	if (n == 0)
		return 0;

	uint32_t t = 0;
	while ((1 << t) < (int)n)
		++t;

	return (int)t;
}

#endif
