# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ctypes
import ctypes.util
import os
import sys
from typing import Optional, Union
from ctypes import Structure, c_uint, c_int, c_char_p, c_void_p, POINTER, c_ubyte, c_ushort


class BPFInstruction(Structure):
    """BPF instruction structure matching libpcap's bpf_insn."""
    _fields_ = [
        ("code", c_ushort),    # opcode
        ("jt", c_ubyte),       # jump true
        ("jf", c_ubyte),       # jump false
        ("k", c_uint),         # generic multiuse field
    ]


class BPFProgram(Structure):
    """BPF program structure matching libpcap's bpf_program."""
    _fields_ = [
        ("bf_len", c_uint),    # number of instructions
        ("bf_insns", POINTER(BPFInstruction)),  # array of instructions
    ]


class PcapPacketHeader(Structure):
    """Pcap packet header structure."""
    _fields_ = [
        ("ts_sec", c_uint),    # timestamp seconds
        ("ts_usec", c_uint),   # timestamp microseconds
        ("caplen", c_uint),    # captured length
        ("origlen", c_uint),   # original length
    ]


class PacketFilter:
    """Packet filtering using libpcap via ctypes."""

    def __init__(self):
        self.libpcap = None
        self.bpf_program = None
        self.pcap_handle = None
        self._load_libpcap()

    def _load_libpcap(self) -> None:
        """Load libpcap library and define function signatures."""
        # Try to find libpcap library
        libpcap_path = None

        # Common library names to try
        lib_names = ['libpcap.so.1', 'libpcap.so', 'pcap']

        for lib_name in lib_names:
            try:
                libpcap_path = ctypes.util.find_library(lib_name)
                if libpcap_path:
                    break
            except (OSError, AttributeError):
                continue

        if not libpcap_path:
            # Try direct loading
            for lib_name in lib_names:
                try:
                    self.libpcap = ctypes.CDLL(lib_name)
                    break
                except OSError:
                    continue

        if not self.libpcap and libpcap_path:
            try:
                self.libpcap = ctypes.CDLL(libpcap_path)
            except OSError:
                pass

        if not self.libpcap:
            raise RuntimeError(
                "Could not load libpcap library. "
                "Please install libpcap development package: "
                "apt-get install libpcap-dev (Ubuntu/Debian) or "
                "yum install libpcap-devel (RHEL/CentOS)"
            )

        # Define function signatures
        self._setup_function_signatures()

    def _setup_function_signatures(self) -> None:
        """Set up ctypes function signatures for libpcap functions."""
        # pcap_open_dead
        self.libpcap.pcap_open_dead.argtypes = [c_int, c_int]
        self.libpcap.pcap_open_dead.restype = c_void_p

        # pcap_compile
        self.libpcap.pcap_compile.argtypes = [
            c_void_p,                    # pcap_t *p
            POINTER(BPFProgram),         # struct bpf_program *fp
            c_char_p,                    # const char *str
            c_int,                       # int optimize
            c_uint                       # bpf_u_int32 netmask
        ]
        self.libpcap.pcap_compile.restype = c_int

        # pcap_freecode
        self.libpcap.pcap_freecode.argtypes = [POINTER(BPFProgram)]
        self.libpcap.pcap_freecode.restype = None

        # pcap_close
        self.libpcap.pcap_close.argtypes = [c_void_p]
        self.libpcap.pcap_close.restype = None

        # bpf_filter
        self.libpcap.bpf_filter.argtypes = [
            POINTER(BPFInstruction),     # const struct bpf_insn *insns
            c_char_p,                    # const u_char *pkt
            c_uint,                      # u_int wirelen
            c_uint                       # u_int buflen
        ]
        self.libpcap.bpf_filter.restype = c_uint

    def compile_filter(self, filter_string: str) -> bool:
        """Compile a tcpdump-style filter string.

        Args:
            filter_string: BPF filter expression (e.g., "tcp port 80", "host 192.168.1.1")

        Returns:
            True if compilation successful, False otherwise

        Raises:
            RuntimeError: If filter compilation fails
        """
        if not filter_string or not filter_string.strip():
            # Empty filter means match all packets
            self.bpf_program = None
            return True

        # Create a dummy pcap handle for compilation
        # DLT_ENET10MB = 1 (Ethernet)
        self.pcap_handle = self.libpcap.pcap_open_dead(1, 65535)
        if not self.pcap_handle:
            raise RuntimeError("Failed to create pcap handle for filter compilation")

        # Allocate BPF program structure
        self.bpf_program = BPFProgram()

        # Compile the filter
        filter_bytes = filter_string.encode('utf-8')
        result = self.libpcap.pcap_compile(
            self.pcap_handle,
            ctypes.byref(self.bpf_program),
            filter_bytes,
            1,  # optimize = True
            0   # netmask = 0 (not used for offline compilation)
        )

        if result != 0:
            # Clean up on failure
            self.cleanup()
            raise RuntimeError(f"Failed to compile filter: '{filter_string}'")

        return True

    def matches_packet(self, packet_data: bytes) -> bool:
        """Check if a packet matches the compiled filter.

        Args:
            packet_data: Raw packet data (Ethernet frame)

        Returns:
            True if packet matches filter, False otherwise
        """
        if not self.bpf_program:
            # No filter means match all packets
            return True

        if not packet_data:
            return False

        # Apply BPF filter to packet
        result = self.libpcap.bpf_filter(
            self.bpf_program.bf_insns,
            packet_data,
            len(packet_data),  # wirelen
            len(packet_data)   # buflen
        )

        # bpf_filter returns 0 if packet doesn't match, non-zero if it matches
        return result != 0

    def cleanup(self) -> None:
        """Clean up resources."""
        if self.bpf_program:
            self.libpcap.pcap_freecode(ctypes.byref(self.bpf_program))
            self.bpf_program = None

        if self.pcap_handle:
            self.libpcap.pcap_close(self.pcap_handle)
            self.pcap_handle = None

    def __del__(self):
        """Destructor to ensure cleanup."""
        self.cleanup()


def validate_filter_string(filter_string: str) -> bool:
    """Validate a filter string by attempting to compile it.

    Args:
        filter_string: BPF filter expression to validate

    Returns:
        True if filter is valid, False otherwise
    """
    if not filter_string or not filter_string.strip():
        return True

    try:
        filter_obj = PacketFilter()
        filter_obj.compile_filter(filter_string)
        filter_obj.cleanup()
        return True
    except RuntimeError:
        return False


def get_filter_error_message(filter_string: str) -> str:
    """Get a detailed error message for an invalid filter string.

    Args:
        filter_string: BPF filter expression that failed to compile

    Returns:
        Error message describing why the filter is invalid
    """
    if not filter_string or not filter_string.strip():
        return "Empty filter string"

    try:
        filter_obj = PacketFilter()
        filter_obj.compile_filter(filter_string)
        filter_obj.cleanup()
        return "Filter is valid"
    except RuntimeError as e:
        return str(e)
