# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import socket
import struct
import subprocess
import threading
import time
from typing import Optional, Dict, Any, Set, NamedTuple
from enum import Enum

import ovs.dirs
import ovs.dpif_doca_tcpdump_hooks
from ovs.db import idl
from ovs import jsonrpc
from ovs.poller import Poller
from ovs.stream import Stream
from ovs.fatal_signal import add_hook


class InterfaceConfig(NamedTuple):
    """Configuration for a single interface and its hooks."""
    name: str
    hooks: Set[ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook]

    def __str__(self) -> str:
        hooks_str = ",".join([hook.value for hook in sorted(self.hooks, key=lambda x: x.value)])
        return f"{self.name}:{hooks_str}"


# DOCA Protocol Constants
DOCA_MAX_IFACE_NAME = 50
DOCA_BATCH_HDR_FMT = f"!II{DOCA_MAX_IFACE_NAME}s"  # uint32_t + uint32_t + char[50]
DOCA_PKT_HDR_FMT = "!I"                       # uint32_t
DOCA_BATCH_HDR_SIZE = struct.calcsize(DOCA_BATCH_HDR_FMT)
DOCA_PKT_HDR_SIZE = struct.calcsize(DOCA_PKT_HDR_FMT)


class OVSInterfaceManager:
    """Manages OVS interface operations using OVSDB IDL."""

    _idl_conn = None
    _schema = None

    @classmethod
    def _ensure_db_connection(cls):
        """Ensure database connection is established."""
        if cls._idl_conn is None:
            rundir = os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR)
            db_sock = f'unix:{os.path.join(rundir, "db.sock")}'

            try:
                cls._schema = cls._get_schema(db_sock)
                cls._schema.register_all()
                cls._idl_conn = idl.Idl(db_sock, cls._schema)
                cls._wait_for_db_change()  # Initial sync with DB
            except Exception as e:
                raise RuntimeError(f"Failed to connect to OVS database: {e}")

    @classmethod
    def _get_schema(cls, db_sock):
        """Get the OVS database schema."""
        error, strm = Stream.open_block(Stream.open(db_sock))
        if error:
            raise Exception(f"Unable to connect to {db_sock}")
        rpc = jsonrpc.Connection(strm)
        req = jsonrpc.Message.create_request('get_schema', ['Open_vSwitch'])
        error, resp = rpc.transact_block(req)
        rpc.close()

        if error or resp.error:
            raise Exception('Unable to retrieve schema.')
        return idl.SchemaHelper(None, resp.result)

    @classmethod
    def _wait_for_db_change(cls):
        """Wait for database changes to be synchronized."""
        seq = cls._idl_conn.change_seqno
        stop = time.time() + 10
        while cls._idl_conn.change_seqno == seq and not cls._idl_conn.run():
            poller = Poller()
            cls._idl_conn.wait(poller)
            poller.block()
            if time.time() >= stop:
                raise Exception('Retry Timeout')

    @classmethod
    def _get_table(cls, table_name):
        """Get a table from the OVS database."""
        cls._ensure_db_connection()
        return cls._idl_conn.tables[table_name]

    @classmethod
    def _find_row_by_name(cls, table_name, value):
        """Find a row by name in the specified table."""
        return next(
            (row for row in cls._get_table(table_name).rows.values()
             if row.name == value), None)

    @staticmethod
    def verify_interface_exists(iface: str) -> None:
        """Verify that the specified interface exists in OVS.

        Args:
            iface: Interface name to verify

        Raises:
            ValueError: If the interface doesn't exist in OVS
        """
        port = OVSInterfaceManager._find_row_by_name('Port', iface)
        if not port:
            raise ValueError(f"{iface}: No such device exists")

    @staticmethod
    def list_ports() -> list:
        """List all OVS ports with their status information.

        Returns:
            List of port information strings
        """
        ports_table = OVSInterfaceManager._get_table('Port')
        ports = [row.name for row in ports_table.rows.values() if row.name]

        formatted = []
        for i, port in enumerate(ports, start=1):
            formatted.append(f"{i}.{port} [Up, Running, Connected]")

        formatted.append(f"{len(ports)+1}.any (Pseudo-device that captures on all interfaces) [Up, Running]")

        return formatted

    @classmethod
    def close(cls):
        """Close the IDL connection."""
        if cls._idl_conn:
            cls._idl_conn.close()


class DOCATcpdumpController:
    """Controls DOCA tcpdump capture via ovs-appctl."""

    @staticmethod
    def start_capture(interfaces: Dict[str, InterfaceConfig]) -> None:
        """Start DOCA tcpdump capture with interface configurations.

        Args:
            interfaces: Dictionary mapping interface names to their configurations

        Raises:
            RuntimeError: If starting capture fails
        """
        cmd = ["ovs-appctl", "dpif-doca/tcpdump-set"]
        if interfaces:
            # Add each interface with its hooks as separate arguments
            for iface_name, config in interfaces.items():
                hooks_str = ",".join([hook.value for hook in config.hooks])
                cmd.extend([iface_name, hooks_str])

        try:
            p = subprocess.Popen(cmd,
                                 stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
            out, err = p.communicate()
        except OSError as e:
            raise RuntimeError(f"OSError: {e}")

        if p.returncode != 0:
            raise RuntimeError(f"Failed to start DOCA tcpdump: {err.decode()}")

    @staticmethod
    def stop_capture() -> None:
        """Stop DOCA tcpdump capture."""
        cmd = ["ovs-appctl", "dpif-doca/tcpdump-unset"]

        try:
            p = subprocess.Popen(cmd,
                                 stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
            p.communicate()
        except OSError:
            pass


class DOCASocketServer:
    """Manages Unix domain socket server for DOCA packet reception."""

    def __init__(self, socket_path: str):
        self.socket_path = socket_path
        self.server_socket: Optional[socket.socket] = None
        self.connection_result: Dict[str, Any] = {}

    def create_server(self) -> None:
        """Create a Unix domain socket server."""
        if os.path.exists(self.socket_path):
            os.unlink(self.socket_path)

        self.server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.server_socket.bind(self.socket_path)
        self.server_socket.listen(1)

        os.chmod(self.socket_path, 0o666)  # Read/write for all users

    def start_accept_thread(self) -> None:
        """Start a thread to accept connections."""
        def accept_thread():
            if self.server_socket:
                conn, _ = self.server_socket.accept()
                self.connection_result['conn'] = conn

        threading.Thread(target=accept_thread, daemon=True).start()

    def get_connection(self) -> Optional[socket.socket]:
        """Get the accepted connection if available."""
        return self.connection_result.get('conn')

    def cleanup(self) -> None:
        """Clean up server socket and file."""
        if self.server_socket:
            self.server_socket.close()
        if os.path.exists(self.socket_path):
            os.unlink(self.socket_path)


class DOCAPacketReceiver:
    """Handles receiving and parsing packets from DOCA socket."""

    @staticmethod
    def recv_all(sock: socket.socket, size: int) -> bytes:
        """Receive exactly `size` bytes from socket."""
        buf = b""
        while len(buf) < size:
            chunk = sock.recv(size - len(buf))
            if not chunk:
                raise ConnectionError("Socket closed unexpectedly")
            buf += chunk
        return buf

    @staticmethod
    def recv_batch(sock: socket.socket) -> Dict[str, Any]:
        """Receive and parse one batch of packets from socket."""
        # Read batch header
        raw = DOCAPacketReceiver.recv_all(sock, DOCA_BATCH_HDR_SIZE)
        batch_size, hook, dev_name = struct.unpack(DOCA_BATCH_HDR_FMT, raw)
        dev_name = dev_name.split(b'\x00', 1)[0].decode(errors="ignore")

        packets = []
        # For each packet, read header + data
        for _ in range(batch_size):
            raw = DOCAPacketReceiver.recv_all(sock, DOCA_PKT_HDR_SIZE)
            (pkt_len,) = struct.unpack(DOCA_PKT_HDR_FMT, raw)

            pkt_data = DOCAPacketReceiver.recv_all(sock, pkt_len)
            packets.append({
                "len": pkt_len,
                "data": pkt_data
            })

        return {
            "batch_size": batch_size,
            "hook": hook,
            "dev_name": dev_name,
            "packets": packets
        }


class DOCATcpdumpManager:
    """Main manager class for DOCA tcpdump functionality."""

    def __init__(self):
        self.socket_server: Optional[DOCASocketServer] = None
        self.connection: Optional[socket.socket] = None

    def get_socket_path(self) -> str:
        """Get the path to the DOCA tcpdump Unix domain socket."""
        rundir = os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR)
        return os.path.join(rundir, "doca-tcpdump.sock")

    def setup_socket_server(self) -> None:
        """Set up the Unix domain socket server."""
        socket_path = self.get_socket_path()
        self.socket_server = DOCASocketServer(socket_path)
        self.socket_server.create_server()
        self.socket_server.start_accept_thread()

    def start_capture(self, interfaces: Dict[str, InterfaceConfig]) -> None:
        """Start DOCA tcpdump capture."""
        DOCATcpdumpController.start_capture(interfaces)

    def stop_capture(self) -> None:
        """Stop DOCA tcpdump capture."""
        DOCATcpdumpController.stop_capture()

    def get_connection(self) -> Optional[socket.socket]:
        """Get the socket connection for receiving packets."""
        if self.socket_server:
            return self.socket_server.get_connection()
        return None

    def receive_packet_batch(self) -> Optional[Dict[str, Any]]:
        """Receive a batch of packets from the DOCA socket."""
        if not self.connection:
            self.connection = self.get_connection()

        if self.connection:
            return DOCAPacketReceiver.recv_batch(self.connection)
        return None

    def cleanup(self) -> None:
        """Clean up all resources."""
        if self.connection:
            self.connection.close()
        if self.socket_server:
            self.socket_server.cleanup()
