#!/usr/bin/env python3
# ex:ts=4:sw=4:sts=4:et
# -*- tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*-
###############################################################################
#
# Copyright 2020 NVIDIA Corporation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
###############################################################################

import os
import sys
import argparse
import subprocess
import glob
import time
import re
import json
import errno

__author__ = "Vladimir Sokolovsky <vlad@nvidia.com>"
__version__ = "1.1"

prog = os.path.basename(sys.argv[0])
os.environ['PATH'] = '/opt/mellanox/iproute2/sbin:/usr/sbin:/usr/bin:/sbin:/bin'

MLXREG = '/usr/bin/mlxreg'
MLXDEVM = 'mlxdevm'
DEVLINK = 'devlink'
SUPPORTED_ACTIONS=["create", "show", "delete"]
RDMA_DEV_TIMEOUT = 120
DEVLINK_DEV_TIMEOUT = 20
SF_CFG_DRIVER_TIMEOUT = 20

verbose = False

if os.path.exists("/usr/bin/mlxconfig"):
    MLXCONFIG = "/usr/bin/mlxconfig"
elif os.path.exists("/usr/bin/mstconfig"):
    MLXCONFIG = "/usr/bin/mstconfig"
else:
    sf_log("ERROR: MFT or mstflint package is required")
    sys.exit(1)


class SF:
    def __init__ (self, args):
        self.device = args.device
        self.pfnum = None
        if self.device:
            self.pfnum = self.device[-1]
        self.action = args.action
        self.verbose = args.verbose
        self.json = args.json
        self.pretty = args.pretty
        self.enable_trust = args.enable_trust
        self.enable_eswitch = args.enable_eswitch
        self.disable_roce = args.disable_roce
        self.sfnum = args.sfnum or None
        self.hw_addr = args.hwaddr or None
        self.sfindex = args.sfindex or None
        self.netif = None
        self.aux_dev = None
        self.sf_netdev = None
        self.rdma_dev = None
        self.cpu_list = args.cpu_list or None
        self.info = {}
        self.info_error = 0


    def show(self):
        """
        Show configurations
        """
        # Get existing devices
        cmd = MLXDEVM + " --pretty --json port show "
        if self.sfindex:
            cmd += self.sfindex

        self.info_error, output = get_status_output(cmd)
        if self.info_error:
            return
        self.info = json.loads(output)

        result = {'status': self.info_error, 'output': 'No SF device found'}

        if not self.info:
            return result

        info = self.info["port"]

        for pci_dev_index in info:
            info[pci_dev_index]["device"] = pci_dev_index.split('/')[1]
            info[pci_dev_index]["sfindex"] = pci_dev_index
            info[pci_dev_index]["aux_dev"] = lookup_aux_dev(info[pci_dev_index]["device"], info[pci_dev_index]["sfnum"])
            info[pci_dev_index]["sf_netdev"] = get_sf_netdev(info[pci_dev_index]["aux_dev"])
            info[pci_dev_index]["rdma_dev"] = get_sf_rdmadev(info[pci_dev_index]["aux_dev"])
            if "trust" not in info[pci_dev_index]["function"]:
                info[pci_dev_index]["function"]["trust"] = "NA"
            if "eswitch" not in info[pci_dev_index]["function"]:
                info[pci_dev_index]["function"]["eswitch"] = "NA"


        if self.json:
            if self.pretty:
                result['output'] = json.dumps(info, indent=4)
            else:
                result['output'] = json.dumps(info)
        else:
            result['output'] = ""
            for pci_dev_index in info:
                result['output'] += "\n"
                result['output'] += "SF Index: " + pci_dev_index + "\n"
                result['output'] += "  Parent PCI dev: " + info[pci_dev_index]["device"] + "\n"
                result['output'] += "  Representor netdev: " + info[pci_dev_index]["netdev"] + "\n"
                result['output'] += "  Function HWADDR: " + info[pci_dev_index]["function"]["hw_addr"] + "\n"
                result['output'] += "  Function trust: " + info[pci_dev_index]["function"]["trust"] + "\n"
                result['output'] += "  Function roce: " + info[pci_dev_index]["function"]["roce"] + "\n"
                result['output'] += "  Function eswitch: " + info[pci_dev_index]["function"]["eswitch"] + "\n"
                result['output'] += "  Auxiliary device: " + info[pci_dev_index]["aux_dev"] + "\n"
                result['output'] += "    netdev: " + info[pci_dev_index]["sf_netdev"] + "\n"
                result['output'] += "    RDMA dev: " + info[pci_dev_index]["rdma_dev"] + "\n"

        return result


    def create(self):
        """
        Create SF
        """
        result = {}
        self.sfindex = None

        cmd = MLXDEVM + " port add pci/{device} flavour pcisf pfnum {pfnum} sfnum {sfnum}".format(
                device=self.device, pfnum=self.pfnum, sfnum=self.sfnum
                )

        result['status'], result['output'] = get_status_output(cmd, self.verbose)
        if result['status']:
            return result

        for line in result['output'].split('\n'):
            if 'pci/' + self.device not in line:
                continue

            self.sfindex = line.split(' ')[0].strip()
            self.sfindex = self.sfindex[:-1]
            self.netif = line.split(' ')[4].strip()

        if not self.sfindex:
            result['output'] = "ERROR: Cannot find sfindex for {device} sfnum {sfnum}".format(device=self.device, sfnum=self.sfnum)
            result['status'] = 1
            sf_log(result['output'])
            return result

        if self.hw_addr:
            cmd = MLXDEVM + " port function set {idx} hw_addr {hw_addr}".format(
                    idx=self.sfindex, hw_addr=self.hw_addr
                    )
            result['status'], result['output'] = get_status_output(cmd, self.verbose)
            if result['status']:
                return result

        if self.enable_trust:
            cmd = MLXDEVM + " port function set {idx} trust on".format(
                    idx=self.sfindex
                    )
            result['status'], result['output'] = get_status_output(cmd, self.verbose)
            if result['status']:
                return result

        if self.disable_roce:
            cmd = MLXDEVM + " port function cap set {} roce false".format(self.sfindex)
            result['status'], result['output'] = get_status_output(cmd, self.verbose)
            if result['status']:
                return result

        if self.enable_eswitch:
            cmd = MLXDEVM + " port function cap set {idx} eswitch true".format(idx=self.sfindex)
            result['status'], result['output'] = get_status_output(cmd, self.verbose)
            if result['status']:
                return result


        # Activate SF
        cmd = MLXDEVM + " port function set {} state active".format(self.sfindex)
        result['status'], result['output'] = get_status_output(cmd, self.verbose)
        if result['status']:
            return result
        # Find the rdma device for the sf
        start = time.process_time()
        while True:
            if time.process_time() - start > RDMA_DEV_TIMEOUT:
                break
            self.aux_dev = lookup_aux_dev(self.device, self.sfnum)
            if self.aux_dev:
                break
            time.sleep(0.001)

        # Wait for sf_cfg link to be created
        start = time.perf_counter()
        while not os.path.exists("/sys/bus/auxiliary/devices/{}/driver".format(self.aux_dev)):
            if time.perf_counter() - start >= SF_CFG_DRIVER_TIMEOUT:
                sf_log("ERROR: timed out waiting for the config driver /sys/bus/auxiliary/devices/{}/driver to appear".format(self.aux_dev), True)
                sf_log("ERROR: Failed to create SF {} for device {}".format(self.sfnum, self.device), True)
                self.delete()
                sys.exit(1)
            time.sleep(0.01)

        # Unbind the SF from the default config driver and bind the actual SF driver
        if os.path.basename(os.readlink("/sys/bus/auxiliary/devices/{}/driver".format(self.aux_dev))) == "mlx5_core.sf_cfg":
            try:
                fname = "/sys/bus/auxiliary/drivers/mlx5_core.sf_cfg/unbind"
                with open(fname, 'w') as f:
                    f.write(self.aux_dev)
                fname = "/sys/bus/auxiliary/drivers/mlx5_core.sf/bind"
                with open(fname, 'w') as f:
                    f.write(self.aux_dev)
            except IOError as e:
                sf_log("I/O error({0}): {1}".format(e.errno, e.strerror))
            except FileNotFoundError:
                sf_log("ERROR: File {} does not exist".format(fname))
            except:
                sf_log("Unexpected error:", sys.exc_info()[0])

        # Wait for devlink to see SF before running devlink commands
        start = time.process_time()
        while True:
            if time.process_time() - start > DEVLINK_DEV_TIMEOUT:
                sf_log("WARNING: timed out waiting for SF device {} to appear: {}".format(self.aux_dev, result))
                break

            cmd = DEVLINK + " dev show auxiliary/{}".format(self.aux_dev)
            result['status'], result['output'] = get_status_output(cmd, self.verbose)
            if result['status'] == 0:
                break

            time.sleep(0.05)

        if self.cpu_list:
            cmd = MLXDEVM + " dev param set auxiliary/{} name cpu_affinity value {} cmode driverinit".format(
                    self.aux_dev, self.cpu_list
                    )
            result['status'], result['output'] = get_status_output(cmd, self.verbose)
            if result['status']:
                return result
            cmd = DEVLINK + " dev reload auxiliary/{}".format(self.aux_dev)
            result['status'], result['output'] = get_status_output(cmd, self.verbose)
            if result['status']:
                return result


        result["output"] = "Created SF: index={} hw_addr={} sfnum={} netif={}".format(self.sfindex, self.hw_addr or "00:00:00:00:00:00", self.sfnum, self.netif)
        sf_log(result["output"])
        return result


    def delete(self):
        """
        delete SF
        """
        result = {}
        # Deactivate SF
        cmd = MLXDEVM + " port function set {} state inactive".format(self.sfindex)
        result['status'], result['output'] = get_status_output(cmd, self.verbose)
        if result['status']:
            return result

        # Delete SF
        cmd = MLXDEVM + " port del {}".format(self.sfindex)
        result['status'], result['output'] = get_status_output(cmd, self.verbose)
        return result


def lookup_aux_dev(device, sfnum):
    """
    This routine search pci device and sf number and returns the device found
    else it returns an empty string.
    """

    devices = os.listdir("/sys/bus/auxiliary/devices")
    for aux_dev in devices:
        if (aux_dev.find("mlx5_core.sf") == -1):
            continue

        try:
            with open("/sys/bus/auxiliary/devices/" + aux_dev + "/sfnum") as f:
                dev_sfnum = f.readline().strip()
        except IOError as e:
            sf_log("I/O error({0}): {1}".format(e.errno, e.strerror))
        except FileNotFoundError:
            sf_log("File {} does not exist".format(fname))
        except:
            sf_log("Unexpected error:", sys.exc_info()[0])

        if str(sfnum) != dev_sfnum:
            continue

        link = os.readlink("/sys/bus/auxiliary/devices/" + aux_dev)
        hirerchy = link.split("/")
        # second last device is the parent pci device
        hirerchy_dev = hirerchy[len(hirerchy) - 2]
        if hirerchy_dev != device:
            continue

        return aux_dev

    return ""


def get_sf_netdev(aux_dev):
    """
    This function returns netdevice of auxiliary SF device
    """
    if len(aux_dev) == 0:
        return ""

    if os.path.exists("/sys/bus/auxiliary/devices/" + aux_dev + "/net/"):
        netdev = os.listdir("/sys/bus/auxiliary/devices/" + aux_dev + "/net/")
        if netdev:
            return netdev[0]

    return ""


def get_sf_rdmadev(aux_dev):
    """
    This function returns rdma device of auxiliary SF device
    """
    if len(aux_dev) == 0:
        return ""

    if os.path.exists("/sys/bus/auxiliary/devices/" + aux_dev + "/infiniband/"):
        netdev = os.listdir("/sys/bus/auxiliary/devices/" + aux_dev + "/infiniband/")
        if netdev:
            return netdev[0]

    return ""


def version():
    """
    Display program version information.
    """
    print(prog + ' ' + __version__)


def get_status_output(cmd, verbose=verbose):
    rc, output = (0, '')

    if verbose:
        print("Running command:", cmd)

    try:
        output = subprocess.check_output(cmd, stderr=subprocess.STDOUT,
                                         shell=True, universal_newlines=True)
    except subprocess.CalledProcessError as e:
        rc, output = (e.returncode, e.output.strip())

    if rc and verbose:
        print("Running {} failed (error[{}])".format(cmd, rc))

    if verbose:
        print("Output:\n", output)

    return rc, output


def sf_log(msg, level=verbose):
    if level:
        print(msg)
    cmd = "logger -t {} -i '{}'".format(prog, msg)
    get_status_output(cmd, False)
    return 0


def verify_args(args):
    rc = 0
    msg = ""
    if (args.action not in SUPPORTED_ACTIONS):
        msg = "ERROR: Action {} is not supported".format(args.action)
        rc = 1
        return rc, msg

    if args.action == 'create':
        if not args.device:
            msg = "ERROR: Action create requires '--device' parameter"
            rc = 1
        if not args.sfnum:
            msg = "ERROR: Action create requires '--sfnum' parameter"
            rc = 1

    elif args.action == 'delete' and not args.sfindex:
        msg = "ERROR: Action delete requires '--sfindex' parameter"
        rc = 1

    return rc, msg


def main():

    global verbose
    rc = 0
    result = {"status": 0, "output": ""}

    if os.geteuid() != 0:
        sys.exit('root privileges are required to run this script!')

    parser = argparse.ArgumentParser(description='Manage SF interfaces')
    parser.add_argument('-d', '--device', help="Network device name")
    parser.add_argument('-a', '--action', required='--version' not in sys.argv, choices=SUPPORTED_ACTIONS, help="Action")
    parser.add_argument('-n', '--sfnum', help="SF number")
    parser.add_argument('-i', '--sfindex', help="SF index")
    parser.add_argument('-m', '--hwaddr', help="SF hw_addr address")
    parser.add_argument('-t', '--enable-trust', action='store_true', help="Set SF as trusted", default=False)
    parser.add_argument('-e', '--enable-eswitch', action='store_true', help="Enable Eswitch on SF", default=False)
    parser.add_argument('-r', '--disable-roce', action='store_true', help="Disable RoCE for SF", default=False)
    parser.add_argument('-c', '--cpu-list', help="Interpret mask as numerical list of cores instead of a \
                                                  bitmask. Numbers are separated by commas and may include \
                                                  ranges. For example: 0,5,8-11.", default=None)
    parser.add_argument('-j', '--json', action='store_true', help="Generate JSON output with show action", default=False)
    parser.add_argument('-p', '--pretty', action='store_true', help="Generate pretty JSON output with show action", default=False)
    parser.add_argument('-v', '--verbose', action='store_true', help="Print verbose information", default=False)
    parser.add_argument('-V', '--version', action='store_true', help='Display program version information and exit')


    args = parser.parse_args()
    if args.version:
        version()
        sys.exit(rc)

    verbose = args.verbose
    if verbose:
        print(args)

    rc, msg = verify_args(args)
    if rc:
        print(msg)
        sys.exit(rc)

    sf = SF(args)

    if sf.action == 'show':
        result = sf.show()
        print(result["output"])
        sys.exit(result['status'])
    elif sf.action == 'create':
        result = sf.create()
        print(result["output"])
        sys.exit(result['status'])
    elif sf.action == 'delete':
        result = sf.delete()
        if args.json:
            print(result)
        else:
            print(result["output"])
        sys.exit(result['status'])

    sys.exit(rc)


if __name__ == '__main__':
        main()
