"""
Copyright © 2019-2022 NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.

This software product is a proprietary product of Nvidia Corporation and its affiliates
(the "Company") and all right, title, and interest in and to the software
product, including all associated intellectual property rights, are and
shall remain exclusively with the Company.

This software product is governed by the End User License Agreement
provided with the software product.
"""
# this script periodically calls DTS gRPC server (blueman.proto)
import os
import sys
import time
import hashlib
import logging
from typing import Union, Callable
import grpc

# update PYTHONPATH to import blueman grpc modules
clx_root = os.getenv("CLX_ROOT", "/opt/mellanox/collectx")
blueman_dir = os.path.join(clx_root, "services", "blueman", "grpc", "server")
sys.path.append(blueman_dir)
sys.path.append(clx_root)
import blueman_pb2
import blueman_pb2_grpc

DTS_INTERNAL_ID_FILE = "/var/tmp/dts_internal_id.txt"
DTS_INTERNAL_ID = "dts_internal_id"
CLIENT = None

def get_host_name():
    from services.modules.system.general.general import General
    from services.dpe.client.dpe_client import DpeClient
    connection = DpeClient()
    return General(connection).get_host_name()

def current_time_usec():
    secs = time.time()
    return round(secs * 1e6)

HOST_NAME = get_host_name()

class Timer:
    """basic class for time scheduling
    """
    def __init__(self, interval : int):
        self._interval = interval * 1e6  # convert to usec
        self.last_sampling_timestamp = 0

    def time_passed(self):
        """check if self passed the time interval. when True, reset the timer

        Returns:
            bool: True if time passed, otherwise False
        """
        now = current_time_usec()
        passed = now - self.last_sampling_timestamp
        if passed >= self._interval:
            self.last_sampling_timestamp = now
            return True
        return False

class MessageHandler:
    def __init__(self, handle : Callable, interval : int, message_name : str):
        """class constructor

        Args:
            handle (Callable): function that returns message content
            interval (int): collection time interval (secs)
            message_name (str): message name
        """
        self._timer = Timer(interval)
        self._handle = handle
        self._message_name = message_name

    def _to_dict(self, content):
        return {
            "hostname": HOST_NAME,
            "timestamp": self._timer.last_sampling_timestamp,
            "message_type": self._message_name,
            "message": content,
        }
    
    def get_message_dict(self) -> Union[None, dict]:
        """collect data and return resolved dictionary

        Returns:
            Union[None, dict]: dict if collected data, otherwise None
        """
        if not self._timer.time_passed():
            return None
        logging.debug("running method of %s", self._message_name)
        content = self._handle()
        message = self._to_dict(content)
        return message

    @property
    def message_name(self) -> str:
        """property for message_name

        Returns:
            str: instance message name
        """
        return self._message_name

class BluemanCollector():
    """grpc client to run blueman.proto rpc calls
    """
    def __init__(self, interval : int, uid: str, conf : dict, retry_delay : int):
        """class constructor

        Args:
            interval (int): collection interval time (seconds)
            uid (str): UUID used for grpc authorization
            conf (dict): runtime configuration
        """
        self._call_metadata = ((DTS_INTERNAL_ID, uid),)
        #TODO: update blueman server address once update
        self._channel = grpc.insecure_channel("localhost:50051")
        self._stub = blueman_pb2_grpc.BluemanServiceStub(self._channel)
        self._message_handlers = self._generate_message_handlers(conf, interval)
        self._supported_data_types = (list, dict, str, int, float, bool)
        self._installed_package_hash = ''
        self._firmware_info_hash = ''
        self._retry_delay = retry_delay
        self._delayed_messages = {}

    def __del__(self):
        """class destructor - close instance channel
        """
        self._channel.close()

    def _generate_message_handlers(self, conf : dict, interval : int) -> list:
        """create list of message handler

        Args:
            conf (dict): runtime configuration 
            interval (int): default interval value (secs)

        Returns:
            list: MessageHandler instances, based on input
        """
        key_to_method = {
            "general_info":            self._collect_general_info,
            "installed_packages_info": self._collect_installed_packages_info,
            "installed_packages_info_hash": self._get_installed_packages_info_hash,
            "cpu_info":                self._collect_cpu_info,
            "firmware_info":           self._collect_firmware_info,
            "firmware_info_hash":      self._get_firmware_info_hash,
            "dpu_operation_mode_info": self._collect_dpu_operation_mode_info,
            "nvme_info":               self._collect_nvme_info,
            "emmc_info":               self._collect_emmc_info,
            "cpu_health":              self._collect_cpu_health,
            "memory_health":           self._collect_memory_health,
            "disk_health":             self._collect_disk_health,
            "dpu_temperature_health":  self._collect_dpu_temperature_health,
            "cpu_frequency_health":    self._collect_cpu_frequency_health,
            "power_consumption_health": self._collect_power_consumption_health,
            "system_services_health":  self._collect_system_services_health,
            "dpu_rshim_log_health":    self._collect_dpu_rshim_log_health,
            "nvme_log_health":         self._collect_nvme_log_health,
            "kernel_modules_health":   self._collect_kernel_modules_health,
            "port_status_health":      self._collect_port_status_health,
            "doca_services":           self._collect_doca_services,
        }

        # get time interval
        handlers_list = []
        for key, method in key_to_method.items():
            key_interval_str = f"{key}_interval"
            interval = conf.get(key_interval_str, interval)
            interval = int(interval)
            handler = MessageHandler(method, interval, key)
            handlers_list.append(handler)
        return handlers_list

    @staticmethod
    def _get_repeated_data_dict_list(response) -> list:
        """convert protobuf response to python list

        Args:
            response : blueman.proto response

        Returns:
            list: parsed rpc responses list
        """
        res = []
        if response:
            for item in response.data:
                res.append(dict(item.data))
        return res

    def _collect_general_info(self) -> dict:
        """collect general info by rpc

        Returns:
            dict: parsed rpc response
        """
        res = {}
        request = blueman_pb2.GeneralInfoRequest()
        response = self._stub.general_info(request, metadata=self._call_metadata)
        if response:
            res = dict(response.data)
        return res

    def _collect_installed_packages_info(self) -> list:
        """collect installed pakcakges info via rpc

        Returns:
            list: parsed rpc response
        """
        res = []
        request = blueman_pb2.InstalledPackagesInfoRequest()
        response = self._stub.installed_packages_info(request, metadata=self._call_metadata)
        if response:
            for pckg in response.package_data:
                as_dict = {"name": pckg.package_name, "version": pckg.version}
                res.append(as_dict)
        self._installed_package_hash = hashlib.md5(str(res).encode()).hexdigest()
        return res

    def _get_installed_packages_info_hash(self) -> dict:
        """
        get installed packages info hash
        Returns:
            dict: hash key and hash value
        """
        return {'hash':self._installed_package_hash}

    def _collect_cpu_info(self) -> dict:
        """collect cpu info via rpc

        Returns:
            dict: parsed rpc response
        """
        res = {}
        request = blueman_pb2.CpuInfoRequest()
        response = self._stub.cpu_info(request, metadata=self._call_metadata)
        if response:
            res = dict(response.data)
        return res

    def _collect_firmware_info(self) -> dict:
        """collect firmware info via rpc

        Returns:
            dict: parsed rpc response
        """

        dict_of_port_data = {}
        request = blueman_pb2.FirmwareInfoRequest()
        response = self._stub.firmware_info(request, metadata=self._call_metadata)
        if response:

            for port_fw_info in response.data:
                res = []
                for fw_info in port_fw_info.data:
                    as_dict = {
                        "field_name":          fw_info.field_name,
                        "default":             fw_info.default,
                        "current":             fw_info.current,
                        "next_boot":           fw_info.next_boot,
                        "description":         fw_info.description,
                        "field_config_params": fw_info.field_config_params,
                        "changed":             fw_info.changed,
                        }
                    res.append(as_dict)
                dict_of_port_data.update({str(port_fw_info.port_number): res})
        self._firmware_info_hash = hashlib.md5(str(dict_of_port_data).encode()).hexdigest()
        return dict_of_port_data

    def _get_firmware_info_hash(self) -> dict:
        """
        get firmware info hash
        Returns:
            dict: hash key and hash value
        """
        return {'hash': self._firmware_info_hash}

    def _collect_dpu_operation_mode_info(self) -> dict:
        """collect dpu operation mode info via rpc

        Returns:
            dict: parsed rpc response
        """
        res = {}
        request = blueman_pb2.DpuOperationModeInfoRequest()
        response = self._stub.dpu_operation_mode_info(request, metadata=self._call_metadata)
        if response:
            # response is a nested Map. dict() cannot parse nested Map to a nested dictionary
            shallow = dict(response.data)
            for key, val in shallow.items():
                res[key] = dict(val.data)
            return res
        return {}

    def _collect_nvme_info(self) -> list:
        """
        collect nvme info

        Returns: list of dict
        """

        request = blueman_pb2.NvmeInfoRequest()
        response = self._stub.nvme_info(request, metadata=self._call_metadata)
        return self._get_repeated_data_dict_list(response)

    def _collect_emmc_info(self) -> dict:
        """
        collect eMMC identification and config info
        Returns: dict of data
        """
        data = {}
        request = blueman_pb2.EmmcInfoRequest()
        response = self._stub.emmc_info(request, metadata=self._call_metadata)
        data["identification"] = dict(response.identification)
        data["config_info"] = response.config_info

        return data

    def _collect_cpu_health(self) -> dict:
        """collect cpu health via rpc

        Returns:
            dict: parsed rpc response
        """
        res = {}
        request = blueman_pb2.CpuHealthRequest()
        response = self._stub.cpu_health(request, metadata=self._call_metadata)
        if response:
            for i in response.cpu_usages:
                cores_data = response.cpu_usages[i]
                processes = []
                res[i] = {
                    "usage":  cores_data.usage,
                    "status": cores_data.status,
                    "processes": processes
                }
                for proc in cores_data.processes:
                    as_dict = dict(proc.data)
                    processes.append(as_dict)
        return res

    def _collect_memory_health(self) -> dict:
        """collect memory health via rpc

        Returns:
            dict: parsed rpc response
        """
        res = {}
        request = blueman_pb2.MemoryHealthRequest()
        response = self._stub.memory_health(request, metadata=self._call_metadata)
        if response:
            for k in response.ram_data:
                ram_data = response.ram_data[k]
                res[k] = {
                    "total":  dict(ram_data.total),
                    "used":   dict(ram_data.used),
                    "free":   dict(ram_data.free),
                    "usage":  dict(ram_data.usage),
                    "status": ram_data.status
                }
        return res

    def _collect_disk_health(self) -> dict:
        """collect disk health via rpc

        Returns:
            dict: parsed rpc response
        """
        res = {}
        request = blueman_pb2.DiskHealthRequest()
        response = self._stub.disk_health(request, metadata=self._call_metadata)
        if response:
            for k in response.hard_disk_data:
                hd_data = response.hard_disk_data[k]
                res[k] = {
                    "size":        dict(hd_data.size),
                    "used":        dict(hd_data.used),
                    "available":   dict(hd_data.available),
                    "usage":       dict(hd_data.usage),
                    "status":      hd_data.status
                }
        return res

    def _collect_dpu_temperature_health(self) -> dict:
        """collect dpe temperature health via rpc

        Returns:
            dict: parsed rpc response
        """
        res = {}
        request = blueman_pb2.DpuTemperatureHealthRequest()
        response = self._stub.dpu_temperature_health(request, metadata=self._call_metadata)
        if response:
            for k in response.dpu_temperature_data:
                dpu_temp = response.dpu_temperature_data[k]
                res[k] = {
                    "temperature": dict(dpu_temp.temperature),
                    "status":      dpu_temp.status,
                }
        return res

    def _collect_cpu_frequency_health(self) -> dict:
        """collect cpu frequency health via grpc

        Returns:
            dict: parsed grpc response
        """
        res = {}
        request = blueman_pb2.CpuFrequencyHealthRequest()
        response = self._stub.cpu_frequency_health(request, metadata=self._call_metadata)
        if response:
            res['frequency_value'] = response.frequency_value
            res['status'] = response.status
            res['unit'] = response.unit
        return res

    def _collect_power_consumption_health(self) -> dict:
        """collect power consumption health via grpc

        Returns:
            dict: parsed grpc response
        """
        res = {}
        request = blueman_pb2.PowerConsumptionHealthRequest()
        response = self._stub.power_consumption_health(request, metadata=self._call_metadata)
        if response:
            res['power_consumption'] = response.power_consumption
            res['status'] = response.status
            res['unit'] = response.unit
        return res

    def _collect_system_services_health(self) -> list:
        """collect system services health via rpc

        Returns:
            list: parsed rpc response
        """
        request = blueman_pb2.SystemServicesHealthRequest()
        response = self._stub.system_services_health(request, metadata=self._call_metadata)
        return self._get_repeated_data_dict_list(response)

    def _collect_dpu_rshim_log_health(self) -> list:
        """
        collect dpu rshim log via grpc
        Returns:
            list of log
        """
        data = []
        request = blueman_pb2.RshimLogHealthRequest()
        response = self._stub.dpu_rshim_log_health(request, metadata=self._call_metadata)
        if response:
            data = list(response.data)
        return data

    def _collect_nvme_log_health(self) -> dict:
        """
        collect nvme log health bia grpc
        """
        res = {}
        request = blueman_pb2.NvmeLogHealthRequest()
        response = self._stub.nvme_log_health(request, metadata=self._call_metadata)
        if response:
            # response is a nested Map. dict() cannot parse nested Map to a nested dictionary
            shallow = dict(response.data)
            for key, val in shallow.items():
                res[key] = dict(val.data)
            return res
        return {}

    def _collect_kernel_modules_health(self) -> dict:
        """collect kernel modules health via rpc

        Returns:
            list: parsed rpc response
        """
        request = blueman_pb2.KernelModulesHealthRequest()
        response = self._stub.kernel_modules_health(request, metadata=self._call_metadata)
        return self._get_repeated_data_dict_list(response)

    def _collect_port_status_health(self) -> list:
        """collect port status health via rpc

        Returns:
            list: parsed rpc response
        """
        request = blueman_pb2.PortStatusHealthRequest()
        response = self._stub.port_status_health(request, metadata=self._call_metadata)
        return self._get_repeated_data_dict_list(response)

    def _collect_doca_services(self) -> dict:
        """collect doca services via rpc

        Returns:
            dict: parsed rpc response
        """
        res = []
        request = blueman_pb2.DocaServicesRequest()
        response = self._stub.doca_services(request, metadata=self._call_metadata)
        if response:
            for service_data in response.data:
                item = {
                    "name": service_data.name,
                    "container_id": service_data.container_id,
                    "state": service_data.state
                }
                res.append(item)
        return res

    def collect(self) -> dict:
        """collect all rpc calls data

        Returns:
            dict: collected data
        """
        data = []
        for handler in self._message_handlers:
            name = handler.message_name
            delayed_message_time = self._delayed_messages.get(name)
            if delayed_message_time:
                if time.time() - delayed_message_time < self._retry_delay:
                    logging.debug("Skipping collecting %s", name)
                    continue
                else:
                    del self._delayed_messages[name]
            try:
                message = handler.get_message_dict()
            except Exception as err:
                logging.error("the following error occurred while collecting %s: %s", name, err)
                self._delayed_messages[name] = time.time()
                logging.info("Will retry collecting %s in %d seconds", name, self._retry_delay)
                message = None

            if message is not None:
                data.append(message)
        return data

def get_dts_internal_uid() -> Union[str, None]:
    """get an internal UUID used for inner gRPC authorization

    Returns:
        Union[str, None]: uuid string if success, otherwise None
    """
    try:
        with open(DTS_INTERNAL_ID_FILE, "r") as f:
            res = f.read()
    except Exception:
        res = None
    return res

def to_int(val : Union[int, str]) -> Union[int, None]:
    """return input as int

    Args:
        val (Union[int, str]): raw integer description

    Returns:
        Union[int, None]: resolved integer on success, otherwise None
    """
    res = None
    if isinstance(val, int):
        res = val
    elif isinstance(val, str):
        try:
            res = int(val)
        except Exception:
            logging.error("could not convert '%s' to integer", val)
    return res

def init(conf : dict) -> bool:
    """script setup

    Args:
        conf (dict): configuration dictionary

    Returns:
        bool: True on success, otherwise False
    """
    res = False

    # set logging
    level = logging.WARNING
    if conf.get("verbose", False):
        level = logging.INFO
    logging.basicConfig(level=level, format='%(asctime)s : %(levelname)s : %(message)s')

    uuid = get_dts_internal_uid()
    if not uuid:
        logging.error("Cannot connect to gRPC backend - no credentials found")

    # set CLIENT
    interval = conf.get("interval", 10 * 60)  # 10 mins as seconds
    interval = to_int(interval)
    if isinstance(interval, int) and interval > 0:
        global CLIENT

        # set retry delay
        retry_delay = conf.get("retry-delay", 60 * 60)  # 1 hour as seconds
        retry_delay = to_int(retry_delay)
        if retry_delay < 0:
            retry_delay = 60 * 60
            logging.warning("invalid retry delay value, using default %d", retry_delay)
        CLIENT = BluemanCollector(interval, uuid, conf, retry_delay)
        res = True
    else:
        logging.error("invalid interval value")
    return res

def collect(conf : dict = {}) -> dict:
    """script entry point (for python provider)

    Args:
        conf (dict, optional): configuration dictionary. Defaults to {}.

    Returns:
        dict: collected data
    """
    if CLIENT is None and not init(conf):
        return None
    res = CLIENT.collect()
    return res

def get_tag():
    return "bluefield"

if __name__ == "__main__":
    # debug
    print(collect())
