Commit 215be907 authored by Pedro Oliveira's avatar Pedro Oliveira

first try... its not working yet... just backup

parent bcefc317
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
import Main
import traceback
class Assert:
TYPE = 5
def __init__(self):
Main.add_protocol(Assert.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
interface_name = interface.interface_name
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_assert = packet.payload.payload # type: PacketPimAssert
metric = pkt_assert.metric
metric_preference = pkt_assert.metric_preference
source = pkt_assert.source_address
group = pkt_assert.multicast_group_address
source_group = (source, group)
interface_name = packet.interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
try:
#Main.kernel.routing[source_group].recv_assert_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_assert_msg(interface_index, packet)
except:
traceback.print_exc()
"""Generic linux daemon base class for python 3.x."""
import sys, os, time, atexit, signal
class Daemon:
"""A generic Daemon class.
Usage: subclass the Daemon class and override the run() method."""
def __init__(self, pidfile): self.pidfile = pidfile
def daemonize(self):
"""Deamonize class. UNIX double fork mechanism."""
try:
pid = os.fork()
if pid > 0:
# exit first parent
sys.exit(0)
except OSError as err:
sys.stderr.write('fork #1 failed: {0}\n'.format(err))
sys.exit(1)
# decouple from parent environment
#os.chdir('/')
#os.setsid()
#os.umask(0)
# do second fork
try:
pid = os.fork()
if pid > 0:
# exit from second parent
sys.exit(0)
except OSError as err:
sys.stderr.write('fork #2 failed: {0}\n'.format(err))
sys.exit(1)
# redirect standard file descriptors
sys.stdout.flush()
sys.stderr.flush()
si = open(os.devnull, 'r')
so = open('stdout', 'a+')
se = open('stderror', 'a+')
os.dup2(si.fileno(), sys.stdin.fileno())
os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno())
# write pidfile
atexit.register(self.delpid)
pid = str(os.getpid())
with open(self.pidfile, 'w+') as f:
f.write(pid + '\n')
def delpid(self):
os.remove(self.pidfile)
def start(self):
"""Start the Daemon."""
# Check for a pidfile to see if the Daemon already runs
if self.is_running():
message = "pidfile {0} already exist. " + \
"Daemon already running?\n"
sys.stderr.write(message.format(self.pidfile))
sys.exit(1)
# Start the Daemon
self.daemonize()
self.run()
def stop(self):
"""Stop the Daemon."""
# Get the pid from the pidfile
try:
with open(self.pidfile, 'r') as pf:
pid = int(pf.read().strip())
except IOError:
pid = None
if not pid:
message = "pidfile {0} does not exist. " + \
"Daemon not running?\n"
sys.stderr.write(message.format(self.pidfile))
return # not an error in a restart
# Try killing the Daemon process
try:
while 1:
#os.killpg(os.getpgid(pid), signal.SIGTERM)
os.kill(pid, signal.SIGTERM)
time.sleep(0.1)
except OSError as err:
e = str(err.args)
if e.find("No such process") > 0:
if os.path.exists(self.pidfile):
os.remove(self.pidfile)
else:
print(str(err.args))
sys.exit(1)
def restart(self):
"""Restart the Daemon."""
self.stop()
self.start()
def run(self):
"""You should override this method when you subclass Daemon.
It will be called after the process has been daemonized by
start() or restart()."""
def is_running(self):
try:
with open(self.pidfile, 'r') as pf:
pid = int(pf.read().strip())
except IOError:
return False
""" Check For the existence of a unix pid. """
try:
os.kill(pid, 0)
return True
except:
return False
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
class Graft:
TYPE = 6
def __init__(self):
Main.add_protocol(Graft.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
print("GRAFT!!")
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_graft_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
class GraftAck:
TYPE = 7
def __init__(self):
Main.add_protocol(GraftAck.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
print("GRAFT ACK!!")
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
from Packet.ReceivedPacket import ReceivedPacket
import Main
from Neighbor import Neighbor
class Hello:
TYPE = 0
TRIGGERED_HELLO_DELAY = 16 # TODO: configure via external file??
def __init__(self):
Main.add_protocol(Hello.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
options = packet.payload.payload.get_options()
if (1 in options) and (20 in options):
hello_hold_time = options[1]
generation_id = options[20]
else:
raise Exception
with interface.neighbors_lock.genWlock():
if ip in interface.neighbors:
neighbor = interface.neighbors[ip]
else:
interface.neighbors[ip] = Neighbor(interface, ip, generation_id, hello_hold_time)
return
neighbor.receive_hello(generation_id, hello_hold_time)
"""
with neighbor.neighbor_lock:
# Already know Neighbor
print("neighbor conhecido")
neighbor.heartbeat()
if neighbor.hello_hold_time != hello_hold_time:
print("keep alive period diferente")
neighbor.set_hello_hold_time(hello_hold_time)
if neighbor.generation_id != generation_id:
print("neighbor reiniciado")
neighbor.set_generation_id(generation_id)
with interface.neighbors_lock.genWlock():
#if interface.get_neighbor(ip) is None:
if ip in interface.neighbors:
# Unknown Neighbor
if (1 in options) and (20 in options):
try:
#Main.add_neighbor(packet.interface, ip, options[20], options[1])
print("non neighbor and options inside")
except Exception:
# Received Neighbor with Timeout
print("non neighbor and options inside but neighbor timedout")
pass
return
print("non neighbor and required options not inside")
else:
# Already know Neighbor
print("neighbor conhecido")
neighbor = Main.get_neighbor(ip)
neighbor.heartbeat()
if 1 in options and neighbor.hello_hold_time != options[1]:
print("keep alive period diferente")
neighbor.set_hello_hold_time(options[1])
if 20 in options and neighbor.generation_id != options[20]:
print("neighbor reiniciado")
neighbor.remove()
Main.add_neighbor(packet.interface, ip, options[20], options[1])
"""
\ No newline at end of file
from Packet.ReceivedPacket import ReceivedPacket
from utils import *
from ipaddress import IPv4Address
class IGMP:
# receive handler
@staticmethod
def receive_handle(packet: ReceivedPacket):
interface = packet.interface
ip_src = packet.ip_header.ip_src
ip_dst = packet.ip_header.ip_dst
#print("ip = ", ip_src)
igmp_hdr = packet.payload
igmp_type = igmp_hdr.type
igmp_group = igmp_hdr.group_address
# source ip can't be 0.0.0.0 or multicast
if ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast:
return
if igmp_type == Version_1_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v1_membership_report(packet)
elif igmp_type == Version_2_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v2_membership_report(packet)
elif igmp_type == Leave_Group and ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_leave_group(packet)
elif igmp_type == Membership_Query and (ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0")):
interface.interface_state.receive_query(packet)
else:
raise Exception("Exception igmp packet: type={}; ip_dst={}; packet_group_report={}".format(igmp_type, ip_dst, igmp_group))
import socket
import threading
import random
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from RWLock.RWLock import RWLockWrite
class Interface(object):
MCAST_GRP = '224.0.0.13'
# substituir ip por interface ou algo parecido
def __init__(self, interface_name: str):
self.interface_name = interface_name
ip_interface = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']
self.ip_interface = ip_interface
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_PIM)
# allow other sockets to bind this port too
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# explicitly join the multicast group on the interface specified
#s.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8'))
# set socket output interface
s.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(ip_interface))
# set socket TTL to 1
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
# don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 0)
self.socket = s
self.interface_enabled = True
# generation id
#self.generation_id = random.getrandbits(32)
# todo neighbors
#self.neighbors = {}
#self.neighbors_lock = RWLockWrite()
# run receive method in background
#receive_thread = threading.Thread(target=self.receive)
#receive_thread.daemon = True
#receive_thread.start()
def receive(self):
try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
else:
packet = None
return packet
except Exception:
return None
"""
while self.interface_enabled:
try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
except Exception:
traceback.print_exc()
continue
"""
def send(self, data: bytes, group_ip: str):
if self.interface_enabled and data:
self.socket.sendto(data, (group_ip, 0))
def remove(self):
self.interface_enabled = False
try:
self.socket.shutdown(socket.SHUT_RDWR)
except Exception:
pass
self.socket.close()
def is_enabled(self):
return self.interface_enabled
def get_ip(self):
return self.ip_interface
"""
def add_neighbor(self, ip, random_number, hello_hold_time):
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
print("ADD NEIGHBOR")
from Neighbor import Neighbor
n = Neighbor(self, ip, random_number, hello_hold_time)
self.neighbors[ip] = n
Main.protocols[0].force_send(self)
def get_neighbors(self):
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
with self.neighbors_lock.genRlock():
return self.neighbors[ip]
"""
\ No newline at end of file
import socket
import struct
import threading
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25
class InterfaceIGMP(object):
ETH_P_IP = 0x0800 # Internet Protocol packet
PACKET_MR_ALLMULTI = 2
def __init__(self, interface_name: str):
# RECEIVE SOCKET
rcv_s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
# allow all multicast packets
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.PACKET_MR_ALLMULTI, struct.pack("i HH BBBBBBBB", 0, InterfaceIGMP.PACKET_MR_ALLMULTI, 0, 0,0,0,0,0,0,0,0))
# bind to interface
rcv_s.bind((interface_name, 0))
self.recv_socket = rcv_s
# SEND SOCKET
snd_s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# bind to interface
snd_s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, str(interface_name + "\0").encode('utf-8'))
self.send_socket = snd_s
self.interface_enabled = True
self.interface_name = interface_name
from igmp.RouterState import RouterState
self.interface_state = RouterState(self)
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
def send(self, data: bytes, address: str="224.0.0.1"):
if self.interface_enabled:
self.send_socket.sendto(data, (address, 0))
def receive(self):
while self.interface_enabled:
try:
(raw_packet, x) = self.recv_socket.recvfrom(256 * 1024)
if raw_packet:
raw_packet = raw_packet[14:]
from Packet.PacketIpHeader import PacketIpHeader
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, raw_packet[:PacketIpHeader.IP_HDR_LEN])
#print(proto)
if proto != socket.IPPROTO_IGMP:
continue
#print((raw_packet, x))
packet = ReceivedPacket(raw_packet, self)
Main.igmp.receive_handle(packet)
except Exception:
traceback.print_exc()
continue
def remove(self):
self.interface_enabled = False
self.recv_socket.close()
self.send_socket.close()
import threading
import random
from Interface import Interface
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from RWLock.RWLock import RWLockWrite
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Packet.Packet import Packet
from Hello import Hello
from utils import HELLO_HOLD_TIME_TIMEOUT
from threading import Timer
class InterfacePim(Interface):
MCAST_GRP = '224.0.0.13'
HELLO_PERIOD = 30
PROPAGATION_DELAY = 0.5
OVERRIDE_INTERNAL = 2.5
MAX_TRIGGERED_HELLO_PERIOD = 5
def __init__(self, interface_name: str):
super().__init__(interface_name)
# generation id
self.generation_id = random.getrandbits(32)
# When PIM is enabled on an interface or when a router first starts, the Hello Timer (HT)
# MUST be set to random value between 0 and Triggered_Hello_Delay
hello_timer_time = random.uniform(0, Hello.TRIGGERED_HELLO_DELAY)
self.hello_timer = Timer(hello_timer_time, self.send_hello)
self.hello_timer.start()
# todo: state refresh capable
self._state_refresh_capable = False
# todo: lan delay enabled
self._lan_delay_enabled = False
# todo: propagation delay
self._propagation_delay = self.PROPAGATION_DELAY
# todo: override interval
self._override_interval = self.OVERRIDE_INTERNAL
# pim neighbors
self.neighbors = {}
self.neighbors_lock = RWLockWrite()
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def receive(self):
while self.is_enabled():
try:
packet = super().receive()
if packet:
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet)
except:
traceback.print_exc()
continue
"""
while self.interface_enabled:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
except Exception:
traceback.print_exc()
continue
"""
def send(self, data: bytes, group_ip: str=MCAST_GRP):
super().send(data=data, group_ip=group_ip)
def send_hello(self):
self.hello_timer.cancel()
pim_payload = PacketPimHello()
pim_payload.add_option(1, 3.5 * Hello.TRIGGERED_HELLO_DELAY)
pim_payload.add_option(20, self.generation_id)
ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph)
self.send(packet.bytes())
# reschedule hello_timer
self.hello_timer = Timer(Hello.TRIGGERED_HELLO_DELAY, self.send_hello)
self.hello_timer.start()
def remove(self):
self.hello_timer.cancel()
self.hello_timer = None
# send pim_hello timeout message
pim_payload = PacketPimHello()
pim_payload.add_option(1, HELLO_HOLD_TIME_TIMEOUT)
pim_payload.add_option(20, self.generation_id)
ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph)
self.send(packet.bytes())
super().remove()
def add_neighbor(self, ip, random_number, hello_hold_time):
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
print("ADD NEIGHBOR")
from Neighbor import Neighbor
n = Neighbor(self, ip, random_number, hello_hold_time)
self.neighbors[ip] = n
Main.protocols[0].force_send(self)
def get_neighbors(self):
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
with self.neighbors_lock.genRlock():
return self.neighbors[ip]
def remove_neighbor(self, ip):
with self.neighbors_lock.genWlock():
del self.neighbors[ip]
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from Interface import Interface
import Main
import traceback
class JoinPrune:
TYPE = 3
def __init__(self):
Main.add_protocol(JoinPrune.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
# todo holdtime
holdtime = pkt_join_prune.hold_time
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
pruned_src_addresses = group.pruned_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_join_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
for source_address in pruned_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_prune_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
import socket
import struct
import netifaces
import threading
import traceback
from RWLock.RWLock import RWLockWrite
import Main
from tree.tree_if_upstream import *
from tree.tree_if_downstream import *
from tree.KernelEntry import KernelEntry
"""
class KernelEntry:
def __init__(self, source_ip: str, group_ip: str, inbound_interface_index: int):
self.source_ip = source_ip
self.group_ip = group_ip
# decide inbound interface based on rpf check
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
# all other interfaces = outbound
#self.outbound_interfaces = [1] * Kernel.MAXVIFS
#self.outbound_interfaces[self.inbound_interface_index] = 0
self._lock = threading.Lock()
# todo
self.state = {} # type: Dict[int, SFRMTreeInterface]
for i in range(Kernel.MAXVIFS):
if i == self.inbound_interface_index:
self.state[i] = SFRMRootInterface(self, i, False)
else:
self.state[i] = SFRMNonRootInterface(self, i)
def lock(self):
self._lock.acquire()
def unlock(self):
self._lock.release()
def get_inbound_interface_index(self):
return self.inbound_interface_index
def get_outbound_interfaces_indexes(self):
# todo check state of outbound interfaces
outbound_indexes = [0]*Kernel.MAXVIFS
for (index, state) in self.state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
def check_rpf(self):
from pyroute2 import IPRoute
# from utils import if_indextoname
ipr = IPRoute()
# obter index da interface
# rpf_interface_index = ipr.get_routes(family=socket.AF_INET, dst=ip)[0]['attrs'][2][1]
# interface_name = if_indextoname(rpf_interface_index)
# return interface_name
# obter ip da interface de saida
rpf_interface_source = ipr.get_routes(family=socket.AF_INET, dst=socket.inet_ntoa(self.source_ip))[0]['attrs'][3][1]
return rpf_interface_source
def recv_data_msg(self, index):
self.state[index].recv_data_msg()
def recv_assert_msg(self, index, packet):
self.state[index].recv_assert_msg(packet, None)
def recv_prune_msg(self, index, packet):
self.state[index].recv_prune_msg(None, None)
def recv_join_msg(self, index, packet):
self.state[index].recv_join_msg(None, None)
def change(self):
# todo: changes on unicast routing or multicast routing...
Main.kernel.set_multicast_route(self)
def delete(self):
Main.kernel.remove_multicast_route(self)
"""
class Kernel:
# MRT
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */
MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */
MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */
MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */
#MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */
#MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */
#MRT_MAX = (MRT_BASE + 11)
# Max Number of Virtual Interfaces
MAXVIFS = 32
# SIGNAL MSG TYPE
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
def __init__(self):
# Kernel is running
self.running = True
# KEY : interface_ip, VALUE : vif_index
self.vif_dic = {}
self.vif_index_to_name_dic = {}
self.vif_name_to_index_dic = {}
# KEY : (source_ip, group_ip), VALUE : KernelEntry ???? TODO
self.routing = {}
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# MRT INIT
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_PIM, 0)
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ASSERT, 1)
self.socket = s
self.rwlock = RWLockWrite()
# Create virtual interfaces
interfaces = netifaces.interfaces()
for interface in interfaces:
try:
# ignore localhost interface
if interface == 'lo':
continue
addrs = netifaces.ifaddresses(interface)
addr = addrs[netifaces.AF_INET][0]['addr']
self.create_virtual_interface(ip_interface=addr, interface_name=interface)
except Exception:
continue
# receive signals from kernel with a background thread
handler_thread = threading.Thread(target=self.handler)
handler_thread.daemon = True
handler_thread.start()
'''
Structure to create/remove virtual interfaces
struct vifctl {
vifi_t vifc_vifi; /* Index of VIF */
unsigned char vifc_flags; /* VIFF_ flags */
unsigned char vifc_threshold; /* ttl limit */
unsigned int vifc_rate_limit; /* Rate limiter values (NI) */
union {
struct in_addr vifc_lcl_addr; /* Local interface address */
int vifc_lcl_ifindex; /* Local interface index */
};
struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */
};
'''
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index: int = None, flags=0x0):
if type(ip_interface) is str:
ip_interface = socket.inet_aton(ip_interface)
if index is None:
index = len(self.vif_dic)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface, socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_dic[socket.inet_ntoa(ip_interface)] = index
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
def remove_virtual_interface(self, ip_interface):
index = self.vif_dic[ip_interface]
struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_VIF, struct_vifctl)
del self.vif_dic[ip_interface]
del self.vif_name_to_index_dic[self.vif_index_to_name_dic[index]]
del self.vif_index_to_name_dic[index]
# TODO alterar MFC's para colocar a 0 esta interface
'''
/* Cache manipulation structures for mrouted and PIMd */
struct mfcctl {
struct in_addr mfcc_origin; /* Origin of mcast */
struct in_addr mfcc_mcastgrp; /* Group in question */
vifi_t mfcc_parent; /* Where it arrived */
unsigned char mfcc_ttls[MAXVIFS]; /* Where it is going */
unsigned int mfcc_pkt_cnt; /* pkt count for src-grp */
unsigned int mfcc_byte_cnt;
unsigned int mfcc_wrong_if;
int mfcc_expire;
};
'''
def set_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip)
group_ip = socket.inet_aton(kernel_entry.group_ip)
outbound_interfaces = kernel_entry.get_outbound_interfaces_indexes()
if len(outbound_interfaces) != Kernel.MAXVIFS:
raise Exception
#outbound_interfaces_and_other_parameters = list(kernel_entry.outbound_interfaces) + [0]*4
outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, kernel_entry.inbound_interface_index, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
# TODO: ver melhor tabela routing
#self.routing[(socket.inet_ntoa(source_ip), socket.inet_ntoa(group_ip))] = {"inbound_interface_index": inbound_interface_index, "outbound_interfaces": outbound_interfaces}
def flood(self, ip_src, ip_dst, iif):
source_ip = socket.inet_aton(ip_src)
group_ip = socket.inet_aton(ip_dst)
outbound_interfaces = [1]*Kernel.MAXVIFS
outbound_interfaces[iif] = 0
outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, iif, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip)
group_ip = socket.inet_aton(kernel_entry.group_ip)
outbound_interfaces_and_other_parameters = [0] + [0]*Kernel.MAXVIFS + [0]*4
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_MFC, struct_mfcctl)
def exit(self):
self.running = False
# MRT DONE
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DONE, 1)
self.socket.close()
'''
/* This is the format the mroute daemon expects to see IGMP control
* data. Magically happens to be like an IP packet as per the original
*/
struct igmpmsg {
__u32 unused1,unused2;
unsigned char im_msgtype; /* What is this */
unsigned char im_mbz; /* Must be zero */
unsigned char im_vif; /* Interface (this ought to be a vifi_t!) */
unsigned char unused3;
struct in_addr im_src,im_dst;
};
'''
def handler(self):
while self.running:
try:
msg = self.socket.recv(5000)
#print(len(msg))
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst) = struct.unpack("II B B B B 4s 4s", msg[:20])
print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
if im_mbz != 0:
continue
print(im_msgtype)
print(im_mbz)
print(im_vif)
print(socket.inet_ntoa(im_src))
print(socket.inet_ntoa(im_dst))
#print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
ip_src = socket.inet_ntoa(im_src)
ip_dst = socket.inet_ntoa(im_dst)
if im_msgtype == Kernel.IGMPMSG_NOCACHE:
print("IGMP NO CACHE")
self.igmpmsg_nocache_handler(ip_src, ip_dst, im_vif)
elif im_msgtype == Kernel.IGMPMSG_WRONGVIF:
print("WRONG VIF HANDLER")
self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif)
else:
raise Exception
except Exception:
traceback.print_exc()
continue
# receive multicast (S,G) packet and multicast routing table has no (S,G) entry
def igmpmsg_nocache_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst)
"""
with self.rwlock.genWlock():
if source_group_pair in self.routing:
kernel_entry = self.routing[(ip_src, ip_dst)]
else:
kernel_entry = KernelEntry(ip_src, ip_dst, iif)
self.routing[(ip_src, ip_dst)] = kernel_entry
self.set_multicast_route(kernel_entry)
kernel_entry.recv_data_msg(iif)
"""
"""
with self.rwlock.genRlock():
if source_group_pair in self.routing:
kernel_entry = self.routing[(ip_src, ip_dst)]
with self.rwlock.genWlock():
if source_group_pair in self.routing:
kernel_entry = self.routing[(ip_src, ip_dst)]
else:
kernel_entry = KernelEntry(ip_src, ip_dst, iif)
self.routing[(ip_src, ip_dst)] = kernel_entry
self.set_multicast_route(kernel_entry)
kernel_entry.recv_data_msg(iif)
"""
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
# receive multicast (S,G) packet in a outbound_interface
def igmpmsg_wrongvif_handler(self, ip_src, ip_dst, iif):
#kernel_entry = self.routing[(ip_src, ip_dst)]
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
#kernel_entry.recv_data_msg(iif)
"""
def get_routing_entry(self, source_group: tuple):
with self.rwlock.genRlock():
return self.routing[source_group]
"""
def get_routing_entry(self, source_group: tuple, create_if_not_existent=False):
ip_src = source_group[0]
ip_dst = source_group[1]
with self.rwlock.genRlock():
if source_group in self.routing:
return self.routing[(ip_src, ip_dst)]
with self.rwlock.genWlock():
if source_group in self.routing:
return self.routing[(ip_src, ip_dst)]
elif create_if_not_existent:
kernel_entry = KernelEntry(ip_src, ip_dst, 0)
self.routing[source_group] = kernel_entry
#self.set_multicast_route(kernel_entry)
return kernel_entry
else:
return None
def neighbor_removed(self, interface_name, neighbor_ip):
# todo
pass
import netifaces
import time
from prettytable import PrettyTable
from InterfacePIM import InterfacePim
from InterfaceIGMP import InterfaceIGMP
from Kernel import Kernel
from threading import Lock
import UnicastRouting
interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces
protocols = {}
kernel = None
igmp = None
def add_interface(interface_name, pim=False, igmp=False):
if pim is True and interface_name not in interfaces:
interface = InterfacePim(interface_name)
interfaces[interface_name] = interface
if igmp is True and interface_name not in igmp_interfaces:
interface = InterfaceIGMP(interface_name)
igmp_interfaces[interface_name] = interface
def remove_interface(interface_name, pim=False, igmp=False):
if pim is True and ((interface_name in interfaces) or interface_name == "*"):
if interface_name == "*":
interface_name_list = list(interfaces.keys())
else:
interface_name_list = [interface_name]
for if_name in interface_name_list:
interface_obj = interfaces.pop(if_name)
interface_obj.remove()
#interfaces[if_name].remove()
#del interfaces[if_name]
print("removido interface")
print(interfaces)
if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"):
if interface_name == "*":
interface_name_list = list(igmp_interfaces.keys())
else:
interface_name_list = [interface_name]
for if_name in interface_name_list:
igmp_interfaces[if_name].remove()
del igmp_interfaces[if_name]
print("removido interface")
print(igmp_interfaces)
"""
def add_neighbor(contact_interface, ip, random_number, hello_hold_time):
global neighbors
with neighbors_lock:
if ip not in neighbors:
print("ADD NEIGHBOR")
n = Neighbor(contact_interface, ip, random_number, hello_hold_time)
neighbors[ip] = n
protocols[0].force_send(contact_interface)
# todo check neighbor in interface
contact_interface.neighbors[ip] = n
def get_neighbor(ip) -> Neighbor:
global neighbors
with neighbors_lock:
if ip not in neighbors:
return None
return neighbors[ip]
def remove_neighbor(ip):
global neighbors
with neighbors_lock:
if ip in neighbors:
del neighbors[ip]
print("removido neighbor")
"""
def add_protocol(protocol_number, protocol_obj):
global protocols
protocols[protocol_number] = protocol_obj
def list_neighbors():
interfaces_list = interfaces.values()
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
check_time = time.time()
for interface in interfaces_list:
for neighbor in interface.get_neighbors():
uptime = check_time - neighbor.time_of_last_update
uptime = 0 if (uptime < 0) else uptime
t.add_row(
[interface.interface_name, neighbor.ip, neighbor.hello_hold_time, neighbor.generation_id, time.strftime("%H:%M:%S", time.gmtime(uptime))])
print(t)
return str(t)
def list_enabled_interfaces():
global interfaces
# TESTE DE PIM JOIN/PRUNE
for interface in interfaces:
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
ph = PacketPimJoinPrune("10.0.0.13", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1", "10.1.1.1"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
ph = PacketPimJoinPrune("ff08::1", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("2001:1:a:b:c::1", ["1.1.1.1", "2001:1:a:b:c::2"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1"], ["2001:1:a:b:c::3"]))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
from Packet.PacketPimAssert import PacketPimAssert
ph = PacketPimAssert("224.12.12.12", "10.0.0.2", 210, 2)
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
from Packet.PacketPimGraft import PacketPimGraft
ph = PacketPimGraft("10.0.0.13")
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'IGMP State'])
for interface in netifaces.interfaces():
try:
# TODO: fix same interface with multiple ips
ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr']
pim_enabled = interface in interfaces
igmp_enabled = interface in igmp_interfaces
enabled = str(pim_enabled) + "/" + str(igmp_enabled)
if igmp_enabled:
state = igmp_interfaces[interface].interface_state.print_state()
else:
state = "-"
t.add_row([interface, ip, enabled, state])
except Exception:
continue
print(t)
return str(t)
def list_state():
state_text = "IGMP State:\n" + list_igmp_state() + "\n\n\n\n" + "Multicast Routing State:\n" + list_routing_state()
return state_text
def list_igmp_state():
t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState'])
for (interface_name, interface_obj) in list(igmp_interfaces.items()):
interface_state = interface_obj.interface_state
state_txt = interface_state.print_state()
print(interface_state.group_state.items())
for (group_addr, group_state) in list(interface_state.group_state.items()):
print(group_addr)
group_state_txt = group_state.print_state()
t.add_row([interface_name, state_txt, group_addr, group_state_txt])
return str(t)
def list_routing_state():
routing_entries = kernel.routing.values()
vif_indexes = kernel.vif_index_to_name_dic.keys()
t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', "Is Forwarding?"])
for entry in routing_entries:
ip = entry.source_ip
group = entry.group_ip
upstream_if_index = entry.inbound_interface_index
for index in vif_indexes:
interface_state = entry.interface_state[index]
interface_name = kernel.vif_index_to_name_dic[index]
is_forwarding = interface_state.is_forwarding()
try:
if index != upstream_if_index:
prune_state = type(interface_state._prune_state).__name__
assert_state = type(interface_state._assert_state).__name__
else:
prune_state = type(interface_state._graft_prune_state).__name__
assert_state = "-"
except:
prune_state = "-"
assert_state = "-"
t.add_row([ip, group, interface_name, prune_state, assert_state, is_forwarding])
return str(t)
def stop():
remove_interface("*", pim=True, igmp=True)
kernel.exit()
UnicastRouting.stop()
def main():
from Hello import Hello
from IGMP import IGMP
from Assert import Assert
from JoinPrune import JoinPrune
from GraftAck import GraftAck
from Graft import Graft
Hello()
Assert()
JoinPrune()
Graft()
GraftAck()
global kernel
kernel = Kernel()
global igmp
igmp = IGMP()
from threading import Timer
import time
from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING
from threading import Lock
import Main
if TYPE_CHECKING:
from InterfacePIM import InterfacePim
class Neighbor:
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
raise Exception
self.contact_interface = contact_interface
self.ip = ip
self.generation_id = generation_id
# todo lan prune delay
# todo override interval
# todo state refresh capable
self.neighbor_liveness_timer = None
self.hello_hold_time = None
self.set_hello_hold_time(hello_hold_time)
self.time_of_last_update = time.time()
self.neighbor_lock = Lock()
# send hello to new neighbor
#self.contact_interface.send_hello()
# todo RANDOM DELAY??? => DO NOTHING... EVENTUALLY THE HELLO MESSAGE WILL BE SENT
def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.remove()
elif hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT:
self.neighbor_liveness_timer = Timer(hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
else:
self.neighbor_liveness_timer = None
def set_generation_id(self, generation_id):
# neighbor restarted
if self.generation_id != generation_id:
self.generation_id = generation_id
self.contact_interface.send_hello()
self.reset()
"""
def heartbeat(self):
if (self.hello_hold_time != HELLO_HOLD_TIME_TIMEOUT) and \
(self.hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT):
print("HEARTBEAT")
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
self.neighbor_liveness_timer = Timer(self.hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
self.time_of_last_update = time.time()
"""
def remove(self):
print('HELLO TIMER EXPIRED... remove neighbor')
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
#Main.remove_neighbor(self.ip)
interface_name = self.contact_interface.interface_name
neighbor_ip = self.ip
Main.kernel.neighbor_removed(interface_name, neighbor_ip)
del self.contact_interface.neighbors[self.ip]
def reset(self):
interface_name = self.contact_interface.interface_name
neighbor_ip = self.ip
Main.kernel.neighbor_removed(interface_name, neighbor_ip)
# todo new neighbor
def receive_hello(self, generation_id, hello_hold_time):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.set_hello_hold_time(hello_hold_time)
else:
self.time_of_last_update = time.time()
self.set_generation_id(generation_id)
self.set_hello_hold_time(hello_hold_time)
from .PacketIpHeader import PacketIpHeader
from .PacketPayload import PacketPayload
class Packet(object):
def __init__(self, ip_header: PacketIpHeader = None, payload: PacketPayload = None):
self.ip_header = ip_header
self.payload = payload
def bytes(self) -> bytes:
return self.payload.bytes()
import struct
from utils import checksum
import socket
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Max Resp Time | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Group Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Resv |S| QRV | QQIC | Number of Sources (N) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address [1] |
+- -+
| Source Address [2] |
+- . -+
. . .
. . .
+- -+
| Source Address [N] |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIGMPHeader(PacketPayload):
IGMP_TYPE = 2
IGMP_HDR = "! BB H 4s"
IGMP_HDR_LEN = struct.calcsize(IGMP_HDR)
IGMP3_SRC_ADDR_HDR = "! BB H "
IGMP3_SRC_ADDR_HDR_LEN = struct.calcsize(IGMP3_SRC_ADDR_HDR)
IPv4_HDR = "! 4s"
IPv4_HDR_LEN = struct.calcsize(IPv4_HDR)
Membership_Query = 0x11
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Version_1_Membership_Report = 0x12
def __init__(self, type: int, max_resp_time: int, group_address: str="0.0.0.0"):
# todo check type
self.type = type
self.max_resp_time = max_resp_time
self.group_address = group_address
def bytes(self) -> bytes:
# obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketIGMPHeader.IGMP_HDR, self.type, self.max_resp_time, 0,
socket.inet_aton(self.group_address))
igmp_checksum = checksum(msg_without_chcksum)
msg = msg_without_chcksum[0:2] + struct.pack("! H", igmp_checksum) + msg_without_chcksum[4:]
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
#print("parseIGMPHdr: ", data)
igmp_hdr = data[0:PacketIGMPHeader.IGMP_HDR_LEN]
(type, max_resp_time, rcv_checksum, group_address) = struct.unpack(PacketIGMPHeader.IGMP_HDR, igmp_hdr)
#print(type, max_resp_time, rcv_checksum, group_address)
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
#print("checksum calculated: " + str(checksum(msg_to_checksum)))
if checksum(msg_to_checksum) != rcv_checksum:
#print("wrong checksum")
raise Exception("wrong checksum")
igmp_hdr = igmp_hdr[PacketIGMPHeader.IGMP_HDR_LEN:]
group_address = socket.inet_ntoa(group_address)
pkt = PacketIGMPHeader(type, max_resp_time, group_address)
return pkt
\ No newline at end of file
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIpHeader:
IP_HDR = "! BBH HH BBH 4s 4s"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
self.version = ver
self.hdr_length = hdr_len
self.ttl = ttl
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, data)
ver = (verhlen & 0xf0) >> 4
hlen = (verhlen & 0x0f) * 4
'''
"VER": ver,
"HLEN": hlen,
"TOS": tos,
"IPLEN": iplen,
"IPID": ipid,
"FRAG": frag,
"TTL": ttl,
"PROTO": proto,
"CKSUM": cksum,
"SRC": socket.inet_ntoa(src),
"DST": socket.inet_ntoa(dst)
'''
src_ip = socket.inet_ntoa(src)
dst_ip = socket.inet_ntoa(dst)
return PacketIpHeader(ver, hlen, ttl, proto, src_ip, dst_ip)
import abc
class PacketPayload(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def bytes(self) -> bytes:
"""Get packet payload in bytes format"""
@abc.abstractmethod
def __len__(self):
"""Get packet payload length"""
@staticmethod
@abc.abstractmethod
def parse_bytes(data: bytes):
"""From bytes create a object payload"""
import struct
import socket
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimAssert:
PIM_TYPE = 5
PIM_HDR_ASSERT = "! %ss %ss LL"
PIM_HDR_ASSERT_WITHOUT_ADDRESS = "! LL"
PIM_HDR_ASSERT_v4 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_ASSERT_v6 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN = struct.calcsize(PIM_HDR_ASSERT_WITHOUT_ADDRESS)
PIM_HDR_ASSERT_v4_LEN = struct.calcsize(PIM_HDR_ASSERT_v4)
PIM_HDR_ASSERT_v6_LEN = struct.calcsize(PIM_HDR_ASSERT_v6)
def __init__(self, multicast_group_address: str or bytes, source_address: str or bytes, metric_preference, metric):
if type(multicast_group_address) is bytes:
multicast_group_address = socket.inet_ntoa(multicast_group_address)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
self.multicast_group_address = multicast_group_address
self.source_address = source_address
self.metric_preference = metric_preference
self.metric = metric
def bytes(self) -> bytes:
multicast_group_address = PacketPimEncodedGroupAddress(self.multicast_group_address).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
msg = multicast_group_address + source_address + struct.pack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS,
0x7FFFFFFF & self.metric_preference,
self.metric)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_addr_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_addr_len = len(multicast_group_addr_obj)
data = data[multicast_group_addr_len:]
source_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_addr_len = len(source_addr_obj)
data = data[source_addr_len:]
(metric_preference, metric) = struct.unpack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS, data[:PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN])
pim_payload = PacketPimAssert(multicast_group_addr_obj.group_address, source_addr_obj.unicast_address, metric_preference, metric)
return pim_payload
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type |B| Reserved |Z| Mask Len |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Group Multicast Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedGroupAddress:
PIM_ENCODED_GROUP_ADDRESS_HDR = "! BBBB %s"
PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS = "! BBBB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS)
PIM_ENCODED_GROUP_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6 = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
RESERVED = 0
def __init__(self, group_address, mask_len=None):
if type(group_address) not in (str, bytes):
raise Exception
if type(group_address) is bytes:
group_address = socket.inet_ntoa(group_address)
self.group_address = group_address
self.mask_len = mask_len
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedGroupAddress.get_ip_info(self.group_address)
mask_len = self.mask_len
if mask_len is None:
mask_len = 8 * struct.calcsize(string_ip_hdr)
ip = socket.inet_pton(socket_family, self.group_address)
msg = struct.pack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0,
PacketPimEncodedGroupAddress.RESERVED, mask_len, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedGroupAddress.IPV4_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedGroupAddress.IPV6_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.group_address).version
if version == 4:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_group_addr = data[0:PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN]
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS, data_without_group_addr)
data_group_addr = data[PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV4_HDR, data_group_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV6_HDR, data_group_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedGroupAddress(ip, mask_len)
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type | Rsrvd |S|W|R| Mask Len |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedSourceAddress:
PIM_ENCODED_SOURCE_ADDRESS_HDR = "! BBBB %s"
PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS = "! BBBB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS)
PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6 = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
RESERVED_AND_SWR_BITS = 0
def __init__(self, source_address, mask_len=None):
if type(source_address) not in (str, bytes):
raise Exception
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
self.source_address = source_address
self.mask_len = mask_len
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedSourceAddress.get_ip_info(self.source_address)
mask_len = self.mask_len
if mask_len is None:
mask_len = 8 * struct.calcsize(string_ip_hdr)
ip = socket.inet_pton(socket_family, self.source_address)
msg = struct.pack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0,
PacketPimEncodedSourceAddress.RESERVED_AND_SWR_BITS, mask_len, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedSourceAddress.IPV4_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedSourceAddress.IPV6_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.source_address).version
if version == 4:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_source_addr = data[0:PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN]
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS, data_without_source_addr)
data_source_addr = data[PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV4_HDR, data_source_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV6_HDR, data_source_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedSourceAddress(ip, mask_len)
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type | Unicast Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedUnicastAddress:
PIM_ENCODED_UNICAST_ADDRESS_HDR = "! BB %s"
PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS = "! BB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS)
PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6 = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
def __init__(self, unicast_address):
if type(unicast_address) not in (str, bytes):
raise Exception
if type(unicast_address) is bytes:
unicast_address = socket.inet_ntoa(unicast_address)
self.unicast_address = unicast_address
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedUnicastAddress.get_ip_info(self.unicast_address)
ip = socket.inet_pton(socket_family, self.unicast_address)
msg = struct.pack(PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedUnicastAddress.IPV4_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedUnicastAddress.IPV6_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.unicast_address).version
if version == 4:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_unicast_addr = data[0:PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN]
(addr_family, encoding) = struct.unpack(PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS, data_without_unicast_addr)
data_unicast_addr = data[PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN:]
if addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV4_HDR, data_unicast_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV6_HDR, data_unicast_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedUnicastAddress(ip)
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraft(PacketPimJoinPrune):
PIM_TYPE = 6
def __init__(self, upstream_neighbor_address, holdtime=0):
super().__init__(upstream_neighbor_address=upstream_neighbor_address, hold_time=holdtime)
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraftAck(PacketPimJoinPrune):
PIM_TYPE = 7
def __init__(self, upstream_neighbor_address, holdtime=0):
super().__init__(upstream_neighbor_address, hold_time=holdtime)
import struct
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimAssert import PacketPimAssert
from Packet.PacketPimGraft import PacketPimGraft
from Packet.PacketPimGraftAck import PacketPimGraftAck
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from utils import checksum
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimHeader(PacketPayload):
PIM_VERSION = 2
PIM_HDR = "! BB H"
PIM_HDR_LEN = struct.calcsize(PIM_HDR)
PIM_MSG_TYPES = {0: PacketPimHello,
3: PacketPimJoinPrune,
5: PacketPimAssert,
6: PacketPimGraft,
7: PacketPimGraftAck,
9: PacketPimStateRefresh
}
def __init__(self, payload):
self.payload = payload
def get_pim_type(self):
return self.payload.PIM_TYPE
def bytes(self) -> bytes:
# obter mensagem e criar checksum
pim_vrs_type = (PacketPimHeader.PIM_VERSION << 4) + self.get_pim_type()
msg_without_chcksum = struct.pack(PacketPimHeader.PIM_HDR, pim_vrs_type, 0, 0)
msg_without_chcksum += self.payload.bytes()
pim_checksum = checksum(msg_without_chcksum)
msg = msg_without_chcksum[0:2] + struct.pack("! H", pim_checksum) + msg_without_chcksum[4:]
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
print("parsePimHdr: ", data)
pim_hdr = data[0:PacketPimHeader.PIM_HDR_LEN]
(pim_ver_type, reserved, rcv_checksum) = struct.unpack(PacketPimHeader.PIM_HDR, pim_hdr)
print(pim_ver_type, reserved, rcv_checksum)
pim_version = (pim_ver_type & 0xF0) >> 4
pim_type = pim_ver_type & 0x0F
if pim_version != PacketPimHeader.PIM_VERSION:
print("Version of PIM packet received not known (!=2)")
raise Exception
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
print("checksum calculated: " + str(checksum(msg_to_checksum)))
if checksum(msg_to_checksum) != rcv_checksum:
print("wrong checksum")
raise Exception
pim_payload = data[PacketPimHeader.PIM_HDR_LEN:]
pim_payload = PacketPimHeader.PIM_MSG_TYPES[pim_type].parse_bytes(pim_payload)
'''
if pim_type == 0: # hello
pim_payload = PacketPimHello.parse_bytes(pim_payload)
elif pim_type == 3: # join/prune
pim_payload = PacketPimJoinPrune.parse_bytes(pim_payload)
print("hold_time = ", pim_payload.hold_time)
print("upstream_neighbor = ", pim_payload.upstream_neighbor_address)
for i in pim_payload.groups:
print(i.multicast_group)
print(i.joined_src_addresses)
print(i.pruned_src_addresses)
elif pim_type == 5: # assert
pim_payload = PacketPimAssert.parse_bytes(pim_payload)
else:
raise Exception
'''
return PacketPimHeader(pim_payload)
import struct
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Type | Option Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Value |
| ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Type | Option Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Value |
| ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimHello:
PIM_TYPE = 0
PIM_HDR_OPTS = "! HH"
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
PIM_MSG_TYPES_LENGTH = {1: 2,
20: 4,
21: 4,
}
# todo: pensar melhor na implementacao state refresh capable option...
def __init__(self):
self.options = {}
def add_option(self, option_type: int, option_value: int or float):
option_value = int(option_value)
# if option_value requires more bits than the bits available for that field: option value will have all field bits = 1
if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option_type] = option_value
def get_options(self):
return self.options
def bytes(self) -> bytes:
res = b''
for (option_type, option_value) in self.options.items():
option_length = PacketPimHello.PIM_MSG_TYPES_LENGTH[option_type]
type_length_hdr = struct.pack(PacketPimHello.PIM_HDR_OPTS, option_type, option_length)
res += type_length_hdr + struct.pack("! " + str(option_length) + "s", option_value.to_bytes(option_length, byteorder='big'))
return res
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
while data != b'':
(option_type, option_length) = struct.unpack(PacketPimHello.PIM_HDR_OPTS,
data[:PacketPimHello.PIM_HDR_OPTS_LEN])
print(option_type, option_length)
data = data[PacketPimHello.PIM_HDR_OPTS_LEN:]
print(data)
(option_value,) = struct.unpack("! " + str(option_length) + "s", data[:option_length])
option_value_number = int.from_bytes(option_value, byteorder='big')
print("option value: ", option_value_number)
'''
options_list.append({"OPTION TYPE": option_type,
"OPTION LENGTH": option_length,
"OPTION VALUE": option_value_number
})
'''
pim_payload.add_option(option_type, option_value_number)
data = data[option_length:]
return pim_payload
import struct
import socket
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimJoinPrune:
PIM_TYPE = 3
PIM_HDR_JOIN_PRUNE = "! %ss BBH"
PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS = "! BBH"
PIM_HDR_JOIN_PRUNE_v4 = PIM_HDR_JOIN_PRUNE % PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN
PIM_HDR_JOIN_PRUNE_v6 = PIM_HDR_JOIN_PRUNE % PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS)
PIM_HDR_JOIN_PRUNE_v4_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_v4)
PIM_HDR_JOIN_PRUNE_v6_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_v6)
def __init__(self, upstream_neighbor_address, hold_time):
if type(upstream_neighbor_address) not in (str, bytes):
raise Exception
if type(upstream_neighbor_address) is bytes:
upstream_neighbor_address = socket.inet_ntoa(upstream_neighbor_address)
self.groups = []
self.upstream_neighbor_address = upstream_neighbor_address
self.hold_time = hold_time
def add_multicast_group(self, group: PacketPimJoinPruneMulticastGroup):
# TODO verificar se grupo ja esta na msg
self.groups.append(group)
def bytes(self) -> bytes:
upstream_neighbor_address = PacketPimEncodedUnicastAddress(self.upstream_neighbor_address).bytes()
msg = upstream_neighbor_address + struct.pack(PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS, 0,
len(self.groups), self.hold_time)
for multicast_group in self.groups:
msg += multicast_group.bytes()
return msg
def __len__(self):
return len(self.bytes())
@classmethod
def parse_bytes(cls, data: bytes):
upstream_neighbor_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
upstream_neighbor_addr_len = len(upstream_neighbor_addr_obj)
data = data[upstream_neighbor_addr_len:]
(_, num_groups, hold_time) = struct.unpack(PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS,
data[:PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN])
data = data[PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN:]
pim_payload = cls(upstream_neighbor_addr_obj.unicast_address, hold_time)
for i in range(0, num_groups):
group = PacketPimJoinPruneMulticastGroup.parse_bytes(data)
group_len = len(group)
pim_payload.add_multicast_group(group)
data = data[group_len:]
return pim_payload
import struct
import socket
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
from Packet.PacketPimEncodedSourceAddress import PacketPimEncodedSourceAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimJoinPruneMulticastGroup:
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP = "! %ss HH"
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS = "! HH"
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_v4_LEN_ = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP % PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN)
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_v6_LEN_ = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP % PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6)
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS)
PIM_HDR_JOINED_PRUNED_SOURCE = "! %ss"
PIM_HDR_JOINED_PRUNED_SOURCE_v4_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN
PIM_HDR_JOINED_PRUNED_SOURCE_v6_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
def __init__(self, multicast_group: str or bytes, joined_src_addresses: list=[], pruned_src_addresses: list=[]):
if type(multicast_group) not in (str, bytes):
raise Exception
elif type(multicast_group) is bytes:
multicast_group = socket.inet_ntoa(multicast_group)
if type(joined_src_addresses) is not list:
raise Exception
if type(pruned_src_addresses) is not list:
raise Exception
self.multicast_group = multicast_group
self.joined_src_addresses = joined_src_addresses
self.pruned_src_addresses = pruned_src_addresses
def bytes(self) -> bytes:
multicast_group_address = PacketPimEncodedGroupAddress(self.multicast_group).bytes()
msg = multicast_group_address + struct.pack(self.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS,
len(self.joined_src_addresses), len(self.pruned_src_addresses))
for joined_src_address in self.joined_src_addresses:
joined_src_address_bytes = PacketPimEncodedSourceAddress(joined_src_address).bytes()
msg += joined_src_address_bytes
for pruned_src_address in self.pruned_src_addresses:
pruned_src_address_bytes = PacketPimEncodedSourceAddress(pruned_src_address).bytes()
msg += pruned_src_address_bytes
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_addr_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_addr_len = len(multicast_group_addr_obj)
data = data[multicast_group_addr_len:]
number_join_prune_data = data[:PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN]
(number_joined_sources, number_pruned_sources) = struct.unpack(PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS, number_join_prune_data)
joined = []
pruned = []
data = data[PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN:]
for i in range(0, number_joined_sources):
joined_obj = PacketPimEncodedSourceAddress.parse_bytes(data)
joined_obj_len = len(joined_obj)
data = data[joined_obj_len:]
joined.append(joined_obj.source_address)
for i in range(0, number_pruned_sources):
pruned_obj = PacketPimEncodedSourceAddress.parse_bytes(data)
pruned_obj_len = len(pruned_obj)
data = data[pruned_obj_len:]
pruned.append(pruned_obj.source_address)
return PacketPimJoinPruneMulticastGroup(multicast_group_addr_obj.group_address, joined, pruned)
import struct
import socket
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Originator Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Masklen | TTL |P|N|O|Reserved | Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimStateRefresh:
PIM_TYPE = 9
PIM_HDR_STATE_REFRESH = "! %ss %ss %ss I I BBBB"
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES = "! I I BBBB"
PIM_HDR_STATE_REFRESH_v4 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_STATE_REFRESH_v6 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES)
PIM_HDR_STATE_REFRESH_v4_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v4)
PIM_HDR_STATE_REFRESH_v6_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v6)
def __init__(self, multicast_group_adress: str or bytes, source_address: str or bytes, originator_adress: str or bytes,
metric_preference: int, metric: int, mask_len: int, ttl: int, prune_indicator_flag: bool,
prune_now_flag: bool, assert_override_flag: bool, interval: int):
if type(multicast_group_adress) is bytes:
multicast_group_adress = socket.inet_ntoa(multicast_group_adress)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
if type(originator_adress) is bytes:
originator_adress = socket.inet_ntoa(originator_adress)
self.multicast_group_adress = multicast_group_adress
self.source_address = source_address
self.originator_adress = originator_adress
self.metric_preference = metric_preference
self.metric = metric
self.mask_len = mask_len
self.ttl = ttl
self.prune_indicator_flag = prune_indicator_flag
self.prune_now_flag = prune_now_flag
self.assert_override_flag = assert_override_flag
self.interval = interval
def bytes(self) -> bytes:
multicast_group_adress = PacketPimEncodedGroupAddress(self.multicast_group_adress).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
originator_adress = PacketPimEncodedUnicastAddress(self.originator_adress).bytes()
prune_and_assert_flags = (self.prune_indicator_flag << 7) | (self.prune_now_flag << 6) | (self.assert_override_flag << 5)
msg = multicast_group_adress + source_address + originator_adress + \
struct.pack(self.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, 0x7FFFFFFF & self.metric_preference,
self.metric, self.mask_len, self.ttl, prune_and_assert_flags, self. interval)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_adress_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_adress_len = len(multicast_group_adress_obj)
data = data[multicast_group_adress_len:]
source_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_address_len = len(source_address_obj)
data = data[source_address_len:]
originator_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
originator_address_len = len(originator_address_obj)
data = data[originator_address_len:]
(metric_preference, metric, mask_len, ttl, reserved_and_prune_and_assert_flags, interval) = struct.unpack(PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, data[:PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN])
metric_preference = 0x7FFFFFFF & metric_preference
prune_indicator_flag = (0x80 & reserved_and_prune_and_assert_flags) >> 7
prune_now_flag = (0x40 & reserved_and_prune_and_assert_flags) >> 6
assert_override_flag = (0x20 & reserved_and_prune_and_assert_flags) >> 5
pim_payload = PacketPimStateRefresh(multicast_group_adress_obj.group_address, source_address_obj.unicast_address,
originator_address_obj.unicast_address, metric_preference, metric, mask_len,
ttl, prune_indicator_flag, prune_now_flag, assert_override_flag, interval)
return pim_payload
from Packet.Packet import Packet
from Packet.PacketIpHeader import PacketIpHeader
from Packet.PacketIGMPHeader import PacketIGMPHeader
from .PacketPimHeader import PacketPimHeader
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from Interface import Interface
class ReceivedPacket(Packet):
# choose payload protocol class based on ip protocol number
payload_protocol = {2: PacketIGMPHeader, 103: PacketPimHeader}
def __init__(self, raw_packet: bytes, interface: 'Interface'):
self.interface = interface
# Parse ao packet e preencher objeto Packet
packet_ip_hdr = raw_packet[:PacketIpHeader.IP_HDR_LEN]
ip_header = PacketIpHeader.parse_bytes(packet_ip_hdr)
protocol_number = ip_header.proto
packet_without_ip_hdr = raw_packet[ip_header.hdr_length:]
payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr)
super().__init__(ip_header=ip_header, payload=payload)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Read Write Lock
"""
import threading
import time
class RWLockRead(object):
"""
A Read/Write lock giving preference to Reader
"""
def __init__(self):
self.V_ReadCount = 0
self.A_Resource = threading.Lock()
self.A_LockReadCount = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
self.V_Locked = self.A_RWLock.A_Resource.acquire(blocking, timeout)
return self.V_Locked
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockRead._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockRead._aWriter(self)
class RWLockWrite(object):
"""
A Read/Write lock giving preference to Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.V_WriteCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockWriteCount = threading.Lock()
self.A_LockReadEntry = threading.Lock()
self.A_LockReadTry = threading.Lock()
self.A_Resource = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadEntry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadEntry.release()
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
return False
self.A_RWLock.V_ReadCount += 1
if (self.A_RWLock.V_ReadCount == 1):
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if (self.A_RWLock.V_ReadCount == 0):
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockWriteCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_WriteCount += 1
if (self.A_RWLock.V_WriteCount == 1):
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_WriteCount -= 1
self.A_RWLock.A_LockWriteCount.release()
return False
self.A_RWLock.A_LockWriteCount.release()
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if self.A_RWLock.V_WriteCount == 0:
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if (self.A_RWLock.V_WriteCount == 0):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockWrite._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockWrite._aWriter(self)
class RWLockFair(object):
"""
A Read/Write lock giving fairness to both Reader and Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockRead = threading.Lock()
self.A_LockWrite = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockRead.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockFair._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockFair._aWriter(self)
#!/usr/bin/env python
from Daemon.Daemon import Daemon
import Main
import _pickle as pickle
import socket
import sys
import os
import argparse
import traceback
def client_socket(data_to_send):
# Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Connect the socket to the port where the server is listening
server_address = './uds_socket'
#print('connecting to %s' % server_address)
try:
sock.connect(server_address)
sock.sendall(pickle.dumps(data_to_send))
data_rcv = sock.recv(1024 * 256)
if data_rcv:
print(pickle.loads(data_rcv))
except socket.error:
pass
finally:
#print('closing socket')
sock.close()
class MyDaemon(Daemon):
def run(self):
Main.main()
server_address = './uds_socket'
# Make sure the socket does not already exist
try:
os.unlink(server_address)
except OSError:
if os.path.exists(server_address):
raise
# Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Bind the socket to the port
sock.bind(server_address)
# Listen for incoming connections
sock.listen(1)
while True:
try:
connection, client_address = sock.accept()
data = connection.recv(256 * 1024)
print(sys.stderr, 'sending data back to the client')
print(pickle.loads(data))
args = pickle.loads(data)
if args.list_interfaces:
connection.sendall(pickle.dumps(Main.list_enabled_interfaces()))
elif args.list_neighbors:
connection.sendall(pickle.dumps(Main.list_neighbors()))
elif args.list_state:
connection.sendall(pickle.dumps(Main.list_state()))
elif args.add_interface:
Main.add_interface(args.add_interface[0], pim=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.add_interface_igmp:
Main.add_interface(args.add_interface_igmp[0], igmp=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.remove_interface:
Main.remove_interface(args.remove_interface[0], pim=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.remove_interface_igmp:
Main.remove_interface(args.remove_interface_igmp[0], igmp=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.stop:
Main.stop()
connection.shutdown(socket.SHUT_RDWR)
except Exception:
connection.shutdown(socket.SHUT_RDWR)
traceback.print_exc()
finally:
# Clean up the connection
connection.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PIM')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-start", "--start", action="store_true", default=False, help="Start PIM")
group.add_argument("-stop", "--stop", action="store_true", default=False, help="Stop PIM")
group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM")
group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces")
group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors")
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)")
args = parser.parse_args()
print(parser.parse_args())
daemon = MyDaemon('/tmp/Daemon-pim.pid')
if args.start:
print("start")
daemon.start()
sys.exit(0)
elif args.stop:
client_socket(args)
daemon.stop()
sys.exit(0)
elif args.restart:
daemon.restart()
sys.exit(0)
elif args.verbose:
os.system("tailf stdout")
sys.exit(0)
elif args.multicast_routes:
os.system("ip mroute show")
sys.exit(0)
elif not daemon.is_running():
print("PIM is not running")
parser.print_usage()
sys.exit(0)
client_socket(args)
import random
from threading import Timer
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Interface import Interface
import Main
from utils import HELLO_HOLD_TIME_TIMEOUT
class StateRefresh:
TYPE = 9
def __init__(self):
Main.add_protocol(StateRefresh.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload
# TODO
raise Exception
\ No newline at end of file
import socket
import time
import struct
# ficheiros importantes: /usr/include/linux/mroute.h
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # Activate the kernel mroute code */
MRT_DONE = (MRT_BASE+1) #/* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE+2) #/* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE+3) #/* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE+4) #/* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE+5) #/* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE+6) #/* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE+7) #/* Activate PIM assert mode */
MRT_PIM = (MRT_BASE+8) #/* enable PIM code */
MRT_TABLE = (MRT_BASE+9) #/* Specify mroute table ID */
MRT_ADD_MFC_PROXY = (MRT_BASE+10) #/* Add a (*,*|G) mfc entry */
MRT_DEL_MFC_PROXY = (MRT_BASE+11) #/* Del a (*,*|G) mfc entry */
MRT_MAX = (MRT_BASE+11)
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3
s2 = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
#MRT INIT
s2.setsockopt(socket.IPPROTO_IP, MRT_INIT, 1)
#MRT PIM
s2.setsockopt(socket.IPPROTO_IP, MRT_PIM, 1)
#ADD VIRTUAL INTERFACE
#estrutura = struct.pack("HBBI 4s 4s", 1, 0x4, 0, 0, socket.inet_aton("192.168.1.112"), socket.inet_aton("224.1.1.112"))
estrutura = struct.pack("HBBI 4s 4s", 0, 0x0, 1, 0, socket.inet_aton("10.0.0.1"), socket.inet_aton("0.0.0.0"))
print(estrutura)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_VIF, estrutura)
estrutura = struct.pack("HBBI 4s 4s", 1, 0x0, 1, 0, socket.inet_aton("192.168.2.2"), socket.inet_aton("0.0.0.0"))
print(estrutura)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_VIF, estrutura)
#time.sleep(5)
while True:
print("recv:")
msg = s2.recv(5000)
print(len(msg))
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst, _) = struct.unpack("II B B B B 4s 4s 8s", msg)
print(im_msgtype)
print(im_mbz)
print(im_vif)
print(socket.inet_ntoa(im_src))
print(socket.inet_ntoa(im_dst))
if im_msgtype == IGMPMSG_NOCACHE:
print("^^ IGMP NO CACHE")
print(struct.unpack("II B B B B 4s 4s 8s", msg))
#s2.setsockopt(socket.IPPROTO_IP, MRT_PIM, 1)
#print(s2.getsockopt(socket.IPPROTO_IP, 208))
#s2.setsockopt(socket.IPPROTO_IP, 208, 0)
#ADD MULTICAST FORWARDING ENTRY
estrutura = struct.pack("4s 4s H BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB IIIi", socket.inet_aton("10.0.0.2"), socket.inet_aton("224.1.1.113"), 0, 0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0, 0, 0)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_MFC, estrutura)
time.sleep(30)
#MRT DONE
s2.setsockopt(socket.IPPROTO_IP, MRT_DONE, 1)
s2.close()
exit(1)
from pyroute2 import IPDB, IPRoute
import socket
#ipdb = IPDB()
ipr = IPRoute()
def get_route(ip_dst: str):
with IPDB() as ipdb:
ip_bytes = socket.inet_aton(ip_dst)
ip_int = int.from_bytes(ip_bytes, byteorder='big')
info = None
for mask_len in range(32, 0, -1):
ip_bytes = (ip_int & (0xFFFFFFFF << (32 - mask_len))).to_bytes(4, "big")
ip_dst = socket.inet_ntoa(ip_bytes) + "/" + str(mask_len)
print(ip_dst)
try:
info = ipdb.routes[ip_dst]
break
except:
continue
if not info:
print("0.0.0.0/0")
info = ipdb.routes["default"]
print(info)
return info
# get metrics (routing preference and cost) to IP ip_dst
def get_metric(ip_dst: str):
unicast_routing_entry = get_route(ip_dst)
entry_protocol = unicast_routing_entry["proto"]
entry_cost = unicast_routing_entry["priority"]
return (entry_protocol, entry_cost)
"""
def get_rpf(ip_dst: str):
unicast_routing_entry = get_route(ip_dst)
#interface_oif = unicast_routing_entry['oif']
if not unicast_routing_entry['multipath']:
interface_oif = unicast_routing_entry['oif']
else:
multiple_entries = unicast_routing_entry['multipath']
print(multiple_entries)
(entry0, _) = multiple_entries
print(entry0)
interface_oif = entry0['oif']
print("ola")
print(ipdb.interfaces[interface_oif]['ipaddr'])
for i in range(len(ipdb.interfaces[interface_oif]['ipaddr'])):
print("ola2")
interface = ipdb.interfaces[interface_oif]['ipaddr'][i]
print(interface)
if interface['family'] == socket.AF_INET:
return interface['address']
return None
"""
# get output interface IP, used to send data to IP ip_dst
# (root interface IP to ip_dst)
def check_rpf(ip_dst):
# obter index da interface
# rpf_interface_index = ipr.get_routes(family=socket.AF_INET, dst=ip)[0]['attrs'][2][1]
# interface_name = if_indextoname(rpf_interface_index)
# return interface_name
# obter ip da interface de saida
rpf_interface_source = ipr.get_routes(family=socket.AF_INET, dst=ip_dst)[0]['attrs'][3][1]
return rpf_interface_source
"""
def get_metric(ip_dst: str):
ip_bytes = socket.inet_aton(ip_dst)
ip_int = int.from_bytes(ip_bytes, byteorder='big')
info = None
for mask_len in range(32, 0, -1):
ip_bytes = (ip_int & (0xFFFFFFFF << (32 - mask_len))).to_bytes(4, "big")
ip_dst = socket.inet_ntoa(ip_bytes) + "/" + str(mask_len)
print(ip_dst)
try:
info = ipdb.routes[ip_dst]
break
except:
continue
if not info:
print("0.0.0.0/0")
info = ipdb.routes["default"]
print(info)
print("metric=", info["priority"])
print("proto=", info["proto"])
#print(info.keys())
#if info["gateway"]:
# print("next_hop=", info["gateway"])
#elif info["prefsrc"]:
# print("next_hop=", info["prefsrc"])
return (info["proto"], info["priority"])
def check_rpf(ip_dst: str):
from pyroute2 import IPRoute
# from utils import if_indextoname
ipr = IPRoute()
# obter index da interface
# rpf_interface_index = ipr.get_routes(family=socket.AF_INET, dst=ip)[0]['attrs'][2][1]
# interface_name = if_indextoname(rpf_interface_index)
# return interface_name
# obter ip da interface de saida
rpf_interface_source = ipr.get_routes(family=socket.AF_INET, dst=ip_dst)[0]['attrs'][3][1]
return rpf_interface_source
"""
def stop():
ipr.close()
#ip = input("ip=")
#get_metric(ip)
\ No newline at end of file
from threading import Timer
from .wrapper import NoMembersPresent
from utils import GroupMembershipInterval, LastMemberQueryInterval, TYPE_CHECKING
from threading import Lock
if TYPE_CHECKING:
from .RouterState import RouterState
class GroupState(object):
def __init__(self, router_state: 'RouterState', group_ip: str):
self.router_state = router_state
self.group_ip = group_ip
self.state = NoMembersPresent
self.timer = None
self.v1_host_timer = None
self.retransmit_timer = None
# lock
self.lock = Lock()
# KernelEntry's instances to notify change of igmp state
self.multicast_interface_state = []
self.multicast_interface_state_lock = Lock()
def print_state(self):
return self.state.print_state()
###########################################
# Set timers
###########################################
def set_timer(self, alternative: bool=False, max_response_time: int=None):
self.clear_timer()
if not alternative:
time = GroupMembershipInterval
else:
time = self.router_state.interface_state.get_group_membership_time(max_response_time)
timer = Timer(time, self.group_membership_timeout)
timer.start()
self.timer = timer
def clear_timer(self):
if self.timer is not None:
self.timer.cancel()
def set_v1_host_timer(self):
self.clear_v1_host_timer()
v1_host_timer = Timer(GroupMembershipInterval, self.group_membership_v1_timeout)
v1_host_timer.start()
self.v1_host_timer = v1_host_timer
def clear_v1_host_timer(self):
if self.v1_host_timer is not None:
self.v1_host_timer.cancel()
def set_retransmit_timer(self):
self.clear_retransmit_timer()
retransmit_timer = Timer(LastMemberQueryInterval, self.retransmit_timeout)
retransmit_timer.start()
self.retransmit_timer = retransmit_timer
def clear_retransmit_timer(self):
if self.retransmit_timer is not None:
self.retransmit_timer.cancel()
###########################################
# Get group state from specific interface state
###########################################
def get_interface_group_state(self):
return self.state.get_state(self.router_state)
###########################################
# Timer timeout
###########################################
def group_membership_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_timeout(self)
def group_membership_v1_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_v1_timeout(self)
def retransmit_timeout(self):
with self.lock:
self.get_interface_group_state().retransmit_timeout(self)
###########################################
# Receive Packets
###########################################
def receive_v1_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v1_membership_report(self)
def receive_v2_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v2_membership_report(self)
def receive_leave_group(self):
with self.lock:
self.get_interface_group_state().receive_leave_group(self)
def receive_group_specific_query(self, max_response_time: int):
with self.lock:
self.get_interface_group_state().receive_group_specific_query(self, max_response_time)
###########################################
# Notify Routing
###########################################
def notify_routing_add(self):
with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=True)
def notify_routing_remove(self):
with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.append(kernel_entry)
return self.has_members()
def remove_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.remove(kernel_entry)
def has_members(self):
return self.state is not NoMembersPresent
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from threading import Timer
from utils import Membership_Query, QueryResponseInterval, QueryInterval, OtherQuerierPresentInterval, TYPE_CHECKING
from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState
from RWLock.RWLock import RWLockWrite
if TYPE_CHECKING:
from InterfaceIGMP import InterfaceIGMP
class RouterState(object):
def __init__(self, interface: 'InterfaceIGMP'):
# interface of the router connected to the network
self.interface = interface
# state of the router (Querier/NonQuerier)
self.interface_state = Querier
# state of each group
# Key: GroupIPAddress, Value: GroupState object
self.group_state = {}
self.group_state_lock = RWLockWrite()
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
self.interface.send(packet.bytes())
# set initial general query timer
timer = Timer(QueryInterval, self.general_query_timeout)
timer.start()
self.general_query_timer = timer
# present timer
self.other_querier_present_timer = None
# Send packet via interface
def send(self, data: bytes, address: str):
self.interface.send(data, address)
############################################
# interface_state methods
############################################
def print_state(self):
return self.interface_state.state_name()
def set_general_query_timer(self):
self.clear_general_query_timer()
general_query_timer = Timer(QueryInterval, self.general_query_timeout)
general_query_timer.start()
self.general_query_timer = general_query_timer
def clear_general_query_timer(self):
if self.general_query_timer is not None:
self.general_query_timer.cancel()
def set_other_querier_present_timer(self):
self.clear_other_querier_present_timer()
other_querier_present_timer = Timer(OtherQuerierPresentInterval, self.other_querier_present_timeout)
other_querier_present_timer.start()
self.other_querier_present_timer = other_querier_present_timer
def clear_other_querier_present_timer(self):
if self.other_querier_present_timer is not None:
self.other_querier_present_timer.cancel()
def general_query_timeout(self):
self.interface_state.general_query_timeout(self)
def other_querier_present_timeout(self):
self.interface_state.other_querier_present_timeout(self)
def change_interface_state(self, querier: bool):
if querier:
self.interface_state = Querier
else:
self.interface_state = NonQuerier
############################################
# group state methods
############################################
def get_group_state(self, group_ip):
with self.group_state_lock.genRlock():
if group_ip in self.group_state:
return self.group_state[group_ip]
with self.group_state_lock.genWlock():
if group_ip in self.group_state:
group_state = self.group_state[group_ip]
else:
group_state = GroupState(self, group_ip)
self.group_state[group_ip] = group_state
return group_state
def receive_v1_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v1_membership_report()
self.get_group_state(igmp_group).receive_v1_membership_report()
def receive_v2_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v2_membership_report()
self.get_group_state(igmp_group).receive_v2_membership_report()
def receive_leave_group(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group in self.group_state:
# self.group_state[igmp_group].receive_leave_group()
self.get_group_state(igmp_group).receive_leave_group()
def receive_query(self, packet: ReceivedPacket):
self.interface_state.receive_query(self, packet)
igmp_group = packet.payload.group_address
# process group specific query
if igmp_group != "0.0.0.0" and igmp_group in self.group_state:
#if igmp_group != "0.0.0.0":
max_response_time = packet.payload.max_resp_time
#self.group_state[igmp_group].receive_group_specific_query(max_response_time)
self.get_group_state(igmp_group).receive_group_specific_query(max_response_time)
\ No newline at end of file
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.state = MembersPresent
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from ..wrapper import NoMembersPresent
from ..wrapper import CheckingMembership
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.set_timer(alternative=True, max_response_time=max_response_time)
group_state.state = CheckingMembership
from ..wrapper import MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.state = MembersPresent
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, TYPE_CHECKING
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from . import NoMembersPresent, MembersPresent, CheckingMembership
from ipaddress import IPv4Address
if TYPE_CHECKING:
from ..RouterState import RouterState
class NonQuerier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
# do nothing
return
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
#change state to Querier
router_state.change_interface_state(querier=True)
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# reset other present querier timer
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Non Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return (max_response_time/10.0) * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return NonQuerier.get_members_present_state()
from Packet.PacketIGMPHeader import PacketIGMPHeader
from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.clear_retransmit_timer()
group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_addr = group_state.group_ip
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_addr)
group_state.router_state.send(data=packet.bytes(), address=group_addr)
group_state.set_retransmit_timer()
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.state = Version1MembersPresent
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.state = MembersPresent
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from Packet.PacketIGMPHeader import PacketIGMPHeader
from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.state = Version1MembersPresent
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_ip = group_state.group_ip
group_state.set_timer(alternative=True)
group_state.set_retransmit_timer()
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_ip)
group_state.router_state.send(data=packet.bytes(), address=group_ip)
group_state.state = CheckingMembership
def receive_group_specific_query(group_state: 'GroupState', max_response_time):
# do nothing
return
from ..wrapper import MembersPresent
from ..wrapper import Version1MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.state = Version1MembersPresent
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.state = MembersPresent
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval
from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent
from ipaddress import IPv4Address
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
class Querier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
# do nothing
return
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# if source ip of membership query lower than the ip of the received interface => change state
# change state of interface
# Querier -> Non Querier
router_state.change_interface_state(querier=False)
# set other present querier timer
router_state.clear_general_query_timer()
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return LastMemberQueryInterval * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return Version1MembersPresent
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.state = MembersPresent
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.set_v1_host_timer()
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_checking_membership_state()
def print_state():
return "CheckingMembership"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_members_present_state()
def print_state():
return "MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_no_members_present_state()
def print_state():
return "NoMembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_version_1_members_present_state()
def print_state():
return "Version1MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
import Main
import socket
from tree.originator import OriginatorState
from tree.tree_if_upstream import TreeInterfaceUpstream
from tree.tree_if_downstream import TreeInterfaceDownstream
from .tree_interface import TreeInterface
from threading import Timer, Lock, RLock
import UnicastRouting
class KernelEntry:
TREE_TIMEOUT = 180
def __init__(self, source_ip: str, group_ip: str, inbound_interface_index: int):
self.source_ip = source_ip
self.group_ip = group_ip
# ip of neighbor of the rpf
self._rpf_node = None
# (S,G) starts IG state
self._was_olist_null = None
# todo
#self._rpf_is_origin = False
self._originator_state = OriginatorState.NotOriginator
# decide inbound interface based on rpf check
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
Main.kernel.flood(source_ip, group_ip, self.inbound_interface_index)
self.interface_state = {} # type: Dict[int, TreeInterface]
for i in Main.kernel.vif_index_to_name_dic.keys():
try:
if i == self.inbound_interface_index:
self.interface_state[i] = TreeInterfaceUpstream(self, i, False)
else:
self.interface_state[i] = TreeInterfaceDownstream(self, i)
except:
import traceback
print(traceback.print_exc())
continue
self._multicast_change = Lock()
self._lock_test2 = RLock()
self.CHANGE_STATE_LOCK = RLock()
#self._was_olist_null = self.is_olist_null()
print('Tree created')
#self._liveliness_timer = None
#if self.is_originater():
# self.set_liveliness_timer()
# print('set SAT')
#self._lock = threading.RLock()
def get_inbound_interface_index(self):
return self.inbound_interface_index
def get_outbound_interfaces_indexes(self):
outbound_indexes = [0]*Main.kernel.MAXVIFS
for (index, state) in self.interface_state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
def check_rpf(self):
return UnicastRouting.check_rpf(self.source_ip)
#################################
# Receive (S,G) packet
#################################
def recv_data_msg(self, index):
print("recv data")
self.interface_state[index].recv_data_msg()
def recv_assert_msg(self, index, packet):
print("recv assert")
self.interface_state[index].recv_assert_msg()
def recv_prune_msg(self, index, packet):
print("recv prune msg")
self.interface_state[index].recv_prune_msg()
def recv_join_msg(self, index, packet):
print("recv join msg")
print("type: ")
self.interface_state[index].recv_join_msg()
def recv_graft_msg(self, index, packet):
print("recv graft msg")
self.interface_state[index].recv_graft_msg()
def recv_graft_ack_msg(self, index, packet):
print("recv graft ack msg")
self.interface_state[index].recv_graft_ack_msg()
def recv_state_refresh_msg(self, index, packet):
print("recv state refresh msg")
prune_indicator = 1
self.interface_state[index].recv_state_refresh_msg(prune_indicator)
def network_update(self, change, args):
#todo
return
def update(self, caller, arg):
#todo
return
def nbr_event(self, link, node, event):
# todo
return
def is_olist_null(self):
for interface in self.interface_state.values():
if interface.is_forwarding():
return False
return True
def evaluate_olist_change(self):
with self._lock_test2:
is_olist_null = self.is_olist_null()
if self._was_olist_null != is_olist_null:
if is_olist_null:
self.interface_state[self.inbound_interface_index].olist_is_null()
else:
self.interface_state[self.inbound_interface_index].olist_is_not_null()
self._was_olist_null = is_olist_null
def get_source(self):
return self.source_ip
def get_group(self):
return self.group_ip
def change(self):
# todo: changes on unicast routing or multicast routing...
with self._multicast_change:
Main.kernel.set_multicast_route(self)
def delete(self):
for state in self.interface_state.values():
state.delete()
Main.kernel.remove_multicast_route(self)
from abc import ABCMeta, abstractstaticmethod
import tree.globals as pim_globals
from .metric import AssertMetric
class AssertStateABC(metaclass=ABCMeta):
@abstractstaticmethod
def receivedDataFromDownstreamIf(interface):
"""
An (S,G) Data packet received on downstream interface
@type interface: TreeInterface
"""
raise NotImplementedError()
@abstractstaticmethod
def receivedInferiorMetricFromWinner(interface):
"""
Receive Inferior (Assert OR State Refresh) from Assert Winner
@type interface: TreeInterface
"""
raise NotImplementedError()
@abstractstaticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface):
"""
Receive Inferior (Assert OR State Refresh) from non-Assert Winner
AND CouldAssert==TRUE
@type interface: TreeInterface
"""
raise NotImplementedError()
@abstractstaticmethod
def receivedPreferedMetric(interface, assert_time, better_metric):
"""
Receive Preferred Assert OR State Refresh
@type interface: TreeInterface
@type assert_time: int
@type better_metric: AssertMetric
"""
raise NotImplementedError()
@abstractstaticmethod
def sendStateRefresh(interface, time):
"""
Send State Refresh
@type interface: TreeInterface
@type time: int
"""
raise NotImplementedError()
@abstractstaticmethod
def assertTimerExpires(interface):
"""
AT(S,G) Expires
@type interface: TreeInterface
"""
raise NotImplementedError()
@abstractstaticmethod
def couldAssertIsNowFalse(interface):
"""
CouldAssert -> FALSE
@type interface: TreeInterface
"""
raise NotImplementedError()
@abstractstaticmethod
def couldAssertIsNowTrue(interface):
"""
CouldAssert -> TRUE
@type interface: TreeInterface
"""
raise NotImplementedError()
@abstractstaticmethod
def winnerLivelinessTimerExpires(interface):
"""
Winner’s NLT(N,I) Expires
@type interface: TreeInterface
"""
raise NotImplementedError()
@abstractstaticmethod
def receivedPruneOrJoinOrGraft(interface):
"""
Receive Prune(S,G), Join(S,G) or Graft(S,G)
@type interface: TreeInterface
"""
raise NotImplementedError()
def _sendAssert_setAT(interface):
interface.send_assert()
interface.assert_timer.set_timer(pim_globals.ASSERT_TIME)
interface.assert_timer.reset()
@staticmethod
def rprint(interface, msg, *entrys):
'''
Method used for simplifiyng the process of reporting changes in a assert state
Tree Interface.
@type interface: TreeInterface
'''
interface.rprint(msg, 'assert state', *entrys)
# Override
def __str__(self) -> str:
return "PruneSM:" + self.__class__.__name__
class LoserState(AssertStateABC):
'''
I am Assert Loser (L)
This router has lost an (S,G) Assert on interface I. It must not
forward packets from S destined for G onto interface I.
'''
@staticmethod
def receivedDataFromDownstreamIf(interface):
"""
@type interface: TreeInterface
"""
interface.rprint('receivedDataFromDownstreamIf, L -> L')
@staticmethod
def receivedInferiorMetricFromWinner(interface):
LoserState._to_NoInfo(interface)
interface.rprint('receivedInferiorMetricFromWinner, L -> NI')
@staticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface):
interface.rprint(
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, L -> L')
@staticmethod
def receivedPreferedMetric(interface, assert_time, better_metric):
'''
@type better_metric: AssertMetric
'''
interface.assert_timer.set_timer(assert_time)
interface.assert_timer.reset()
has_winner_changed = interface.assert_winner_metric.node != better_metric.node
interface.assert_winner_metric = better_metric
if interface.could_assert() and has_winner_changed:
interface.send_prune()
interface.rprint('receivedPreferedMetric, L -> L', 'from:',
better_metric.node)
@staticmethod
def sendStateRefresh(interface, time):
assert False, "this should never ocurr"
@staticmethod
def assertTimerExpires(interface):
LoserState._to_NoInfo(interface)
interface.rprint('assertTimerExpires, L -> NI')
@staticmethod
def couldAssertIsNowFalse(interface):
LoserState._to_NoInfo(interface)
interface.rprint('couldAssertIsNowFalse, L -> NI')
@staticmethod
def couldAssertIsNowTrue(interface):
LoserState._to_NoInfo(interface)
interface.rprint('couldAssertIsNowTrue, L -> NI')
@staticmethod
def winnerLivelinessTimerExpires(interface):
LoserState._to_NoInfo(interface)
interface.rprint('winnerLivelinessTimerExpires, L -> NI')
@staticmethod
def receivedPruneOrJoinOrGraft(interface):
interface.send_assert()
interface.rprint('receivedPruneOrJoinOrGraft, L -> L')
@staticmethod
def _to_NoInfo(interface):
interface.assert_timer.stop()
interface.assert_state = AssertState.NoInfo
interface.assert_winner_metric = AssertMetric.infinite_assert_metric()
class NoInfoState(AssertStateABC):
'''
NoInfoState (NI)
This router has no (S,G) Assert state on interface I.
'''
@staticmethod
def receivedDataFromDownstreamIf(interface):
"""
@type interface: TreeInterface
"""
NoInfoState._sendAssert_setAT(interface)
interface.assert_state = AssertState.Winner
interface.assert_winner_metric = interface.assert_metric
interface.rprint('receivedDataFromDownstreamIf, NI -> W')
@staticmethod
def receivedInferiorMetricFromWinner(interface):
assert False, "this should never ocurr"
@staticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface):
NoInfoState._sendAssert_setAT(interface)
interface.assert_state = AssertState.Winner
interface.assert_winner_metric = interface.assert_metric
interface.rprint(
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, NI -> W')
@staticmethod
def receivedPreferedMetric(interface, assert_time, better_metric):
'''
@type interface: TreeInterface
'''
interface.assert_timer.set_timer(assert_time)
interface.assert_timer.reset()
interface.assert_state = AssertState.Loser
interface.assert_winner_metric = better_metric
if interface.could_assert():
interface.send_prune()
interface.rprint('receivedPreferedMetric, NI -> L')
@staticmethod
def sendStateRefresh(interface, time):
pass
@staticmethod
def assertTimerExpires(interface):
assert False, "this should never ocurr"
@staticmethod
def couldAssertIsNowFalse(interface):
interface.rprint('couldAssertIsNowFalse, NI -> NI')
@staticmethod
def couldAssertIsNowTrue(interface):
interface.rprint('couldAssertIsNowTrue, NI -> NI')
@staticmethod
def winnerLivelinessTimerExpires(interface):
assert False, "this should never ocurr"
@staticmethod
def receivedPruneOrJoinOrGraft(interface):
interface.rprint('receivedPruneOrJoinOrGraft, NI -> NI')
class WinnerState(AssertStateABC):
'''
I am Assert Winner (W)
This router has won an (S,G) Assert on interface I. It is now
responsible for forwarding traffic from S destined for G via
interface I.
'''
@staticmethod
def receivedDataFromDownstreamIf(interface):
"""
@type interface: TreeInterface
"""
WinnerState._sendAssert_setAT(interface)
interface.rprint('receivedDataFromDownstreamIf, W -> W')
@staticmethod
def receivedInferiorMetricFromWinner(interface):
assert False, "this should never ocurr"
@staticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface):
WinnerState._sendAssert_setAT(interface)
interface.rprint(
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, W -> W')
@staticmethod
def receivedPreferedMetric(interface, assert_time, better_metric):
'''
@type better_metric: AssertMetric
'''
interface.assert_timer.set_timer(assert_time)
interface.assert_timer.reset()
interface.assert_winner_metric = better_metric
interface.assert_state = AssertState.Loser
if interface.could_assert:
interface.send_prune()
interface.rprint('receivedPreferedMetric, W -> L', 'from:',
str(better_metric.node))
@staticmethod
def sendStateRefresh(interface, time):
interface.assert_timer.set_timer(time)
interface.assert_timer.reset()
@staticmethod
def assertTimerExpires(interface):
interface.assert_state = AssertState.NoInfo
interface.assert_winner_metric = AssertMetric.infinite_assert_metric()
interface.rprint('assertTimerExpires, W -> NI')
@staticmethod
def couldAssertIsNowFalse(interface):
interface.send_assert_cancel()
interface.assert_timer.stop()
interface.assert_state = AssertState.NoInfo
interface.assert_winner_metric = AssertMetric.infinite_assert_metric()
interface.rprint('couldAssertIsNowFalse, W -> NI')
@staticmethod
def couldAssertIsNowTrue(interface):
assert False, "this should never ocurr"
@staticmethod
def winnerLivelinessTimerExpires(interface):
assert False, "this should never ocurr"
@staticmethod
def receivedPruneOrJoinOrGraft(interface):
pass
class AssertState():
NoInfo = NoInfoState()
Winner = WinnerState()
Loser = LoserState()
from threading import Timer
class AssertWinnerState(object):
def GraftPruneState(self):
self._assert_state = AssertState.Winner
self._assert_timer = None
self._assert_winner_ip = None
self._assert_winner_metric = None
def set_assert_timer(self):
self.clear_assert_timer()
self._assert_timer= Timer()
def clear_assert_timer(self):
if self._assert_timer is not None:
self._assert_timer.cancel()
from abc import ABCMeta, abstractstaticmethod
from tree import globals as pim_globals
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from .tree_if_downstream import TreeInterfaceDownstream
class DownstreamStateABS(metaclass=ABCMeta):
@abstractstaticmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: Downstream
"""
raise NotImplementedError()
@abstractstaticmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: Downstream
"""
raise NotImplementedError()
@abstractstaticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"):
"""
Receive Graft(S,G)
@type interface: Downstream
"""
raise NotImplementedError()
@abstractstaticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
"""
PPT(S,G) Expires
@type interface: Downstream
"""
raise NotImplementedError()
@abstractstaticmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: Downstream
"""
raise NotImplementedError()
@abstractstaticmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: Downstream
"""
raise NotImplementedError()
@abstractstaticmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: Downstream
"""
raise NotImplementedError()
def __str__(self):
return "Downstream." + self.__class__.__name__
class NoInfo(DownstreamStateABS):
'''
NoInfo(NI)
The interface has no (S,G) Prune state, and neither the Prune
timer (PT(S,G,I)) nor the PrunePending timer ((PPT(S,G,I)) is
running.
'''
@staticmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.set_prune_state(DownstreamState.PrunePending)
time = 0
if len(interface.get_interface().neighbors) > 1:
time = pim_globals.JT_OVERRIDE_INTERVAL
#timer = interface.get_ppt().start(time)
interface.set_prune_pending_timer(time)
print("receivedPrune, NI -> PP")
@staticmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
# Do nothing
print("receivedJoin, NI -> NI")
@staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"):
"""
Receive Graft(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
# todo why pt stop???!!!
#interface.get_pt().stop()
interface.send_graft_ack()
print('receivedGraft, NI -> NI')
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
"""
PPT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
assert False
@staticmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
assert False
@staticmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: TreeInterfaceDownstreamDownstream
"""
pass
# Do nothing
@staticmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: TreeInterfaceDownstreamDownstream
"""
pass
# Do nothing
class PrunePending(DownstreamStateABS):
'''
PrunePending(PP)
The router has received a Prune(S,G) on this interface from a
downstream neighbor and is waiting to see whether the prune will
be overridden by another downstream router. For forwarding
purposes, the PrunePending state functions exactly like the
NoInfo state.
'''
@staticmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
print('receivedPrune, PP -> PP')
@staticmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
#interface.get_ppt().stop()
interface.clear_prune_pending_timer()
interface.set_prune_state(DownstreamState.NoInfo)
print('receivedJoin, PP -> NI')
@staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"):
"""
Receive Graft(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
# todo why prune timer and not prune pending timer???!
#interface.get_pt().stop()
interface.clear_prune_pending_timer()
interface.set_prune_state(DownstreamState.NoInfo)
interface.send_graft_ack()
print('receivedGraft, PP -> NI')
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
"""
PPT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.set_prune_state(DownstreamState.Pruned)
#pt = interface.get_pt()
#pt.start(interface.get_lpht() - pim_globals.JT_OVERRIDE_INTERVAL)
interface.set_prune_timer(prune_holdtime - pim_globals.JT_OVERRIDE_INTERVAL)
interface.send_pruneecho()
print('PPTexpires, PP -> P')
@staticmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
assert False
@staticmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: TreeInterfaceDownstreamDownstream
"""
# todo understand better
#interface.get_ppt().stop()
interface.clear_prune_pending_timer()
print('is_now_RPF_Interface, PP -> NI')
@staticmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: TreeInterfaceDownstreamDownstream
"""
pass
class Pruned(DownstreamStateABS):
'''
Pruned(P)
The router has received a Prune(S,G) on this interface from a
downstream neighbor, and the Prune was not overridden. Data from
S addressed to group G is no longer being forwarded on this
interface.
'''
@staticmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
# todo ppt???! should be pt
#ppt = interface.get_ppt()
#if interface.get_lpht() > ppt.time_left():
# ppt.set_timer(interface.get_lpht())
# ppt.reset()
# todo nao percebo... corrigir 0
if holdtime > 0:
interface.set_prune_timer(holdtime)
print('receivedPrune, P -> P')
@staticmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
#interface.get_pt().stop()
interface.clear_prune_timer()
interface.set_prune_state(DownstreamState.NoInfo)
print('receivedPrune, P -> NI')
@staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"):
"""
Receive Graft(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
#interface.get_pt().stop()
interface.clear_prune_timer()
interface.set_prune_state(DownstreamState.NoInfo)
interface.send_graft_ack()
print('receivedGraft, P -> NI')
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
"""
PPT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
assert False
@staticmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.set_prune_state(DownstreamState.NoInfo)
print('PTexpires, P -> NI')
@staticmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: TreeInterfaceDownstreamDownstream
"""
# todo ver melhor
#interface.get_pt().stop()
interface.clear_prune_timer()
print('is_now_RPF_Interface, P -> NI')
@staticmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: TreeInterfaceDownstreamDownstream
"""
#pt = interface.get_pt()
#pt.set_timer(interface.get_lpht())
#pt.reset()
interface.set_prune_timer(interface.get_lpht())
print('send_state_refresh, P -> P')
class DownstreamState():
NoInfo = NoInfo()
Pruned = Pruned()
PrunePending = PrunePending()
'''
Created on Feb 23, 2015
This module is intended to have all constants and global values for pim_dm
@author: alex
'''
ASSERT_TIME = 180
GRAFT_RETRY_PERIOD = 3
JT_OVERRIDE_INTERVAL = 3.0
OVERRIDE_INTERVAL = 2.5
REFRESH_INTERVAL = 60 # State Refresh Interval
SOURCE_LIFETIME = 210
T_LIMIT = 210
from threading import Timer
class GraftPruneState(object):
def GraftPruneState(self):
self._prune_state = SFMRPruneState.DIP
self._prune_pending_timer = None # type: Timer
self._prune_timer = None # type: Timer
def set_prune_pending_timer(self):
self.clear_prune_pending_timer()
self._prune_pending_time = Timer()
def clear_prune_pending_timer(self):
if self._prune_pending_time is not None:
self._prune_pending_time.cancel()
def set_prune_timer(self):
self.clear_prune_timer()
self._prune_timer = Timer()
def clear_prune_timer(self):
if self._prune_timer is not None:
self._prune_timer.cancel()
'''
Created on Sep 8, 2014
@author: alex
'''
class AssertMetric(object):
'''
Note: we consider the node name the ip of the metric.
'''
def __init__(self):
'''
@type tree_if: TreeInterface
'''
self._pref = None
self._metric = None
self._node = None
def is_worse_than(self, other):
if self._pref != other.pref:
return self._pref > other.pref
elif self._metric != other.metric:
return self._metric > other.metric
else:
return self._node.__str__() <= other.node.__str__()
@property
def pref(self):
return self._pref
@property
def metric(self):
return self._metric
@property
def node(self):
return self._node
@staticmethod
def infinite_assert_metric():
'''
@type metric: AssertMetric
'''
metric = AssertMetric()
metric._pref = 1
metric._metric = float("Inf")
metric._node = ""
return metric
@staticmethod
def spt_assert_metric(tree_if):
'''
@type metric: AssertMetric
@type tree_if: TreeInterface
'''
metric = AssertMetric()
metric._pref = 1 # TODO: ver isto melhor no route preference
metric._metric = tree_if.metric
metric._node = tree_if.node
return metric
# overrides
def __str__(self):
return "AssertMetric<{}:{}:{}>".format(self._pref, self._metric,
self._node)
from abc import ABCMeta, abstractstaticmethod
from tree import globals as pim_globals
class OriginatorStateABC(metaclass=ABCMeta):
def recvDataMsgFromSource(tree):
pass
@abstractstaticmethod
def SRTexpires(tree):
pass
@abstractstaticmethod
def SATexpires(tree):
pass
@abstractstaticmethod
def SourceNotConnected(tree):
pass
class Originator(OriginatorStateABC):
@staticmethod
def recvDataMsgFromSource(tree):
tree.source_active_timer.reset()
@staticmethod
def SRTexpires(tree):
'''
@type tree: Tree
'''
tree.rprint('SRT expired, O to O')
tree.state_refresh_timer.reset()
tree.send_state_refresh_msg()
@staticmethod
def SATexpires(tree):
tree.rprint('SAT expired, O to NO')
tree.source_active_timer.stop()
tree.state_refresh_timer.stop()
tree.originator_state = OriginatorState.NotOriginator
@staticmethod
def SourceNotConnected(tree):
tree.rprint('Source no longer directly connected, O to NO')
tree.source_active_timer.stop()
tree.state_refresh_timer.stop()
tree.originator_state = OriginatorState.NotOriginator
class NotOriginator(OriginatorStateABC):
@staticmethod
def recvDataMsgFromSource(tree):
'''
@type interface: Tree
'''
tree.originator_state = OriginatorState.Originator
tree.state_refresh_timer.start()
tree.source_active_timer.start()
tree.rprint('new DataMsg from Source, NO to O')
# Since the recording of the TTL is common to both states,its registering is made on the
# Tree.new_state_refresh_msg(...) method
@staticmethod
def SRTexpires(tree):
assert False
@staticmethod
def SATexpires(tree):
assert False
@staticmethod
def SourceNotConnected(tree):
pass
class OriginatorState():
NotOriginator = NotOriginator()
Originator = Originator()
from threading import Timer
class OriginatorState(object):
def OriginatorState(self):
self._source_active_timer = None # type: Timer
self._state_refresh_timer = None # type: Timer
def set_source_active_timer(self):
self.clear_source_active_timer()
self._source_active_timer = Timer()
def clear_source_active_timer(self):
if self._source_active_timer is not None:
self._source_active_timer.cancel()
def set_state_refresh_timer(self):
self.clear_state_refresh_timer()
self._state_refresh_timer = Timer()
def clear_state_refresh_timer(self):
if self._state_refresh_timer is not None:
self._state_refresh_timer.cancel()
from threading import Timer
class PruneState(object):
def PruneState(self):
self._prune_state = SFMRPruneState.DIP
self._prune_pending_timer = None
self._prune_timer = None
def set_prune_pending_timer(self):
self.clear_prune_pending_timer()
self._prune_pending_timer= Timer()
def clear_prune_pending_timer(self):
if self._prune_pending_timer is not None:
self._prune_pending_timer.cancel()
def set_prune_timer(self):
self.clear_prune_timer()
self._prune_timer= Timer()
def clear_prune_timer(self):
if self._prune_timer is not None:
self._prune_timer.cancel()
'''
Created on Jul 16, 2015
@author: alex
'''
#from convergence import Convergence
#from des.event.timer import Timer
from threading import Timer
from .assert_ import AssertState, AssertStateABC
#from .messages.assert_msg import SFMRAssertMsg
#from .messages.reset import SFMResetMsg
from .metric import AssertMetric
from .downstream_prune import DownstreamState, DownstreamStateABS
from .tree_interface import TreeInterface
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
from threading import Lock
class TreeInterfaceDownstream(TreeInterface):
def __init__(self, kernel_entry, interface_id):
TreeInterface.__init__(self, kernel_entry, interface_id)
# State
self._local_membership_state = None # todo NoInfo or Include
# Prune State
self._prune_state = DownstreamState.NoInfo
self._prune_pending_timer = None
self._prune_timer = None
# Assert Winner State
self._assert_state = AssertState.Winner
self._assert_timer = None
self._assert_winner_ip = None
self._assert_winner_metric = None
#self.set_dipt_timer()
#self.send_prune()
##########################################
# Set state
##########################################
def set_prune_state(self, new_state: DownstreamStateABS):
with self.get_state_lock():
if new_state != self._prune_state:
self._prune_state = new_state
self.change_tree()
self.evaluate_ingroup()
##########################################
# Check timers
##########################################
def is_prune_pending_timer_running(self):
return self._prune_pending_timer is not None and self._prune_pending_timer.is_alive()
def is_prune_timer_running(self):
return self._prune_timer is not None and self._prune_timer.is_alive()
##########################################
# Set timers
##########################################
def set_prune_pending_timer(self, time):
self.clear_prune_pending_timer()
self._prune_pending_timer = Timer(time, self.prune_pending_timeout)
self._prune_pending_timer.start()
def clear_prune_pending_timer(self):
if self._prune_pending_timer is not None:
self._prune_pending_timer.cancel()
def set_prune_timer(self, time):
self.clear_prune_timer()
self._prune_timer = Timer(time, self.prune_timeout)
self._prune_timer.start()
def clear_prune_timer(self):
if self._prune_timer is not None:
self._prune_timer.cancel()
###########################################
# Timer timeout
###########################################
def prune_pending_timeout(self):
self._prune_state.PPTexpires(self, 10)
def prune_timeout(self):
self._prune_state.PTexpires(self)
###########################################
# Recv packets
###########################################
# Override
def recv_prune_msg(self):
self._prune_state.receivedPrune(self, 0)
# Override
def recv_join_msg(self):
self._prune_state.receivedJoin(self)
# Override
def recv_graft_msg(self):
self._prune_state.receivedGraft(self)
# Override
def is_forwarding(self):
return ((len(self.get_interface().neighbors) >= 1 and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert()
# todo wtf is boundary??!!
#return self._assert_state == AssertState.Winner and self.is_in_group()
def is_pruned(self):
return self._prune_state == DownstreamState.Pruned
def lost_assert(self):
return self._assert_state == AssertState.Loser
# Override
def nbr_connected(self):
self._prune_state.new_nbr(self)
# Override
def delete(self):
TreeInterface.delete(self)
#self._get_dipt().cancel()
def get_metric(self):
return AssertMetric.spt_assert_metric(self)
def _set_assert_state(self, value: AssertStateABC):
with self.get_state_lock():
if value != self._assert_state:
self._assert_state = value
self.change_tree()
self.evaluate_ingroup()
#Convergence.mark_change()
def _get_winner_metric(self):
'''
@rtype: SFMRAssertMetric
'''
return self._assert_metric
def _set_winner_metric(self, value):
assert isinstance(value, AssertMetric) or value is None
# todo
self._assert_metric = value
def is_downstream(self):
return True
'''
Created on Jul 16, 2015
@author: alex
'''
#from des.addr import Addr
#from .messages.assert_msg import SFMRAssertMsg
#from .messages.join import SFMRJoinMsg
from .tree_interface import TreeInterface
from .upstream_prune import UpstreamState
from threading import Timer
from .globals import *
import random
class TreeInterfaceUpstream(TreeInterface):
def __init__(self, kernel_entry, interface_id, is_originater: bool):
TreeInterface.__init__(self, kernel_entry, interface_id)
self._graft_prune_state = UpstreamState.Forward
self._graft_retry_timer = None
self._override_timer = None
self._prune_limit_timer = None
self._originator_state = None
##########################################
# Set state
##########################################
def set_state(self, new_state):
with self.get_state_lock():
if new_state != self._graft_prune_state:
self._graft_prune_state = new_state
self.change_tree()
self.evaluate_ingroup()
##########################################
# Check timers
##########################################
def is_graft_retry_timer_running(self):
return self._graft_retry_timer is not None and self._graft_retry_timer.is_alive()
def is_override_timer_running(self):
return self._override_timer is not None and self._override_timer.is_alive()
def is_prune_limit_timer_running(self):
return self._prune_limit_timer is not None and self._prune_limit_timer.is_alive()
##########################################
# Set timers
##########################################
def set_graft_retry_timer(self, time=GRAFT_RETRY_PERIOD):
self.clear_graft_retry_timer()
self._graft_retry_timer = Timer(time, self.graft_retry_timeout)
self._graft_retry_timer.start()
def clear_graft_retry_timer(self):
if self._graft_retry_timer is not None:
self._graft_retry_timer.cancel()
def set_override_timer(self):
self.clear_override_timer()
self._override_timer = Timer(self.t_override, self.override_timeout)
self._override_timer.start()
def clear_override_timer(self):
if self._override_timer is not None:
self._override_timer.cancel()
def set_prune_limit_timer(self, time=T_LIMIT):
self.clear_prune_limit_timer()
self._prune_limit_timer = Timer(time, self.prune_limit_timeout)
self._prune_limit_timer.start()
def clear_prune_limit_timer(self):
if self._prune_limit_timer is not None:
self._prune_limit_timer.cancel()
###########################################
# Timer timeout
###########################################
def graft_retry_timeout(self):
self._graft_prune_state.GRTexpires(self)
def override_timeout(self):
self._graft_prune_state.OTexpires(self)
def prune_limit_timeout(self):
return
###########################################
# Recv packets
###########################################
def recv_data_msg(self):
# todo check olist
if self.is_olist_null() and not self.is_prune_limit_timer_running():
self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self)
def recv_state_refresh_msg(self, prune_indicator: int):
# todo check rpf nbr
if prune_indicator == 1:
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs1(self)
elif prune_indicator == 0 and not self.is_prune_limit_timer_running():
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(self)
def recv_join_msg(self):
# todo check rpf nbr
self._graft_prune_state.seeJoinToRPFnbr(self)
def recv_prune_msg(self):
self._graft_prune_state.seePrune(self)
def recv_graft_ack_msg(self):
# todo check rpf nbr
self._graft_prune_state.recvGraftAckFromRPFnbr(self)
###########################################
# Change olist
###########################################
def olist_is_null(self):
self._graft_prune_state.olistIsNowNull(self)
def olist_is_not_null(self):
self._graft_prune_state.olistIsNowNotNull(self)
###########################################
# Changes on Unicast Routing Table
###########################################
# todo
#Override
def is_forwarding(self):
return False
#Override
def delete(self):
super().delete()
def is_downstream(self):
return False
#-------------------------------------------------------------------------
# Properties
#-------------------------------------------------------------------------
@property
def t_override(self):
oi = self.get_interface()._override_internal
return random.uniform(0, oi)
'''
Created on Jul 16, 2015
@author: alex
'''
from abc import ABCMeta, abstractmethod
import Main
from threading import Lock, RLock
import traceback
#from convergence import Convergence
#from sfmr.messages.prune import SFMRPruneMsg
#from .router_interface import SFMRInterface
from .downstream_prune import DownstreamState
from .assert_ import AssertState
from Packet.PacketPimGraft import PacketPimGraft
from Packet.PacketPimGraftAck import PacketPimGraftAck
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from Packet.PacketPimHeader import PacketPimHeader
from Packet.Packet import Packet
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimAssert import PacketPimAssert
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
class TreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id):
'''
@type interface: SFMRInterface
@type node: Node
'''
#assert isinstance(interface, SFMRInterface)
self._kernel_entry = kernel_entry
self._interface_id = interface_id
#self._interface = interface
#self._node = node
#self._tree_id = tree_id
#self._cost = cost
#self._evaluate_ig = evaluate_ig_cb
try:
interface_name = Main.kernel.vif_index_to_name_dic[interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
group_state = igmp_interface.interface_state.get_group_state(kernel_entry.group_ip)
self._igmp_has_members = group_state.add_multicast_routing_entry(self)
except:
#traceback.print_exc()
self._igmp_has_members = False
# Local Membership State
self._local_membership_state = None # todo NoInfo or Include
# Prune State
self._prune_state = DownstreamState.NoInfo
self._prune_pending_timer = None
self._prune_timer = None
# Assert Winner State
self._assert_state = AssertState.Winner
self._assert_timer = None
self._assert_winner_ip = None
self._assert_winner_metric = None
self._igmp_lock = RLock()
#self.rprint('new ' + self.__class__.__name__)
def recv_data_msg(self):
pass
def recv_assert_msg(self):
pass
def recv_reset_msg(self):
pass
def recv_prune_msg(self):
pass
def recv_join_msg(self):
pass
def recv_graft_msg(self):
pass
def recv_graft_ack_msg(self):
pass
def recv_state_refresh_msg(self, prune_indicator):
pass
def forward_state_reset_msg(self):
raise NotImplemented
######################################
# Send messages
######################################
def send_graft(self):
print("send graft")
try:
(source, group) = self.get_tree_id()
# todo self.get_rpf_()
ph = PacketPimGraft("10.0.0.13")
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
#msg = GraftMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg)
except:
return
def send_graft_ack(self):
print("send graft ack")
try:
(source, group) = self.get_tree_id()
# todo endereco?!!
ph = PacketPimGraftAck("10.0.0.13")
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
#msg = GraftAckMsg(self.get_tree().tree_id, self.get_node())
#self.pim_if.send_mcast(msg)
except:
return
def send_prune(self):
print("send prune")
try:
(source, group) = self.get_tree_id()
# todo help ip of ph
ph = PacketPimJoinPrune("123.123.123.123", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
print('sent prune msg')
except:
return
def send_pruneecho(self):
print("send prune echo")
# todo
#msg = PruneMsg(self.get_tree().tree_id,
# self.get_node(), self._assert_timer.time_left())
#self.pim_if.send_mcast(msg)
return
def send_join(self):
print("send join")
try:
(source, group) = self.get_tree_id()
# todo help ip of ph
ph = PacketPimJoinPrune("123.123.123.123", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
#msg = JoinMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg)
except:
return
def send_assert(self):
print("send assert")
import UnicastRouting
try:
(source, group) = self.get_tree_id()
(entry_protocol, entry_cost) = UnicastRouting.get_metric(source)
# todo help ip of ph
ph = PacketPimAssert(multicast_group_address=group, source_address=source, metric_preference=entry_protocol, metric=entry_cost)
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
#msg = AssertMsg(self.tree_id, self.assert_metric)
#self.pim_if.send_mcast(msg)
except:
return
def send_assert_cancel(self):
print("send cancel")
#msg = AssertMsg.new_assert_cancel(self.tree_id)
#self.pim_if.send_mcast(msg)
pass
@abstractmethod
def is_forwarding(self):
pass
def nbr_died(self, node):
pass
def nbr_connected(self):
pass
#@abstractmethod
def is_now_root(self):
pass
@abstractmethod
def delete(self):
print('Tree Interface deleted')
def is_olist_null(self):
return self._kernel_entry.is_olist_null()
def evaluate_ingroup(self):
self._kernel_entry.evaluate_olist_change()
def notify_igmp(self, has_members: bool):
with self.get_state_lock():
#with self._igmp_lock:
if has_members != self._igmp_has_members:
self._igmp_has_members = has_members
self.change_tree()
self.evaluate_ingroup()
def igmp_has_members(self):
#with self._igmp_lock:
return self._igmp_has_members
def rprint(self, msg, *entrys):
return
def __str__(self):
return '{}<{}>'.format(self.__class__, self._interface.get_link())
def get_link(self):
# todo
return self._interface.get_link()
def get_interface(self):
kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id]
interface = Main.interfaces[interface_name]
return interface
def get_node(self):
# todo: para ser substituido por get_ip
return self.get_ip()
def get_ip(self):
ip = self.get_interface().get_ip()
return ip
def get_tree_id(self):
return (self._kernel_entry.source_ip, self._kernel_entry.group_ip)
def change_tree(self):
self._kernel_entry.change()
def get_state_lock(self):
return self._kernel_entry.CHANGE_STATE_LOCK
@abstractmethod
def is_downstream(self):
raise NotImplementedError()
def get_rpf_(self):
return self.get_neighbor_RPF()
# obtain ip of RPF'(S)
def get_neighbor_RPF(self):
'''
RPF'(S)
'''
if not self.is_assert_winner():
return self._assert_winner_ip
else:
return self._kernel_entry._rpf_node
def is_assert_winner(self):
return not self.is_downstream() and not self._assert_state == AssertState.Loser
\ No newline at end of file
from abc import ABCMeta, abstractstaticmethod
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from .tree_if_upstream import TreeInterfaceUpstream
class UpstreamStateABC(metaclass=ABCMeta):
@abstractstaticmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: Upstream
"""
raise NotImplementedError()
@abstractstaticmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: Upstream
"""
raise NotImplementedError()
class AckPending(UpstreamStateABC):
"""
AckPending (AP)
The router was in the Pruned(P) state, but a transition has
occurred in the Downstream(S,G) state machine for one of this
(S,G) entry’s outgoing interfaces, indicating that traffic from S
addressed to G should again be forwarded. A Graft message has
been sent to RPF’(S), but a Graft Ack message has not yet been
received.
"""
@staticmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
#assert False
return
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: TreeInterfaceUpstream
"""
#interface.set_ot()
interface.set_override_timer()
print('stateRefreshArrivesRPFnbr_pruneIs1, AP -> AP')
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
print(UpstreamState.Forward)
interface.set_state(UpstreamState.Forward)
#interface.get_grt().cancel()
interface.clear_graft_retry_timer()
print(
'stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped, AP -> P')
@staticmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: TreeInterfaceUpstream
"""
#interface.cancel_ot()
interface.clear_override_timer()
print('seeJoinToRPFnbr, P -> P')
@staticmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: TreeInterfaceUpstream
"""
#interface.set_ot()
interface.set_override_timer()
interface.rprint('seePrune, AP -> AP')
@staticmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
interface.send_join()
interface.rprint('OTexpires, AP -> AP')
@staticmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: TreeInterfaceUpstream
"""
interface.set_state(UpstreamState.Pruned)
#timer = interface._prune_limit_timer
#timer.set_timer(interface.t_override)
interface.set_prune_limit_timer()
#timer.start()
#interface.get_grt().stop()
interface.clear_graft_retry_timer()
interface.send_prune()
print("olistIsNowNull, AP -> P")
@staticmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: TreeInterfaceUpstream
"""
#assert False
return
@staticmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: TreeInterfaceUpstream
"""
interface.send_graft()
#interface.get_grt().reset()
interface.set_graft_retry_timer()
print('olistIsNowNotNull, AP -> AP')
@staticmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: TreeInterfaceUpstream
"""
#interface.get_grt().cancel()
interface.clear_graft_retry_timer()
print('RPFnbrChanges_olistIsNull, AP -> P')
@staticmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: TreeInterfaceUpstream
"""
interface.set_state(UpstreamState.Forward)
#interface.get_grt().stop()
interface.clear_graft_retry_timer()
print("sourceIsNowDirectConnect, AP -> F")
@staticmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
#interface.get_grt().start()
interface.set_graft_retry_timer()
interface.send_graft()
print('GRTexpires, AP -> AP')
@staticmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: TreeInterfaceUpstream
"""
interface.set_state(UpstreamState.Forward)
#interface.get_grt().stop()
interface.clear_graft_retry_timer()
print('recvGraftAckFromRPFnbr, AP -> F')
class Forward(UpstreamStateABC):
"""
Forwarding (F)
This is the starting state of the Upsteam(S,G) state machine.
The state machine is in this state if it just started or if
oiflist(S,G) != NULL.
"""
@staticmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
interface.set_state(UpstreamState.Pruned)
#interface.get_ot().stop()
#timer = interface._prune_limit_timer
#timer.set_timer(interface.t_override)
#timer.start()
interface.set_prune_limit_timer()
interface.send_prune()
print("dataArrivesRPFinterface_OListNull_PLTstoped, F -> P")
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: TreeInterfaceUpstream
"""
#interface.set_ot()
interface.set_override_timer()
print('stateRefreshArrivesRPFnbr_pruneIs1, F -> F')
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
print(
'stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped, F -> F')
@staticmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: TreeInterfaceUpstream
"""
#interface.cancel_ot()
interface.clear_override_timer()
print('seeJoinToRPFnbr, F -> F')
@staticmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: TreeInterfaceUpstream
"""
#interface.set_ot()
interface.set_override_timer()
print('seePrune, F -> F')
@staticmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
interface.send_join()
print('OTexpires, F -> F')
@staticmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: TreeInterfaceUpstream
"""
interface.set_state(UpstreamState.Pruned)
#timer = interface._prune_limit_timer
#timer.set_timer(interface.t_override)
#timer.start()
interface.set_prune_limit_timer()
interface.send_prune()
print("olistIsNowNull, F -> P")
@staticmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: TreeInterfaceUpstream
"""
#assert False
return
@staticmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: TreeInterfaceUpstream
"""
interface.send_graft()
#interface.get_grt().start()
interface.set_graft_retry_timer()
interface.set_state(UpstreamState.AckPending)
print('RPFnbrChanges_olistIsNotNull, F -> AP')
@staticmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: TreeInterfaceUpstream
"""
interface.set_state(UpstreamState.Pruned)
print('RPFnbrChanges_olistIsNull, F -> P')
@staticmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: TreeInterfaceUpstream
"""
print("sourceIsNowDirectConnect, F -> F")
@staticmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
#assert False
return
@staticmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: TreeInterfaceUpstream
"""
print('recvGraftAckFromRPFnbr, P -> P')
class Pruned(UpstreamStateABC):
'''
Pruned (P)
The set, olist(S,G), is empty.
The router will not forward data from S addressed to group G.
'''
@staticmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
interface.set_state(UpstreamState.Pruned)
# todo send prune?!?!?!?!
#timer = interface._prune_limit_timer
#timer.set_timer(interface.t_override)
#timer.start()
interface.set_prune_limit_timer()
print("dataArrivesRPFinterface_OListNull_PLTstoped, P -> P")
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: TreeInterfaceUpstream
"""
#interface.get_plt().reset()
interface.set_prune_limit_timer()
interface.rprint('stateRefreshArrivesRPFnbr_pruneIs1, P -> P')
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
# todo: desnecessario pq PLT stopped????!!!
#plt = interface.get_plt()
#if not plt.is_ticking():
# plt.start()
# interface.send_prune()
interface.send_prune()
interface.set_prune_limit_timer()
print(
'stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped, P -> P')
@staticmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: TreeInterfaceUpstream
"""
# Do nothing
print('seeJoinToRPFnbr, P -> P')
@staticmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: TreeInterfaceUpstream
"""
print('seePrune, P -> P')
@staticmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
#assert False
return
@staticmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: TreeInterfaceUpstream
"""
#assert False
return
@staticmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: TreeInterfaceUpstream
"""
interface.send_graft()
#interface.get_grt().start()
interface.set_graft_retry_timer()
interface.set_state(UpstreamState.AckPending)
print('olistIsNowNotNull, P -> AP')
@staticmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: TreeInterfaceUpstream
"""
interface.send_graft()
#interface.get_grt().start()
interface.set_graft_retry_timer()
interface.set_state(UpstreamState.AckPending)
print('olistIsNowNotNull, P -> AP')
@staticmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: TreeInterfaceUpstream
"""
#interface.get_plt().stop()
interface.clear_prune_limit_timer()
print('RPFnbrChanges_olistIsNull, P -> P')
@staticmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: TreeInterfaceUpstream
"""
print("sourceIsNowDirectConnect, P -> P")
@staticmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
#assert False
return
@staticmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: TreeInterfaceUpstream
"""
print('recvGraftAckFromRPFnbr, P -> P')
class UpstreamState():
Forward = Forward()
Pruned = Pruned()
AckPending = AckPending()
import array
'''
import struct
if struct.pack("H",1) == "\x00\x01": # big endian
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return s & 0xffff
else:
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s>>8)&0xff)|s<<8) & 0xffff
'''
HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF
HELLO_HOLD_TIME = 160
HELLO_HOLD_TIME_TIMEOUT = 0
def checksum(pkt: bytes) -> bytes:
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s >> 8) & 0xff) | s << 8) & 0xffff
import ctypes
import ctypes.util
libc = ctypes.CDLL(ctypes.util.find_library('c'))
def if_nametoindex(name):
if not isinstance(name, str):
raise TypeError('name must be a string.')
ret = libc.if_nametoindex(name)
if not ret:
raise RuntimeError("Invalid Name")
return ret
def if_indextoname(index):
if not isinstance(index, int):
raise TypeError('index must be an int.')
libc.if_indextoname.argtypes = [ctypes.c_uint32, ctypes.c_char_p]
libc.if_indextoname.restype = ctypes.c_char_p
ifname = ctypes.create_string_buffer(32)
ifname = libc.if_indextoname(index, ifname)
if not ifname:
raise RuntimeError ("Inavlid Index")
return ifname.decode("utf-8")
# obtain TYPE_CHECKING (for type hinting)
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False
# IGMP timers (in seconds)
RobustnessVariable = 2
QueryInterval = 125
QueryResponseInterval = 10
MaxResponseTime_QueryResponseInterval = QueryResponseInterval*10
GroupMembershipInterval = RobustnessVariable * QueryInterval + QueryResponseInterval
OtherQuerierPresentInterval = RobustnessVariable * QueryInterval + QueryResponseInterval/2
StartupQueryInterval = QueryInterval / 4
StartupQueryCount = RobustnessVariable
LastMemberQueryInterval = 1
MaxResponseTime_LastMemberQueryInterval = LastMemberQueryInterval*10
LastMemberQueryCount = RobustnessVariable
UnsolicitedReportInterval = 10
Version1RouterPresentTimeout = 400
# IGMP msg type
Membership_Query = 0x11
Version_1_Membership_Report = 0x12
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment