#! /usr/bin/python3
#
# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import atexit
import fcntl
import os
import signal
import struct
import sys
import time
from typing import Optional, List, Dict, Set

try:
    import ovs.doca_tcpdump_util
    import ovs.dpif_doca_tcpdump_hooks
    import ovs.packet_filter
    from scapy.all import Ether, Raw, wrpcap
    from scapy.packet import Packet
except ModuleNotFoundError as e:
    if "scapy" in str(e).lower():
        print(u"""\
ERROR: Missing Scapy dependency.
Please install Scapy: pip3 install scapy
Alternatively, install it via your package manager: apt-get install python3-scapy""")
    else:
        print(u"""\
ERROR: Missing dependencies.
Please install the Open vSwitch python libraries: python3-doca-openvswitch (version 3.3.0040).
Alternatively, install them from source: ( cd ovs/python ; python3 setup.py install ).
Alternatively, check that your PYTHONPATH is pointing to the correct location.""")
    sys.exit(1)


def parse_interface_config(config_str: str) -> ovs.doca_tcpdump_util.InterfaceConfig:
    """Parse interface configuration string like 'eth0:hook1,hook2' or 'eth0'."""
    if ':' in config_str:
        iface_name, hooks_str = config_str.split(':', 1)
        hooks = set()
        for hook_name in hooks_str.split(','):
            hook_name = hook_name.strip()
            if hook_name:  # Skip empty hook names
                try:
                    hook = ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook.verify_hook_exists(hook_name)
                    hooks.add(hook)
                except ValueError as e:
                    print(f"Error: {e}")
                    sys.exit(1)
        return ovs.doca_tcpdump_util.InterfaceConfig(iface_name.strip(), hooks)
    else:
        return ovs.doca_tcpdump_util.InterfaceConfig(config_str.strip(), {ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook.RX})


def create_argument_parser() -> argparse.ArgumentParser:
    """Create and configure the argument parser."""
    parser = argparse.ArgumentParser(
        prog=os.path.basename(sys.argv[0]),
        description='Dump software traffic from an Open vSwitch port using Scapy',
        add_help=False,  # We'll handle help ourselves
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Interface specification format:
  eth0                    # Single interface, default hook
  eth0:hook1,hook2        # Interface with specific hooks
  eth0+eth1               # Multiple interfaces, default hooks
  eth0:hook1,hook2+eth1:hook3   # Multiple interfaces with specific hooks
  any                     # All interfaces, default hooks

Use plus (+) to separate multiple interfaces
Use comma (,) to separate multiple hooks per interface
Available hooks: """ + ", ".join(h.value for h in ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook) + """

Filter examples:
  tcp port 80             # HTTP traffic
  host 192.168.1.1        # Traffic to/from specific host
  src host 10.0.0.1       # Traffic from specific source
  dst port 443            # HTTPS traffic
  icmp                    # ICMP packets
  tcp and port 22         # SSH traffic
  udp and port 53         # DNS traffic
  not port 22             # Everything except SSH
  tcp[tcpflags] & (tcp-syn|tcp-fin) != 0  # TCP SYN or FIN packets

Filter syntax:
  ovs-doca-tcpdump -i eth0 "tcp port 80"     # Quoted filter
  ovs-doca-tcpdump -i eth0 tcp port 80       # Unquoted filter (multiple args)
  ovs-doca-tcpdump -i eth0 port 80           # Simple unquoted filter
    """
    )

    # Our custom arguments
    parser.add_argument('-i', '--interface',
                       help='Interface specification (see format below)')
    parser.add_argument('--list-interfaces',
                       action='store_true',
                       help='List available OVS interfaces')
    parser.add_argument('--list-hooks',
                       action='store_true',
                       help='List available DPIF-DOCA hooks')
    parser.add_argument('-v', '--verbose',
                       action='store_true',
                       help='Verbose output (show packet details)')
    parser.add_argument('-c', '--count',
                       type=int,
                       help='Stop after capturing specified number of packets')
    parser.add_argument('-t', '--timestamp',
                       action='store_true',
                       help='Print timestamp for each packet')
    parser.add_argument('-x', '--hex',
                       action='store_true',
                       help='Print packet data in hex')
    parser.add_argument('-s', '--snaplen',
                       type=int,
                       default=65535,
                       help='Snapshot length (default: 65535)')
    parser.add_argument('-w', '--write',
                       help='Write packets to pcap file')
    parser.add_argument('filter',
                       nargs='*',
                       action=FilterAction,
                       help='BPF filter expression (e.g., "tcp port 80", "host 192.168.1.1")')
    parser.add_argument('-h', '--help',
                       action='store_true',
                       help='Show this help message and exit')

    return parser


class FilterAction(argparse.Action):
    """Custom action to handle filter arguments flexibly."""

    def __call__(self, parser, namespace, values, option_string=None):
        if isinstance(values, list):
            # Join multiple values with spaces
            filter_string = ' '.join(values)
        else:
            filter_string = values

        # Store the joined filter string
        setattr(namespace, self.dest, filter_string)


def handle_special_commands(args: argparse.Namespace) -> bool:
    """Handle special commands that should exit the program.

    Returns:
        True if the program should exit, False if it should continue.
    """
    if args.list_interfaces:
        try:
            ports = ovs.doca_tcpdump_util.OVSInterfaceManager.list_ports()
            for line in ports:
                print(line)
        except RuntimeError as e:
            print(f"Error: {e}")
        return True

    if args.list_hooks:
        print("Available DOCA tcpdump hooks:")
        for hook in ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook:
            print(f"  {hook.value:<15} - {hook.description}")
        return True

    if args.help:
        parser = create_argument_parser()
        parser.print_help()
        return True

    return False


def get_interface_configs(interface_spec: str) -> Dict[str, ovs.doca_tcpdump_util.InterfaceConfig]:
    """Parse interface configuration from the interface specification string."""
    interfaces = {}

    if not interface_spec or interface_spec == "any":
        return interfaces

    # Split on plus to separate interfaces
    interface_parts = interface_spec.split('+')

    for interface_part in interface_parts:
        interface_part = interface_part.strip()
        if interface_part:
            config = parse_interface_config(interface_part)
            if config.name != "any":
                try:
                    ovs.doca_tcpdump_util.OVSInterfaceManager.verify_interface_exists(config.name)
                except (ValueError, RuntimeError) as e:
                    print(f"Error: {e}")
                    sys.exit(1)
            interfaces[config.name] = config

    return interfaces


def format_timestamp() -> str:
    """Format current timestamp for display."""
    return time.strftime("%H:%M:%S.%f")[:-3]  # Include milliseconds


def print_metadata(interface: str = None, hook: int = None) -> None:
    """Print packet metadata if available."""
    if interface or hook is not None:
        metadata_parts = []
        if interface:
            metadata_parts.append(f"interface={interface}")
        if hook is not None:
            # Convert hook number to enum value
            try:
                # Get all enum values and map by index
                hook_values = list(ovs.dpif_doca_tcpdump_hooks.DOCATcpdumpHook)
                hook_enum = hook_values[hook]
                metadata_parts.append(f"hook={hook_enum.value}")
            except (IndexError, ValueError):
                metadata_parts.append(f"hook=UNKNOWN({hook})")
        print(f"  Metadata: {', '.join(metadata_parts)}")


def print_packet(packet_data: bytes, args: argparse.Namespace, packet_count: int,
                 interface: str = None, hook: int = None) -> None:
    """Print packet using Scapy formatting with metadata."""
    try:
        # Parse packet with Scapy
        packet = Ether(packet_data)

        # Apply snaplen if specified
        if len(packet_data) > args.snaplen:
            packet_data = packet_data[:args.snaplen]
            packet = Ether(packet_data)

        # Print timestamp if requested
        if args.timestamp:
            print(f"[{format_timestamp()}] ", end="")

        # Print packet number
        print(f"Packet #{packet_count}:")

        # Print metadata if available
        print_metadata(interface, hook)

        if args.verbose:
            # Use Scapy's show() method for detailed output
            packet.show()
        else:
            # Simple summary
            print(f"  {packet.summary()}")

        if args.hex:
            # Print hex dump
            print("  Hex dump:")
            hex_data = packet_data.hex()
            for i in range(0, len(hex_data), 32):
                chunk = hex_data[i:i+32]
                hex_line = ' '.join(chunk[j:j+2] for j in range(0, len(chunk), 2))
                print(f"    {i//2:04x}:  {hex_line}")

        print()  # Empty line between packets

    except (struct.error, IndexError, ValueError) as e:
        # If Scapy can't parse the packet, show raw data
        print(f"Packet #{packet_count} (raw data, {len(packet_data)} bytes):")
        if args.timestamp:
            print(f"[{format_timestamp()}] ", end="")

        # Print metadata if available
        print_metadata(interface, hook)

        if args.hex:
            hex_data = packet_data.hex()
            for i in range(0, len(hex_data), 32):
                chunk = hex_data[i:i+32]
                hex_line = ' '.join(chunk[j:j+2] for j in range(0, len(chunk), 2))
                print(f"  {i//2:04x}:  {hex_line}")
        else:
            print(f"  Raw data: {packet_data[:100]}{'...' if len(packet_data) > 100 else ''}")
        print()
    except Exception as e:
        # Log any other unexpected errors but continue processing
        print(f"Error processing packet #{packet_count}: {type(e).__name__}: {e}",
              file=sys.stderr)
        print()


class OVSDocaTcpdump:
    """Main class that orchestrates the DOCA tcpdump functionality."""

    def __init__(self):
        self.doca_manager: Optional[ovs.doca_tcpdump_util.DOCATcpdumpManager] = None
        self.packet_filter: Optional[ovs.packet_filter.PacketFilter] = None
        self.running = False
        self.pcap_file = None
        self.pcap_fd = None

    def setup_signal_handlers(self) -> None:
        """Set up signal handlers for graceful shutdown."""
        signal.signal(signal.SIGINT, self._handle_signal)
        signal.signal(signal.SIGTERM, self._handle_signal)

    def _handle_signal(self, signum: int, frame) -> None:
        """Handle system signals for graceful shutdown."""
        if self.doca_manager:
            self.doca_manager.stop_capture()
        self._close_pcap_file()
        sys.exit(0)

    def process_arguments(self, args: List[str]) -> tuple:
        """Process and validate command line arguments."""
        parser = create_argument_parser()

        # Parse all args and let argparse handle unknown arguments
        try:
            parsed_args = parser.parse_args(args)
        except SystemExit:
            # argparse already printed the error message and called sys.exit()
            sys.exit(1)

        # Handle special commands that should exit
        if handle_special_commands(parsed_args):
            sys.exit(0)

        # Get interface configurations
        interfaces = get_interface_configs(parsed_args.interface)

        # Process filter arguments - join multiple args into single filter string
        filter_string = None
        if parsed_args.filter:
            if isinstance(parsed_args.filter, list):
                # Join multiple values with spaces
                filter_string = ' '.join(parsed_args.filter)
            else:
                filter_string = parsed_args.filter

        # Validate and compile filter if provided
        if filter_string:
            try:
                self.packet_filter = ovs.packet_filter.PacketFilter()
                self.packet_filter.compile_filter(filter_string)
                print(f"Filter compiled successfully: '{filter_string}'")
            except RuntimeError as e:
                print(f"Error compiling filter '{filter_string}': {e}")
                sys.exit(1)

        return interfaces, parsed_args

    def setup_components(self, interfaces: Dict[str, ovs.doca_tcpdump_util.InterfaceConfig]) -> None:
        """Set up all components for packet capture."""
        # Create DOCA manager
        self.doca_manager = ovs.doca_tcpdump_util.DOCATcpdumpManager()
        self.doca_manager.setup_socket_server()

        # Start DOCA tcpdump with interface configurations
        self.doca_manager.start_capture(interfaces)

    def _setup_pcap_file(self, filename: str) -> None:
        """Set up pcap file for writing."""
        try:
            self.pcap_file = filename
            self.pcap_fd = open(filename, 'wb')
            # Write pcap global header (24 bytes)
            # Magic number, version, timezone, accuracy, max packet length, data link type
            header = struct.pack('<LHHLLLL', 0xa1b2c3d4, 2, 4, 0, 0, 65535, 1)
            self.pcap_fd.write(header)
            self.pcap_fd.flush()
        except Exception as e:
            print(f"Error creating pcap file: {e}")
            sys.exit(1)

    def _write_packet_to_pcap(self, packet_data: bytes) -> None:
        """Write a single packet to the pcap file."""
        if self.pcap_fd:
            try:
                # Get current timestamp
                now = time.time()
                ts_sec = int(now)
                ts_usec = int((now - ts_sec) * 1000000)

                # Write pcap packet header (16 bytes)
                # timestamp seconds, timestamp microseconds, captured length, original length
                header = struct.pack('<LLLL', ts_sec, ts_usec, len(packet_data), len(packet_data))
                self.pcap_fd.write(header)

                # Write packet data
                self.pcap_fd.write(packet_data)
                self.pcap_fd.flush()
            except Exception:
                # Skip packets that can't be written
                pass

    def _close_pcap_file(self) -> None:
        """Close the pcap file."""
        if self.pcap_fd:
            self.pcap_fd.close()
            self.pcap_fd = None

    def run_capture_loop(self, args: argparse.Namespace) -> None:
        """Main capture loop that processes incoming packets."""
        self.running = True
        packet_count = 0
        filtered_count = 0
        self.pcap_file = args.write

        if self.pcap_file:
            print(f"Writing packets to: {self.pcap_file}")
            self._setup_pcap_file(self.pcap_file)
        else:
            print("Starting packet capture...")
            if self.packet_filter:
                print(f"Filter: {args.filter}")
            print("Press Ctrl+C to stop")
            print()

        try:
            while self.running:
                batch = self.doca_manager.receive_packet_batch()
                if batch:
                    # Extract metadata from batch
                    interface = batch.get("dev_name")
                    hook = batch.get("hook")

                    for pkt in batch["packets"]:
                        packet_count += 1
                        packet_data = pkt['data']

                        # Apply snaplen if specified
                        if len(packet_data) > args.snaplen:
                            packet_data = packet_data[:args.snaplen]

                        # Apply packet filter if configured
                        if self.packet_filter and not self.packet_filter.matches_packet(packet_data):
                            continue  # Skip this packet

                        filtered_count += 1

                        if self.pcap_file:
                            # Write mode: write packet to pcap file (no metadata in pcap)
                            self._write_packet_to_pcap(packet_data)
                        else:
                            # Display mode: show packets with metadata
                            print_packet(packet_data, args, filtered_count, interface, hook)

                        if args.count and filtered_count >= args.count:
                            print(f"Captured {filtered_count} packets (filtered from {packet_count} total). Stopping.")
                            self.running = False
                            break
        except KeyboardInterrupt:
            print(f"\nCaptured {filtered_count} packets (filtered from {packet_count} total). Stopping.")
            self._handle_signal(signal.SIGINT, None)
        finally:
            self._close_pcap_file()
            self.cleanup()

    def cleanup(self) -> None:
        """Clean up all resources."""
        if self.packet_filter:
            self.packet_filter.cleanup()
        if self.doca_manager:
            self.doca_manager.cleanup()

    def run(self, args: List[str]) -> None:
        """Main entry point to run the application."""
        self.setup_signal_handlers()
        interfaces, parsed_args = self.process_arguments(args)
        self.setup_components(interfaces)
        self.run_capture_loop(parsed_args)


def acquire_lock():
    """Acquire a file lock to prevent parallel execution. Returns lock file descriptor or None."""
    try:
        import ovs.dirs
        lock_file = os.path.join(os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR), 'ovs-doca-tcpdump.lock')
        lock_fd = os.open(lock_file, os.O_CREAT | os.O_WRONLY | os.O_TRUNC)
        fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
        # Write PID to lock file
        os.write(lock_fd, str(os.getpid()).encode())
        os.fsync(lock_fd)
        return lock_fd
    except (OSError, IOError):
        return None

def release_lock(lock_fd):
    """Release the file lock."""
    if lock_fd is not None:
        try:
            fcntl.flock(lock_fd, fcntl.LOCK_UN)
            os.close(lock_fd)
            # Remove lock file
            import ovs.dirs
            lock_file = os.path.join(os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR), 'ovs-doca-tcpdump.lock')
            os.unlink(lock_file)
        except (OSError, IOError):
            pass

# Global variable to hold lock file descriptor
_lock_fd = None

def main():
    """Main function that creates and runs the OVS DOCA tcpdump application."""
    global _lock_fd

    # Try to acquire lock
    _lock_fd = acquire_lock()
    if _lock_fd is None:
        import ovs.dirs
        lock_file = os.path.join(os.environ.get('OVS_RUNDIR', ovs.dirs.RUNDIR), 'ovs-doca-tcpdump.lock')
        print("Error: Another instance of ovs-doca-tcpdump is already running.")
        print(f"Lock file: {lock_file}")
        sys.exit(1)

    # Register cleanup function
    atexit.register(lambda: release_lock(_lock_fd))

    try:
        app = OVSDocaTcpdump()
        app.run(sys.argv[1:])
    finally:
        release_lock(_lock_fd)


if __name__ == "__main__":
    main()
