#! /usr/bin/env python2.7
#
# Copyright (c) 2011-2022 NVIDIA CORPORATION & AFFILIATES, ALL RIGHTS RESERVED.
#
# This software product is a proprietary product of NVIDIA CORPORATION &
# AFFILIATES (the "Company") and all right, title, and interest in and to the
# software product, including all associated intellectual property rights, are
# and shall remain exclusively with the Company.
#
# This software product is governed by the End User License Agreement
# provided with the software product.
#

"""
Script that takes in an RXPBench verbose output CSV match file from both an RXP
device and a Hyperscan device test.
The results are parsed to verfiy that the matches detected are the same, even if
the total number returned by each are different.
"""

import sys
import csv
import getopt

exact_matches = 0
sub_matches = 0
greedy_matches = 0
rules_matched = {}

"""
Takes in list of tuples of the form 'job_id, rule_id, start_ptr, length'.
Converts this to a 2-D dictionary.
The 'outer' dictionary takes the job_id as a key and returns a dictionary.
This 'inner' dictionary takes the rule_id and returns a list of tuples.
So dic[job_id][rule_id] gets all the tuples with that job/rule combo.
"""
def matches_to_job_id_dic(match_list, job_id_dic):
    # first line should the title - verify and remove it
    if len(match_list[0]) == 4 and match_list[0][0][0] == "#":
        match_list.pop(0)
    else:
        print 'Error: RXPBench csv - First line not a comment'
        sys.exit(-1)

    # sanity check that each row has 4 values
    for match in match_list:
        if len(match) != 4:
            print 'Error: RXPBench csv - Some lines do not have 4 values'
            sys.exit(-1)

    for m in match_list:
       mid = int(m[0])
       rule_id = int(m[1])
       if mid in job_id_dic.keys():
           if (rule_id in job_id_dic[mid].keys()):
               job_id_dic[mid][rule_id].append(m)
           else:
               job_id_dic[mid][rule_id] = [m]
       else:
           job_id_dic[mid] = {}
           job_id_dic[mid][rule_id] = [m]


"""
RXP returns each ungreedy match while HS returns all matches.
so for ab.*cd, if the data is abcdcd, RXP will return abcd while HS will return
abcd and abcdcd.
Because the 2 HS results will have the same start pointer (with -L or 0
otherwise) then we can assume these are the same match as the abcd from RXP.

However, if the data was abcdabcd, RXP would return 2 matches of abcd with
different start pointers but HS would return abcd and abcdabcd with the same
start pointer - the second 'ab' trigger is ignored as it is part of another
match on the same rule id.

Similarly, for rules that start with optionals, e.g. a.*?bcd, RXP may choose to
remove the a.*? and only match on the bcd. HS can pick up any a's if they exist
before the bcd. Here, RXP and HS may have a different start pointer but will
still be in the same start pointer + length field.

We can therefore assume the matches are the same if the start pointer of matches
are the same, or if the RXP start pointer lies within a full match
(start + length) of a HS match.
"""
def compare_sorted_matches_rxp_to_hs(rxp_matches, hs_matches):
    global exact_matches
    global sub_matches

    found_match = False
    probable_match = False
    job_id = rxp_matches[0][0]
    for rxp in rxp_matches:
        found_match = False
        sub_match = False
        rxp_start = int(rxp[2])
        rxp_end = int(rxp[2]) + int(rxp[3])
        for hs in hs_matches:
            hs_start = int(hs[2])
            hs_end = int(hs[2]) + int(hs[3])
            if rxp_start == hs_start:
                # match has same start pointer so it's valid
                found_match = True
                exact_matches += 1
                break
            elif rxp_start > hs_start and rxp_end <= hs_end:
                sub_match = True
                # continue in case there is a full match
            elif rxp_start < hs_start:
                break
        if found_match == False and sub_match == False:
            print 'MATCHES NOT EQUIVALENT'
            print 'RXP match in job %s missed by Hyperscan' % job_id
            print '- rule %s match with starter pointer %s and length %s' \
                      % (int(rxp[1]), int(rxp[2]), int(rxp[3]))
            print '----------------------------------------------'
            sys.exit(-1)
        elif sub_match == True and found_match == False:
            sub_matches += 1


"""
HS will return all matches for a rule, not just the greedy and ungreedy.
If it picks up a match that is not in RXP the match range will for a rule id
will not include a start pointer that is in RXP.
"""
def compare_sorted_matches_hs_to_rxp(rxp_matches, hs_matches):
    global greedy_matches

    all_start_ptrs = set()
    job_id = rxp_matches[0][0]

    for hs in hs_matches:
        hs_start = int(hs[2])
        hs_end = int(hs[2]) + int(hs[3])
        all_start_ptrs.add(hs_start)
        for rxp in rxp_matches:
            rxp_start = int(rxp[2])
            if hs_start <= rxp_start and hs_end >= rxp_start:
                # have a match
                break
            elif hs_end < rxp_start:
                print 'MATCHES NOT EQUIVALENT'
                print 'Hyperscan match in job %s missed by RXP' % job_id
                print '- rule %s match with starter pointer %s and length %s' \
                          % (int(hs[1]), int(hs[2]), int(hs[3]))
                print '----------------------------------------------'
                sys.exit(-1)

    # matches with the same rule_id and same start pointer are the same match
    greedy_matches += (len(hs_matches) - len(all_start_ptrs))


def compare_rule_matches(rxp_rule_matches, hs_rule_matches):
    global rules_matched

    rxp_by_start = sorted(rxp_rule_matches, key=lambda kv: int(kv[2]))
    hs_by_start = sorted(hs_rule_matches, key=lambda kv: int(kv[2]))

    rule_id = int(rxp_by_start[0][1])
    if rule_id in rules_matched:
        rules_matched[rule_id] += len(rxp_by_start)
    else:
        rules_matched[rule_id] = len(rxp_by_start)

    compare_sorted_matches_rxp_to_hs(rxp_by_start, hs_by_start)
    compare_sorted_matches_hs_to_rxp(rxp_by_start, hs_by_start)


def compare_job_matches(job_id, rxp_job_matches, hs_job_matches):
    fail = False

    # verify the job_id matches the same number of rules in both files
    if len(rxp_job_matches.keys()) != len(hs_job_matches.keys()):
        fail = True;

    if fail == False:
        for rule in rxp_job_matches:
            if rule not in hs_job_matches.keys():
                fail = False
                break

            compare_rule_matches(rxp_job_matches[rule], hs_job_matches[rule])

    if fail == True:
        print 'MATCHES NOT EQUIVALENT'
        print 'User_id %s matches different rules in files 1 and 2' % job_id
        print 'Rules matched in RXP:'
        print rxp_job_matches.keys()
        print 'Rules matched in Hyperscan:'
        print hs_job_matches.keys()
        print '----------------------------------------------'
        sys.exit(-1)


def main(argv):
    global exact_matches
    global sub_matches
    global rules_matched
    global greedy_matches

    try:
        opts, args = getopt.getopt(argv,"r:h:",["rxp=","hyperscan="])
    except getopt.GetoptError as err:
        print(err)
        print 'Usage: rxpbench_match_comparison.py -r <rxp.csv> -h <hs.csv>'
        sys.exit(-1)

    rxp_file = ''
    hs_file = ''

    for opt,arg in opts:
        if opt in ("-r", "--rxp"):
            rxp_file = arg
        elif opt in ("-h", "--hyperscan"):
            hs_file = arg
        else:
            sys.exit(-1)

    if rxp_file == "":
        print 'Error: RXP results file required (-r)'
        sys.exit(-1)
    elif hs_file == "":
        print 'Error: Hyperscan results file required (-h)'
        sys.exit(-1)

    # read csv into a list
    rxp_matches = list(csv.reader(open(rxp_file, 'r')))
    hs_matches = list(csv.reader(open(hs_file, 'r')))

    rxp_by_id = {}
    hs_by_id = {}

    print 'Comparing RXP and Hyperscan matches...'

    # convert the input files to dics that allow look up by job and rule id
    matches_to_job_id_dic(rxp_matches, rxp_by_id)
    matches_to_job_id_dic(hs_matches, hs_by_id)

    print '----------------------------------------------'

    # first check that the same job ids get matches
    rxp_tmp = rxp_by_id.keys()
    hs_tmp = []
    for job_id in hs_by_id.keys():
        if job_id in rxp_tmp:
            rxp_tmp.remove(job_id)
        else:
            hs_tmp.append(job_id)

    if len(rxp_tmp) > 0 or len(hs_tmp) > 0:
        print 'MATCHES NOT EQUIVALENT'
        print 'The following user_ids are in RXP but not Hyperscan:'
        print rxp_tmp
        print 'The following user_ids are in Hyperscan but not RXP:'
        print hs_tmp
        print '----------------------------------------------'
        sys.exit(-1)

    del rxp_tmp
    del hs_tmp

    # check the matches for each job individually
    for job_id in hs_by_id.keys():
        compare_job_matches(job_id, rxp_by_id[job_id], hs_by_id[job_id])

    # sort the total matches
    sorted_rules = sorted(rules_matched.items(), key=lambda item: item[1], \
                          reverse=True)

    print 'MATCHES ARE EQUAL'
    print ''
    print 'Summary...'
    print '- Total matches in RXP file: %s' % len(rxp_matches)
    print '- Total matches in Hyperscan file: %s' % len(hs_matches)
    print '- RXP exact matches with Hyperscan: %s' % exact_matches
    print '- RXP sub matches with Hyperscan: %s' % sub_matches
    print '- Hyperscan \'duplicate\' matches: %s' % greedy_matches
    print ''
    print '- Different job ids with matches: %s' % len(rxp_by_id)
    print '- Different rules ids matched (RXP): %s' % len(rules_matched.keys())
    if len(sorted_rules) > 0:
        print '- Most common matches (RXP):'
        print '  - 1. rule id: %s, occurrences: %s' \
                  % (sorted_rules[0][0], sorted_rules[0][1])
    if len(sorted_rules) > 1:
        print '  - 2. rule id: %s, occurrences: %s' \
                  % (sorted_rules[1][0], sorted_rules[1][1])
    if len(sorted_rules) > 2:
        print '  - 3. rule id: %s, occurrences: %s' \
                  % (sorted_rules[2][0], sorted_rules[2][1])
    if len(sorted_rules) > 3:
        print '  - 4. rule id: %s, occurrences: %s' \
                  % (sorted_rules[3][0], sorted_rules[3][1])
    if len(sorted_rules) > 4:
        print '  - 5. rule id: %s, occurrences: %s' \
                  % (sorted_rules[4][0], sorted_rules[4][1])
    print '----------------------------------------------'


if __name__ == "__main__":
    main(sys.argv[1:])
