"""
Copyright © 2019-2023 NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
 
This software product is a proprietary product of Nvidia Corporation and its affiliates
(the "Company") and all right, title, and interest in and to the software
product, including all associated intellectual property rights, are and
shall remain exclusively with the Company.

This software product is governed by the End User License Agreement
provided with the software product.
"""
import json
import re
import subprocess
import time
import logging
from datetime import datetime

NA = "N/A"
MESSAGE = "message"
MESSAGE_TYPE = "message_type"
TRANS_MODE = "trans_mode"
FULL_TRANSFORM = 0
PARTIAL_TRANSFORM = 1
SCRIPT_TAG = "[collect_host_inventory]"
LEGACY_MODE = False
TIMEOUT = 30
TIMEOUT_RETRY_DELAY = 60*60
DEVICE_TYPE = "device_type"

LOG_LEVEL_OPTION = "log-level"
LEGACY_MODE_OPTION = "legacy-mode"
TIMEOUT_RETRY_DELAY_OPTION = "timeout-retry-delay"

# collector sampling interval
NETWORK_INTERFACES_SAMPLING_OPTION = "network-interfaces-sampling-interval"
IB_DEVICES_SAMPLING_OPTION = "ib-device-sampling-interval"
LSHW_NETWORK_SAMPLING_OPTION = "lshw-network-sampling-interval"
LSPCI_SAMPLING_OPTION = "lspci-sampling-interval"
ETHTOOL_SAMPLING_OPTION = "ethtool-sampling-interval"
ROCE_SAMPLING_OPTION = "roce-sampling-interval"
NODE_SAMPLING_OPTION = "node-sampling-option"

# enabled options
ENABLED = "enabled"

OPTION_NAME_TO_DESCRIPTION = {
    LOG_LEVEL_OPTION: "set logging level debug >= 7, info >= 6, warning >= 4, error >= 3",
    NETWORK_INTERFACES_SAMPLING_OPTION: "set collection sampling time interval of network interfaces message. use [s/m/h/d] suffix",
    IB_DEVICES_SAMPLING_OPTION: "set collection sampling time interval of InfiniBand devices message. use [s/m/h/d] suffix",
    LSHW_NETWORK_SAMPLING_OPTION: "set collection sampling time interval of lshw network message. use [s/m/h/d] suffix",
    LSPCI_SAMPLING_OPTION: "set collection sampling time interval of lspci message. use [s/m/h/d] suffix",
    ETHTOOL_SAMPLING_OPTION: "set collection sampling time interval of ethtool message. use [s/m/h/d] suffix",
    ROCE_SAMPLING_OPTION: "set collection sampling time interval of roce message. use [s/m/h/d] suffix",
    NODE_SAMPLING_OPTION: "set collection sampling time interval of node message. use [s/m/h/d] suffix",

    ENABLED: "enabled message types to collect"
}

# global
SECTION_TO_COLLECTOR = {}
SCRIPT_START_TIME = time.time()
LOGGER = logging.getLogger('collect_host_inventory')

def arg_to_bool(arg):
    ret = False
    if isinstance(arg, str):
        arg = arg.lower()
        ret = arg in ("t", "true", "1")
    elif isinstance(arg, bool):
        ret = arg
    elif arg:
        ret = True
    return ret

def bool_from_dict(d, arg, default):
    val = d.get(arg, default)
    ret = arg_to_bool(val)
    return ret

def int_from_dict(d, arg, default):
    val = d.get(arg, default)
    if isinstance(val, str):
        try:
            val = int(val)
        except Exception:
            val = default
    return val

def get_command_output(command, check_rc=True):
    if not hasattr(get_command_output, 'delayed_commands'):
        get_command_output.delayed_commands = {}

    global TIMEOUT, TIMEOUT_RETRY_DELAY

    delayed_command = get_command_output.delayed_commands.get(command)
    if delayed_command:
        if time.time() - delayed_command < TIMEOUT_RETRY_DELAY:
            return None
        else:
            del get_command_output.delayed_commands[command]
    try:
        output = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=TIMEOUT)
    except subprocess.TimeoutExpired:
        LOGGER.error("got timeout for command %s, will not run it for the next %d seconds", command, TIMEOUT_RETRY_DELAY)
        get_command_output.delayed_commands[command] = time.time()
        return None
    if check_rc:
        try:
            output.check_returncode()
        except subprocess.CalledProcessError:
            return None
    return output.stdout


def get_tag():
    return "host_inventory"


def get_options_description():
    global OPTIONS_NAME_TO_DESCRIPTION
    return OPTIONS_NAME_TO_DESCRIPTION


def get_hostname():
    hostname = get_command_output("hostname")
    if hostname is None:
        hostname = NA
    return hostname.strip()

class GeneralDataCollector():
    """
    Simple utility collector to pass general metadata for all messages.
    some fields are required per message and some are per data item
    """
    GENERAL_METADATA_KEYS = {"timestamp", "hostname"}
    hostname = NA
    device_type = "host"

    @classmethod
    def get_general_metadata(cls):
        """
        same keys as in GeneralDataCollector.GENERAL_METADATA_KEYS
        """
        now = time.time()
        now = round(now * 1e6)  # convert to milliseconds
        gm = {
            "timestamp": now,
            "hostname": cls.hostname,
        }
        if LEGACY_MODE:
            gm[DEVICE_TYPE] = cls.device_type
        return gm

    @classmethod
    def get_header_metadata(cls):
        """
        metadata for message headers
        """
        id_key = "hostname"
        return {
            # "aid" is an identifier index. here using machine hostname
            id_key: cls.hostname,
            TRANS_MODE: PARTIAL_TRANSFORM,
            DEVICE_TYPE: cls.device_type,
        }


class TimerException(Exception):
    pass


class Timer:
    """
    Timing functionalities - All collectors have a collection time interval
    """

    def __init__(self, interval):
        """
        interval is a string of an integer with last char of units
        s - seconds
        m - minutes
        h - hours
        d - days
        """
        self._last_ts = -1
        self._interval = -1
        self.set_interval(interval)

    def set_interval(self, interval):
        try:
            self._interval = int(interval[:-1])
            units = interval[-1].lower()
        except:
            raise TimerException("invalid input")
        if units == "s":
            # secs
            pass
        elif units == "m":
            # minutes
            self._interval *= 60
        elif units == "h":
            # hour
            self._interval *= 60 * 60
        elif units == "d":
            # days
            self._interval *= 60 * 60 * 24
        else:
            raise TimerException("invalid input")

    def passed(self):
        now = time.time()
        return now >= self._last_ts + self._interval

    def reset(self):
        self._last_ts = time.time()


class TimerCollector:
    """
    abstract class for collectors with a fixed time collection interval that return diff only.
    """

    def __init__(self, interval, post_collection_func=None):
        """
        post_collection_func: to run on collected data - returns resolved data dict with metadata
        """
        self._timer = Timer(interval)
        self._post_collection_func = None
        if post_collection_func:
            self._post_collection_func = post_collection_func

    def collect(self, section_name):
        """
        main method to run collection and do post collection processing.
        """
        data = {}
        if self._timer.passed():
            self._timer.reset()
            gm = GeneralDataCollector.get_general_metadata()
            msg = self._collect(gm)
            if msg:
                metadata = GeneralDataCollector.get_header_metadata()
                if self._post_collection_func:
                    msg, metadata = self._post_collection_func(msg, metadata, section_name)
                else:
                    metadata[TRANS_MODE] = FULL_TRANSFORM
                    metadata[MESSAGE_TYPE] = section_name
                if msg:
                    data = metadata
                    if isinstance(msg, list):
                        data[MESSAGE] = msg
                    else:
                        data[MESSAGE] = [msg]
        return data

    def _collect(self, gm):
        raise NotImplementedError("Must override _collect")


class UtilMethods:

    @classmethod
    def get_map_bus_interface(cls):
        data = {}
        port_buses = LSPCICollector.get_hca_buses()
        for bus in port_buses:
            interface = ''
            bus_path = bus.replace(r':', r'\:')
            output = get_command_output(command=f'ls /sys/bus/pci/devices/{bus_path}/net/')
            if output is not None:
                # output example :
                # ls /sys/bus/pci/devices/0000\:03\:00.0/net/
                # en3f0pf0sf0  p0  pf0hpf
                interfaces = re.findall(r'\w+', output)
                for interface in interfaces:
                    output = get_command_output(command=f'ethtool -i {interface}')
                    if output is not None:
                        match = re.search(r'driver:\s*(mlx\d+_(core|en|ib))', output)
                        if match:
                            data.update({bus: interface})
                            break
        return data

    @classmethod
    def get_map_bus_rdma_and_interface(cls):

        data = []
        bus_interface = cls.get_map_bus_interface()
        for bus, interface in bus_interface.items():
            output = get_command_output(command=f'ls /sys/class/net/{interface}/device/infiniband/')
            if output is not None:
                # ls /sys/class/net/p0/device/infiniband/
                # mlx5_0
                match = re.search(r'(mlx\d+_\d+)',output)
                if match:
                    rdma_device = match.group(1)
                    data.append({"bus": bus, "rdma": rdma_device, "interface": interface})

        return data

    @classmethod
    def get_mlxconfig_query(cls, bus):
        output = ''
        res = get_command_output(command=f'mlxconfig -d {bus} q')
        if res is not None:
            output = res
        return output


class IbDevicesCollector(TimerCollector):
    """
    Collector for ib devices info
    """

    def __init__(self, config):
        sampling_interval = config.get(IB_DEVICES_SAMPLING_OPTION, "20m")
        super().__init__(sampling_interval)

    def _collect(self, gm):
        ip_addr_info = self.get_ibstat_info()
        if not ip_addr_info:
            LOGGER.error("cannot collect ibstat")
            return {}
        data = gm
        data.update({"devices": ip_addr_info})
        return data

    @classmethod
    def get_ibstat_info(cls):
        data = {}
        rdma_device = ""
        port = ""
        output = get_command_output(command='ibstat')
        if output is not None:
            for line in output.splitlines():
                res = line.strip().split(':')
                if len(res) == 1:
                    match = re.search(r'(mlx\d+_\d+)', res[0])
                    if match:
                        rdma_device = match.group(1)
                    else:
                        rdma_device = res[0]
                    data.update({rdma_device.strip(): {}})
                    port = ""
                elif len(res) == 2 and res[1] == '' and 'Port' in res[0]:
                    port = res[0].strip()
                    data[rdma_device].update({port: {}})
                elif len(res) == 2:
                    field = res[0].strip()
                    field_value = res[1].strip()
                    if port:
                        data[rdma_device][port].update({field: field_value})
                    else:
                        data[rdma_device].update({field: field_value})
        return data


class NetworkInterfacesCollector(TimerCollector):
    """
    Collector for Network interfaces info
    """

    def __init__(self, config):
        sampling_interval = config.get(NETWORK_INTERFACES_SAMPLING_OPTION, "20m")
        super().__init__(sampling_interval)

    def _collect(self, gm):
        ip_addr_info = self.get_ip_addr_show()
        if not ip_addr_info:
            LOGGER.error("cannot collect network interface addresses")
            return {}
        data = gm
        data.update({"interfaces": ip_addr_info})
        return data

    @classmethod
    def get_ip_addr_show(cls):
        res = get_command_output(command='ip --json addr show')
        if res is not None:
            return json.loads(res)
        return []


class LshwNetworkCollector(TimerCollector):
    """
    Collector for LSHW Network info
    """

    def __init__(self, config):
        sampling_interval = config.get(LSHW_NETWORK_SAMPLING_OPTION, "1h")
        super().__init__(sampling_interval)

    def _collect(self, gm):
        ip_addr_info = self.get_lshw_network()
        if not ip_addr_info:
            LOGGER.error("cannot collect lshw network")
            return {}
        data = gm
        data.update({"interfaces": ip_addr_info})
        return data

    @classmethod
    def get_lshw_network(cls):
        res = get_command_output(command='lshw -json -class network')
        if res is not None:
            return json.loads(res)
        return []


class LSPCICollector(TimerCollector):
    """
    Collector for LSPCI info
    """
    # example 
    # LnkSta: Speed 16GT/s (ok), Width x16 (ok)
    PCI_LINK_STATE_RE = re.compile(r'LnkSta:\s+Speed\s+(\S+)\s+\((\S+)\),\s+Width\s+(\S+)\s+\((\S+)\)')
    # example 
    # [SN] Serial number: MT2116T00290
    PCI_SERIAL_NUMBER_RE = re.compile(r'Serial number:\s+(\S+)')
    # example
    # Device: 0000:e2:00.0
    # Device: MT28908 Family [ConnectX-6]
    PCI_DEVICE_RE = re.compile(r"Device\:\s+(\S+.*)")
    # example
    # Class:  Infiniband controller
    PCI_CLASS_RE = re.compile(r"Class\:\s+(\S+.*)")
    # example
    # Vendor: Mellanox Technologies
    PCI_VENDOR_RE = re.compile(r"Vendor\:\s+(\S+.*)")
    # example
    # 0000:81:00.0 Infiniband controller: Mellanox Technologies MT28908 Family [ConnectX-6]
    PCI_FULL_BUS_ADDRESS_RE = re.compile(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}.[0-9a-f]+)\s*")

    def __init__(self, config):
        sampling_interval = config.get(LSPCI_SAMPLING_OPTION, "1d")
        super().__init__(sampling_interval)

    def _collect(self, gm):
        data = gm
        bus_info = self.get_bus_info()
        if not bus_info:
            LOGGER.error("cannot collect lspci info")
            return {}
        data["bus_info"] = bus_info
        return data

    @classmethod
    def get_hca_buses(cls):
        port_buses: list = []
        cmd = 'lspci -D |grep -i Mellanox | egrep -i "(Ethernet|Infiniband|Network)" | grep -iv Virtual'
        res = get_command_output(command=cmd)
        if res is not None:
            match = cls.PCI_FULL_BUS_ADDRESS_RE.findall(res)
            for bus in match:
                port_buses.append(bus)
        return port_buses

    @classmethod
    def get_hca_buses_description(cls, port_buses):
        cmd = 'lspci -D -m -vvv -d 15b3:'
        # output example
        #
        # Device: 0000:81:00.0
        # Class:  Infiniband controller
        # Vendor: Mellanox Technologies
        # Device: MT28908 Family [ConnectX-6]
        # SVendor:        Mellanox Technologies
        # SDevice:        MT28908 Family [ConnectX-6]
        # NUMANode:       1
        # 
        # Device: 0000:81:00.1
        # Class:  Infiniband controller
        # Vendor: Mellanox Technologies
        # Device: MT28908 Family [ConnectX-6]
        # SVendor:        Mellanox Technologies
        # SDevice:        MT28908 Family [ConnectX-6]
        # NUMANode:       1
        description_lines_per_device = []
        res = get_command_output(command=cmd)
        if res is not None:
            lines = res.splitlines()
            
            device_lines = []
            for line in lines:
                if line:
                    # separation between paragraphs
                    device_lines.append(line)
                else:
                    if device_lines:
                        description_lines_per_device.append(device_lines)
                    device_lines = []
            if device_lines:
                description_lines_per_device.append(device_lines)
        description = {}
        for lines in description_lines_per_device:
            bus_description = {}
            current_bus = None
            for line in lines:
                if len(bus_description) == 3:
                    break
                match = cls.PCI_DEVICE_RE.match(line)
                if match:
                    first = match.group(1)
                    match = cls.PCI_FULL_BUS_ADDRESS_RE.match(first)
                    if match:
                        first = match.group(1)
                        current_bus = first.strip()
                    else:
                        # is device name and not device number
                        bus_description["device"] = first.strip()
                    continue
                match = cls.PCI_CLASS_RE.match(line)
                if match:
                    class_name = match.group(1)
                    bus_description["class"] = class_name.strip()
                    continue
                match = cls.PCI_VENDOR_RE.match(line)
                if match:
                    vendor = match.group(1)
                    bus_description["vendor"] = vendor.strip()
                    continue
            if current_bus in port_buses:
                description[current_bus] = bus_description
        return description

    @classmethod
    def _parse_pci_serial_number(cls, output : str):
        serial_number = NA
        match = re.search(cls.PCI_SERIAL_NUMBER_RE, output)
        if match:
            serial_number = match.group(1)
        return serial_number

    @classmethod
    def _parse_pci_link_state(cls, output : str):
        link_state = NA
        match = cls.PCI_LINK_STATE_RE.search(output)
        if match is not None:
            link_state = {
                "speed_value": match.group(1),
                "speed_status": match.group(2),
                "width_value": match.group(3),
                "width_status": match.group(4),
            }
        return link_state

    @classmethod
    def _parse_pci_description(cls, bus : str, output : str):
        description : str = NA
        if output:
            end = output.find("\n")
            if end > 0:
                line = output[:end].strip()
                # description starts at the second word
                # example:
                # "81:00.0 Infiniband controller: Mellanox Technologies MT28908 Family [ConnectX-6]"
                first_space = output.find(" ")
                if first_space > 0:
                    description = line[first_space + 1 : ].strip()
        return description

    @classmethod
    def get_bus_info(cls):
        data = []
        port_buses = cls.get_hca_buses()
        buses_description = cls.get_hca_buses_description(port_buses)
        for bus, info in buses_description.items():
            res = get_command_output(command=f'lspci -vvv -s {bus}')
            if res is not None:
                info["serial_number"] = cls._parse_pci_serial_number(res)
                info["link_state"] = cls._parse_pci_link_state(res)
            data.append({bus: info})
        return data

class EthtoolCollector(TimerCollector):
    """
    Collector for Ethtool info
    """

    def __init__(self, config):
        sampling_interval = config.get(ETHTOOL_SAMPLING_OPTION, "1h")
        super().__init__(sampling_interval)

    def _collect(self, gm):
        data = gm
        standard_info = self.get_standard_info()
        if not standard_info:
            LOGGER.error("cannot collect ethtool info")
            return {}
        data["standard_info"] = standard_info
        return data

    @classmethod
    def get_standard_info(cls):
        data = []
        bus_interface_dict = UtilMethods.get_map_bus_interface()
        for interface in bus_interface_dict.values():
            if interface:
                connector_type = ''
                res = get_command_output(command=f'ethtool {interface}|grep -i "Port:"')
                if res is not None:
                    # output example: ethtool p4p1 |grep -i "Port:"
                    #   Port: Direct Attach Copper
                    match = re.search(r'Port:\s+(.*)', res)
                    if match:
                        connector_type = match.group(1)
                data.append({interface: {'connector_type': connector_type}})
        return data


class RoCEConfCollector(TimerCollector):
    """
    Collector for RoCE config info
    """
    ROCE_ACCL_KEY_VAL_RE = re.compile(r"^(\S*)\s*\|\s*(\S+)\s*", re.MULTILINE)
    MLNX_QOS_PTS_RE = re.compile(r"Priority trust state\:\s*(\S+)")
    MLNX_QOS_DSCP2PRIO_RE = re.compile(r"dscp2prio mapping\:\s*")
    MLNX_QOS_DSCP2PRIO_ENTRY_RE = re.compile(r"\s*prio:(\d+)\s*dscp\:((\d+,)+)")
    MLNX_QOS_PFC_RE = re.compile(r"PFC configuration\:\s*")
    MLNX_QOS_PFC_PRIORITY_RE = re.compile(r"\s*priority\s*((\d+\s*)+)")
    MLNX_QOS_PFC_ENABLED_RE = re.compile(r"\s*enabled\s*((\d+\s*)+)")
    MLNX_QOS_PFC_BUFFER_RE = re.compile(r"\s*buffer\s*((\d+\s*)+)")

    def __init__(self, config):
        sampling_interval = config.get(ROCE_SAMPLING_OPTION, "1h")
        super().__init__(sampling_interval)

    def _collect(self, gm):
        roce_info = self.get_roce_config_info()
        if not roce_info:
            LOGGER.error("cannot collect roce config info")
            return {}
        data = gm
        data["roce_info"] = roce_info
        return data

    @classmethod
    def get_roce_config_info(cls):
        data = []
        bus_rdma_interface_list = UtilMethods.get_map_bus_rdma_and_interface()
        for info in bus_rdma_interface_list:
            port_data_info = {}
            bus = info.get('bus', '')
            rdma_device = info.get('rdma', '')
            interface = info.get('interface', '')
            if bus:
                match = re.search(r'\.(\d+)', bus)
                if match:
                    port = str(int(match.group(1)) + 1)
                    ident_data = {"bus": bus, "rdma_device": rdma_device, "port": port, "interface": interface}
                    res = UtilMethods.get_mlxconfig_query(bus=bus)
                    if res:
                        # cnp priority
                        cnp_res = cls.get_cnp_priority(port=port, interface=interface, mlxconfig_output=res)
                        port_data_info.update({"CNP_Priority": cnp_res})

                        # ECN DCQCN
                        ecn_dcqcn_info = cls.get_ecn_dcqcn(port=port, interface=interface, mlxconfig_output=res)
                        port_data_info.update({"ECN_DCQCN": ecn_dcqcn_info})
                    # roce mode
                    roce_mode_info = cls.get_roce_mode(rdma_device=rdma_device)
                    port_data_info.update({"RoCE_Mode": roce_mode_info})

                    # roce tos
                    roce_tos_info = cls.get_roce_tos(rdma_device=rdma_device)
                    port_data_info.update({"RoCE_TOS": roce_tos_info})

                    # DSCP
                    dscp_info = cls.get_dscp_traffic_class(rdma_device=rdma_device)
                    port_data_info.update({"DSCP": {"traffic_class": dscp_info}})

                    # ROCE_ACCL
                    roce_accl = cls.get_roce_accl(bus=bus)
                    port_data_info.update({"ROCE_ACCL": roce_accl})

                    # mlnx_qos
                    pfc, dscp2prio, pts = cls.get_mlnx_qos_entries(interface)
                    port_data_info["DSCP"].update({"dscp2prio": dscp2prio})
                    port_data_info.update({"Priority_trust_state": pts})
                    port_data_info.update({"PFC": pfc})

                    port_data_info.update(ident_data)
            data.append(port_data_info)
        return data

    @classmethod
    def get_mlnx_qos_entries(cls, interface):
        pfc, dscp2prio, pts = {}, {}, NA
        if not interface:
            return pfc, dscp2prio, pts
        res = get_command_output(command=f"mlnx_qos -i {interface}")
        if res is not None:
            in_dscp2prio = False
            in_pfc = False
            pfc_prios = []
            pfc_enabled = []
            pfc_buffers = []
            lines = res.splitlines()
            for line in lines:
                if in_dscp2prio:
                    match = cls.MLNX_QOS_DSCP2PRIO_ENTRY_RE.match(line)
                    if match:
                        prio = match.group(1)
                        dscp = match.group(2)
                        dscp = dscp.split(",")
                        vals = []
                        for v in dscp:
                            v = v.strip()
                            if v:
                                vals.append(v)
                        dscp2prio[prio] = vals
                        continue
                    in_dscp2prio = False
                if in_pfc:
                    match = cls.MLNX_QOS_PFC_PRIORITY_RE.match(line)
                    if match:
                        prios = match.group(1)
                        prios = prios.split(" ")
                        for p in prios:
                            p = p.strip()
                            if p:
                                pfc_prios.append(p)
                        continue
                    match = cls.MLNX_QOS_PFC_ENABLED_RE.match(line)
                    if match:
                        enabled = match.group(1)
                        enabled = enabled.split(" ")
                        for e in enabled:
                            e = e.strip()
                            if e:
                                pfc_enabled.append(e)
                        continue
                    match = cls.MLNX_QOS_PFC_BUFFER_RE.match(line)
                    if match:
                        buffers = match.group(1)
                        buffers = buffers.split(" ")
                        for b in buffers:
                            b = b.strip()
                            if b:
                                pfc_buffers.append(b)
                        continue
                    in_pfc = False
                match = cls.MLNX_QOS_PFC_RE.match(line)
                if match:
                    in_pfc = True
                match = cls.MLNX_QOS_PTS_RE.match(line)
                if match:
                    pts = match.group(1).strip()
                    continue
                match = cls.MLNX_QOS_DSCP2PRIO_RE.match(line)
                if match:
                    in_dscp2prio = True
                    continue
            if pfc_prios and len(pfc_prios) == len(pfc_enabled) == len(pfc_buffers):
                for i, prio in enumerate(pfc_prios):
                    pfc[prio] = {
                        "enabled": pfc_enabled[i],
                        "buffer": pfc_buffers[i],
                    }
        return pfc, dscp2prio, pts

    @classmethod
    def get_roce_accl(cls, bus):
        roce_accl = {}
        res = get_command_output(command=f"mlxreg -d {bus} --reg_name ROCE_ACCL --get")
        if res is not None:
            matches = cls.ROCE_ACCL_KEY_VAL_RE.findall(res)
            for match in matches:
                key = match[0]
                value = match[1]
                roce_accl[key] = value
        return roce_accl


    @classmethod
    def get_roce_mode(cls, rdma_device):
        mode = ''
        res = get_command_output(command=f'cma_roce_mode -d {rdma_device}')
        if res is not None:
            mode = res.strip()
        return mode

    @classmethod
    def get_roce_tos(cls, rdma_device):
        tos = ''
        res = get_command_output(command=f'cma_roce_tos -d {rdma_device}')
        if res is not None:
            tos = res.strip()
        return tos

    @classmethod
    def get_dscp_traffic_class(cls, rdma_device):
        tc = ''
        res = get_command_output(command=f'cat /sys/class/infiniband/{rdma_device}/tc/1/traffic_class')
        if res is not None:
            tc = res.splitlines()
        return tc

    @classmethod
    def get_cnp_priority(cls, port, interface, mlxconfig_output):
        data = {}
        fields = [f'CNP_DSCP_P{port}', f'CNP_802P_PRIO_P{port}']
        for field in fields:
            match = re.search(rf'{field}\s+(\S+)', mlxconfig_output)
            if match:
                field_value = match.group(1)
                data.update({field: field_value})

        dscp_cnp_data = {"ecn_roce_np_cnp_dscp": f'/sys/class/net/{interface}/ecn/roce_np/cnp_dscp',
                         "ecn_roce_np_cnp_802p_prio": f'/sys/class/net/{interface}/ecn/roce_np/cnp_802p_prio'}
        for k, p in dscp_cnp_data.items():
            res = get_command_output(command=f'cat {p}')
            if res is not None:
                cnp_value = res.strip()
                data.update({f'{k}': cnp_value})
        return data

    @classmethod
    def get_ecn_dcqcn(cls, port, interface, mlxconfig_output):
        data = {}

        field = f'ROCE_CC_PRIO_MASK_P{port}'
        match = re.search(rf'{field}\s+(\S+)', mlxconfig_output)
        if match:
            field_value = match.group(1)
            data.update({field: field_value})
        ecn_dcqcn_np_rp_data = {"ecn_roce_np_enable": f'/sys/class/net/{interface}/ecn/roce_np/enable/',
                                "ecn_roce_rp_enable": f'/sys/class/net/{interface}/ecn/roce_rp/enable/'}
        for k, p in ecn_dcqcn_np_rp_data.items():
            for i in range(0,8):
                res = get_command_output(command=f'cat {p}{i}')
                if res is not None:
                    enable_i_value = res.strip()
                    data.update({f'{k}_{i}': enable_i_value})
        return data

class NodeCollector(TimerCollector):
    """
    Collector for node status (uptime, last boot, IP)
    Serves as heartbit
    """

    IPV4_RE = re.compile(r"(\d+\.\d+\.\d+\.\d+)")

    def __init__(self, config):
        sampling_interval = config.get(NODE_SAMPLING_OPTION, "30s")
        super().__init__(sampling_interval)

    def _collect(self, gm):
        data = gm
        data["lastboot"] = SCRIPT_START_TIME
        data = self._get_sys_uptime(data)
        data = self._get_ip(data)
        return data

    @staticmethod
    def _get_sys_uptime(data):
        key = "sys_uptime"
        txt = get_command_output(command="uptime -s")
        if txt is not None:
            try:
                timestamp = datetime.strptime(txt.strip(), "%Y-%m-%d %H:%M:%S").timestamp()
            except Exception as exc:
                LOGGER.error(f"failed to convert '{txt}' to timestamp: {exc}")
                timestamp = NA
            data[key] = timestamp
        return data

    @classmethod
    def _get_ip(cls, data):
        key = "ipv4"
        txt = get_command_output(command="hostname -i")
        data[key] = NA
        if txt is not None:
            match = cls.IPV4_RE.search(txt)
            if match:
                data[key] = match.group(1)
        if data[key] == NA:
            LOGGER.error(f"failed to set '{key}'. raw output : {txt.strip()}")
        return data

class MainCollector:
    """
    Class to iterate over collectors and return collected data list
    """
    section_to_collector = {}
    _initialized = False

    def __init__(self):
        self._collected_data = []

    def collect(self):
        for section, collector in MainCollector.section_to_collector.items():
            try:
                data = collector.collect(section)
                if data:
                    self._collected_data.append(data)
            except Exception as err:
                LOGGER.error(f"while collecting {section} data: {err}")
        return self._collected_data


def set_section_to_collector(config):
    all_section_to_collector = {
        "NetIFAddr": NetworkInterfacesCollector(config),
        "IBDevices": IbDevicesCollector(config),
        "LSHWNetwork": LshwNetworkCollector(config),
        "LSPCI": LSPCICollector(config),
        "Ethtool": EthtoolCollector(config),
        "RoCEConf": RoCEConfCollector(config),
        "Node": NodeCollector(config),
    }
    enabled_components = config.get("enabled", "")
    if enabled_components:
        enabled_components = enabled_components.split(',')
    else:
        enabled_components = list(all_section_to_collector.keys())
    for section, collector in all_section_to_collector.items():
        if section in enabled_components:
            MainCollector.section_to_collector[section] = collector


# configuration

def set_logging(config):
    global LOG_LEVEL_OPTION
    num = config.get(LOG_LEVEL_OPTION, 3)
    num = int(num)
    if num >= 7:
        level = logging.DEBUG
    elif num >= 6:
        level = logging.INFO
    elif num >= 4:
        level = logging.WARNING
    elif num >= 3:
        level = logging.ERROR
    else:
        level = logging.NOTSET
    logging.basicConfig(level=level, format='%(asctime)s   %(name)-17s   %(levelname)-8s   %(message)s')


def init(config):
    global LEGACY_MODE, TIMEOUT_RETRY_DELAY
    LEGACY_MODE = bool_from_dict(config, LEGACY_MODE_OPTION, LEGACY_MODE)
    TIMEOUT_RETRY_DELAY = int_from_dict(config, TIMEOUT_RETRY_DELAY_OPTION, TIMEOUT_RETRY_DELAY)
    set_logging(config)
    GeneralDataCollector.hostname = get_hostname()
    set_section_to_collector(config)
    return True


class ScriptContext:
    initialized = False


def get_tag():
    return "host_inventory"


def collect(config):
    if not ScriptContext.initialized:
        if not init(config):
            return {}
        else:
            ScriptContext.initialized = True
    collector = MainCollector()
    data = collector.collect()
    return data


if __name__ == "__main__":
    res = collect({})
    res = json.dumps(res)
    print(res)
