#! /usr/bin/python3
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

if sys.version_info < (3, 8):
    sys.stderr.write(
        "ERROR: ovs-exporter requires Python 3.8 or later.\n"
        "Current version: Python %d.%d.%d\n"
        "This feature is not supported on operating systems with older Python versions.\n"
        % (sys.version_info.major, sys.version_info.minor, sys.version_info.micro)
    )
    sys.exit(1)

import argparse
import configparser
import signal
import time

try:
    import ovs.daemon
    from ovs.metrics import get_metrics_families, MetricsReadError
    import ovs.vlog
    import ovs.unixctl
    import ovs.unixctl.server
    import ovs.util
except ModuleNotFoundError:
    print(u"""\
ERROR: Missing dependencies.
Please install the Open vSwitch python libraries: python3-doca-openvswitch (version 3.3.0040).
Alternatively, install them from source: ( cd ovs/python ; python3 setup.py install ).
Alternatively, check that your PYTHONPATH is pointing to the correct location.""")
    sys.exit(1)

try:
    import prometheus_client
    from prometheus_client import start_http_server
    from prometheus_client.core import REGISTRY
except Exception as e:
    print("ERROR: Missing python module: %s" % e.name)
    sys.exit(1)

exiting = False
vlog = ovs.vlog.Vlog(ovs.util.PROGRAM_NAME)
DEFAULT_PORT = 6103
DEFAULT_CONFIG_FILE = "%s/%s.conf" % (ovs.dirs.DBDIR, ovs.util.PROGRAM_NAME)

# We do not want to report the metrics related to this web server,
# only the ones from OVS. Unregister the default Collectors.

REGISTRY.unregister(prometheus_client.GC_COLLECTOR)
REGISTRY.unregister(prometheus_client.PLATFORM_COLLECTOR)
REGISTRY.unregister(prometheus_client.PROCESS_COLLECTOR)


class OVSCollector(object):
    """Collector for OVS-DOCA metrics."""

    def __init__(self, extended=False, debug=False, registry=REGISTRY):
        self._extended = extended
        self._debug = debug
        if registry:
            registry.register(self)
        vlog.dbg("Registered OVS metrics collector: extended=%s, debug=%s" % (extended, debug))
        self._inited = True

    def collect(self):
        if hasattr(self, '_inited'):
            vlog.info("Collecting requested metrics")
        try:
            return get_metrics_families(extended=self._extended, debug=self._debug)
        except MetricsReadError as e:
            vlog.err(str(e))
            return {}


def sigint_handler(sig, frame):
    global exiting
    exiting = True
    vlog.info("Exiting")
signal.signal(signal.SIGINT, sigint_handler)


def unixctl_exit(conn, unused_argv, aux):
    global exiting
    vlog.info("Exiting")
    exiting = True
    conn.reply(None)


def uint(s):
    i = int(s)
    if i < 0:
        raise Exception("%s is not a positive integer" % s)
    return i


def push_config_into_args(config, args):
    if not 'exporter' in config:
        return

    exporter = config['exporter']

    if args.port is None and 'port' in exporter:
        try:
            args.port = uint(exporter['port'])
        except Exception as e:
            vlog.err("Failed to parse port: %s" % str(e))

    if not args.extended and 'extended' in exporter:
        args.extended = exporter['extended'] == 'yes'

    if not args.debug and 'debug' in exporter:
        args.debug = exporter['debug'] == 'yes'


def argparse_port(s):
    try:
        p = uint(s)
    except Exception as e:
        raise argparse.ArgumentTypeError(str(e))
    if p > 65535:
        raise argparse.ArgumentTypeError("%s is not within [0, %d]" % (s, 2**16 - 1))
    return p


def main():
    config = configparser.ConfigParser()

    parser = argparse.ArgumentParser(description=u'Open vSwitch metrics exporter.')

    configfile = DEFAULT_CONFIG_FILE

    group = parser.add_argument_group(title="Exporter Options")
    group.add_argument('-p', '--port', type=argparse_port, metavar='<uint>',
                       help='TCP port to listen on')
    group.add_argument('-x', '--extended', action='count', default=0,
                       help='Also export the extended metrics page')
    group.add_argument('-d', '--debug', action='count', default=0,
                       help='Also export the debug metrics page')
    group.add_argument('-c', '--config', nargs="?", default=configfile,
                       help='Read configuration from file (default %s)' % configfile)
    group.add_argument('--unixctl', action='store_true',
                       help='Create a unixctl server')
    group.add_argument('--version', action='store_true',
                       help='Print the version')

    ovs.daemon.add_args(parser)
    ovs.vlog.add_args(parser)
    args = parser.parse_args()

    if args.version:
        print("%s 3.3.0040" % ovs.util.PROGRAM_NAME)
        sys.exit(0)

    ovs.daemon.handle_args(args)
    ovs.vlog.handle_args(args)

    try:
        config.read(args.config)
    except:
        sys.stderr.write("Unable to find configuration file %s: %s" % args.config)

    push_config_into_args(config, args)

    # Set the default values, only after arguments were parsed and
    # potential config read. The 'default' of argparse cannot be used,
    # as it would prevent proper precedence with the configuration.

    args.port = args.port or DEFAULT_PORT
    args.extended = bool(args.extended) or False
    args.debug = bool(args.debug) or False

    ovs.daemon.daemonize_start()

    if args.unixctl:
        error, server = ovs.unixctl.server.UnixctlServer.create(None)
        if error:
            ovs.util.ovs_fatal(error, "could not create unixctl server", vlog)
        ovs.unixctl.command_register("exit", "", 0, 0, unixctl_exit, "aux_exit")
    else:
        server = None

    ovs_collector = OVSCollector(extended=args.extended, debug=args.debug)

    vlog.info("Starting HTTP server on port %d" % args.port)
    try:
        start_http_server(args.port)
    except OSError as e:
        vlog.err("Failed to start HTTP server: %s" % e.strerror)

    ovs.daemon.daemonize_complete()

    if server is not None:
        vlog.info("Entering run loop.")
        poller = ovs.poller.Poller()
        while not exiting:
            server.run()
            server.wait(poller)
            if exiting:
                poller.immediate_wake()
            poller.block()
        server.close()
    else:
        while True and not exiting:
            time.sleep(1)
    vlog.info("Stopping the HTTP server")


if __name__ == '__main__':
    main()

# vi: filetype=python
