#!/usr/bin/env python
#
# Copyright (c) 2017 Mellanox Technologies. All rights reserved.
#
# This Software is licensed under one of the following licenses:
#
# 1) under the terms of the "Common Public License 1.0" a copy of which is
#    available from the Open Source Initiative, see
#    http://www.opensource.org/licenses/cpl.php.
#
# 2) under the terms of the "The BSD License" a copy of which is
#    available from the Open Source Initiative, see
#    http://www.opensource.org/licenses/bsd-license.php.
#
# 3) under the terms of the "GNU General Public License (GPL) Version 2" a
#    copy of which is available from the Open Source Initiative, see
#    http://www.opensource.org/licenses/gpl-license.php.
#
# Licensee has the right to choose one of the above licenses.
#
# Redistributions of source code must retain the above copyright
# notice and one of the license notices.
#
# Redistributions in binary form must reproduce both the above copyright
# notice, one of the license notices in the documentation
# and/or other materials provided with the distribution.
#

#################################################################
#					Mellanox Technologies LTD  				    #
# ib2ib_setup: Tool for Generating files requiered by IB-Router #
#															    #
#################################################################

from __future__ import print_function
from pprint import pprint
from optparse import OptionParser
from subprocess import Popen, PIPE
import re
import csv
import os
import fcntl
import socket
import struct
SIOCGIFADDR = 0x8915


""" Utils functions """
def read(path):
		with open(path) as f:
			return f.read()	

def read_subnets(s):
	subnet_list=[]
	subnets=s.split(',')
	for subnet in subnets:
		if '/' not in subnet:
			raise Exception("Please provide mask as /24")
		mask=subnet.split('/')[1]
		subnet=subnet.split('/')[0]
		#print(mask,subnet)
		if (mask=='24'):
			subnet_list.append(subnet)
		else :
			raise Exception('Please provide mask as /24') 
	return subnet_list
	
def chunks(l, n):
		"""Yield successive n-sized chunks from l."""
		for i in range(0, len(l), n):
			yield l[i:i + n]
def dic2csv(my_dict,my_file,no_of_column):
	"""get dict, filename, and nomber of coumn in file. 
		generate csv file with of the dict"""
	with open(my_file, 'wb') as csv_file:
		writer = csv.writer(csv_file, delimiter=' ')
		for key, value in my_dict.items():
			if (key!=None):
				if (no_of_column==3) :
					lid='{0:#06x}'.format(int(value))
					writer.writerow([key,lid,lid ])
				if (no_of_column==2) :
					writer.writerow([key, value])	
					
					
### Collecting info from Network and Generate IP2GID,HOST2IP, GUId2LId FIles.###	
class ib2ib(object):
	def __init__(self, sm=None, device=None,networks=None,filename=None):
		self.sm=sm
		self.device=device
		self.networks=networks
		self.PORT=None
		self.HCA=None
		self.local_lid=None
		self.local_guid=None
		self.local_ip=None
		self.local_hostname=None
		self.guid2lid={}
		self.guid2host={}
		self.guid2ip={}
		self.guid2gid={}
		self.lid2guid={}
		self.ips_list=[]
		self.ip2gid={}
		self.ip2host={}
	def get_local_lid(self):
		path=self.get_dev_path(self.HCA,self.PORT)
		lid_path=path+'/lid'
		hex_lid=read(lid_path)
		lid=int(hex_lid,16)
		self.local_lid=lid
		return  lid	
	def get_local_hostname(self):
		self.local_hostname=socket.gethostname()
		return self.local_hostname
	def get_local_ip(self):
		if_name=self.device
		sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		sockfd = sock.fileno()
		ifreq = struct.pack('16sH14s', if_name, socket.AF_INET, b'\x00'*14)
		try:
			result = fcntl.ioctl(sockfd, SIOCGIFADDR, ifreq)
		except IOError as e :
			raise Exception(e)
		finally:
			sock.close()
		ip = struct.unpack('16sH2x4s8x', result)[2]
		self.local_ip=socket.inet_ntoa(ip)
		return socket.inet_ntoa(ip)
	def get_local_guid(self):
		path=self.get_dev_path(self.HCA,self.PORT)
		guid_path=path+'/gids/0'
		tmp_guid=read(guid_path).strip()
		tmp=tmp_guid.split(':')
		guid="".join(tmp[4:8])
		self.local_guid="0x"+guid.lower()
		return  self.local_guid
	def generate_ip2host(self):
		
		for guid in self.guid2ip:
			if guid in self.guid2host:
				host=self.guid2host[guid]
				ip=self.guid2ip[guid]
				self.ip2host[ip]=host
		return 	self.ip2host	
	def calculate_lid_by_LMC(self,lid):
		pass
	def generate_ip2gid(self):
		for guid in self.guid2ip:
			if  guid in self.guid2gid:
				gid=self.guid2gid[guid]
				ip=self.guid2ip[guid]
				#print("-I-",ip, gid)
				self.ip2gid[ip]=gid
		return self.ip2gid
	def get_dev_path(self,hca,port):
		dev_path='/sys/class/infiniband/{0}/ports/{1}'.format(hca, port)
		return dev_path
	def interface2dev(self):
		if_name=self.device
		try :
			hca = os.listdir('/sys/class/net/{0}/device/infiniband/'.format(if_name))[0]
			port = int(read('/sys/class/net/{0}/dev_id'.format(if_name)), 16) + 1
		except :
			raise Exception("Please make sure you are using the right interface")
		self.HCA=hca
		self.PORT=port
		print("-I- Using " , hca, " Port", port)
		return [hca,port]
	def get_lid_from_guid1(self,hw_addr):
		if ":" in hw_addr :
			segments=hw_addr.split(':')
			lid= segments[7]
		else:
			tmp_lid="0x"+hw_addr[14:18]
			lid=int(tmp_lid,16)
		return int(lid)
	def guid1_2_guid(self,guid1):
		guid=None
		lid=str(self.get_lid_from_guid1(guid1))
		if lid in self.lid2guid:
			guid=self.lid2guid[lid]
			
		else: 
			print("Can't convert gid0 to guid due to New lid. please restart driver on hosts")
		
		return guid
		
	def get_guid2host(self):
		#print("-I-Collecting guids from network")
		host_dict={}
		host_dict[self.local_guid]=self.local_hostname
		p = Popen(['ibnetdiscover','-H','--Ca',self.HCA], stdin=PIPE, stdout=PIPE, stderr=PIPE)
		output, err = p.communicate()
		if (p.returncode==0): 
			for line in output.splitlines():
				q=line.split()
				guid=q[2]
				host_hca_port=q[9]
				pair= guid,host_hca_port
				host_dict[guid]=host_hca_port.replace('"','')
			self.guid2host=host_dict
		else :
			raise Exception('Cannot Collect info from IB Network')
		return p.returncode
	def generate_ips(self):
		"""Generating list of IPs."""
		ips=self.networks
		ips_list=[]
		for ip in ips:
			a,b,c,d=ip.split('.')
			tmp_ip=a+"."+b+"."+c
			for i in range(256):
				cmd=tmp_ip+"."+str(i)
				cmd=cmd.replace(" ","")
				ips_list.append(cmd)
		self.ips_list=ips_list
		return 0
	
        
	def assign_guid_lid(self):
		
		lid=1
		for guid in self.guid2host:
			
			self.guid2lid[guid]=lid
			self.lid2guid[lid]=guid
			lid=lid+1
		return self.guid2lid
	def get_guid2lid(self):
		"""function returns dictionary of guid:lid"""
		#print("-I-Discovering lids on subnet")
		guid_lid={}
		guid_lid[self.local_guid]=str (self.local_lid)
		cmd='ibnetdiscover -p --Ca '+self.HCA
		p = Popen(cmd.split(), stdin=PIPE, stdout=PIPE, stderr=PIPE)
		output, err = p.communicate(b"input data that is passed to subprocess' stdin")
		CA_regx= re.compile(r'\b^CA .*$', flags=re.M)
		CAs=CA_regx.findall(output)
		for CA in CAs :
			CA_as_list= CA.split()
			guid=CA_as_list[3]
			lid=CA_as_list[1]
			guid_lid[guid]=lid
			self.guid2lid[guid]=lid
			self.lid2guid[lid]=guid
			#print("CA :", guid, lid)
		return self.guid2lid	
	def calculate_gid_by_lid(self,str_lid):
		"""take lid and returns GID"""
		sm=self.sm
		my_prefix = "fec0:0000:0000:"
		my_guid = ":0014:0500:0000:"
		lid=int(str_lid)
		hex_lid='{0:#06x}'.format(lid)
		gid=my_prefix+sm.zfill(4)+my_guid+'{0:04x}'.format(lid) 
		return gid
	def get_guid2gid(self):
		for guid in self.guid2lid:
			gid=self.calculate_gid_by_lid(self.guid2lid[guid])
			#print(guid, gid)
			self.guid2gid[guid]=gid
		return self.guid2gid
		
	
	def generate_ips_from_file(self,filename):
		ips_list=[]
		try :
			with open(filename) as f:
				for line in f:
					
					self.ips_list.append(line.strip())
			
		except   :
			raise Exception('Cannot open '+filename)
		return
	def collect_IPs2GUID_per_Subnet(self):
		"""function get list of IPS(subnets), and device name(ib0) and return dictionary of GUID:IP"""
		print("-I- Discovering IPs on Subnet")
		ips=self.networks
		dev=self.device
		processes=[]
		guid_ip_dict={}
		self.guid2ip[self.local_guid]=self.local_ip
		for ips in chunks(self.ips_list,255):
			processes=[]
			for ip in ips:
				cmd=ip
				cmd=cmd.replace(" ","")
				p = Popen('arping -w 3 -c 1  '.split() + ['-I' ,dev,cmd], stdout=PIPE, stderr=PIPE)
				processes.append(p)
			for p in processes:
				output, err = p.communicate()
				if "reply"  in output :
					mac_regx= re.compile('[\:\-]'.join(['([0-9A-F]{1,2})']*20) , re.IGNORECASE)
					ip_regx=re.compile(r"reply .*(\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b)")
					mac=mac_regx.findall(output)[0]
					found_ip=ip_regx.findall(output)[0]
					clean_mac="0x"+''.join(mac)[24:44]
					if mac and clean_mac :
						guid_ip_dict[clean_mac.lower()]=found_ip
						guid1=clean_mac.lower()
						if guid1 in self.guid2lid:
							self.guid2ip[guid1]=found_ip
						else :
							real_guid=self.guid1_2_guid(guid1)
							self.guid2ip[real_guid]=found_ip
		if (len(self.guid2ip)<=1 ):
			print("-W- Please Check Your Network, only ", len(guid_ip_dict), "IP found")
		else:
			print("-I- Total ", len(self.guid2ip), " IPs found on subnet")
		return self.guid2ip	
	
####### Main #######
def main():
	
	parser = OptionParser(usage="usage: %prog [options] ",
                          version="%prog 1.0")
	parser.add_option("-n", "--network",
		      type="string",
		      action="store",
                      dest="address",
                      help="network/Mask to scan for IPs. Format is  A.B.C.D/24. Example :11.130.1.1/24,11.130.2.1/24")
	parser.add_option("-s", "--sm",
                      dest="sm",
		      action="store",
		      type="string",
                      help="subnet number .Unique number for ib subnet.(0-31)",)
	parser.add_option("-d", "--device",
		      type="string",
		      action="store",
                      dest="dev",
                      help="device name. Example: ib0")
	parser.add_option("-f", "--file",type="string", action="store",
			dest="file_ips",
						help="text file which hold IPs.(IP per line)")                 
	(options, args) = parser.parse_args()
    #all([hasattr(options, attr) for attr in ['device', 'sm','address']])
    #if len(vars(options)) != 3:
	#    parser.error("Please proveide the needed info")
	
	addresses=None
	interface=options.dev
	sm_id=options.sm
	filename=None
	filename=options.file_ips
	if (options.address!=None):
		addresses=read_subnets(options.address)
	##initialize ###	
	ibr_setup=ib2ib(sm_id,interface,addresses,filename)
	###Collecting data ###
	ibr_setup.interface2dev()
	ibr_setup.get_local_ip()
	ibr_setup.get_local_guid()
	ibr_setup.get_local_lid()
	ibr_setup.get_local_hostname()
		
	if (filename!=None ):
		print("-I- Reading ips from ", filename)
		ibr_setup.generate_ips_from_file(filename)
	else :
		print("-I- Generating IPs from given subnets")
		ibr_setup.generate_ips()
	
	if (ibr_setup.get_guid2host()!=0 ): 
		raise Exception("Is SM down? Please check your SM")
	guid2lid=ibr_setup.get_guid2lid()
	#guid2lid=ibr_setup.assign_guid_lid()
	ibr_setup.get_guid2gid()
	ibr_setup.collect_IPs2GUID_per_Subnet()
	### file writing ###
	dic2csv(guid2lid,'guid2lid',3)
	ip2gid=ibr_setup.generate_ip2gid()
	dic2csv(ip2gid,'ip2gid.db',2)
	
	ip2hosts=ibr_setup.generate_ip2host()
	dic2csv(ip2hosts,'hosts',2)
if __name__ == '__main__':
	try:
		main()
	except Exception as  e : 
		print("-E- ", e)
	else :
		print("-I- files created :ip2gid.db, guid2lid, hosts")
		print("Completed successfully")
		
