Commit fd60f995 authored by Pedro Oliveira's avatar Pedro Oliveira

IGMP efficient packet reception (via BPF) & StateRefresh/Originator state...

IGMP efficient packet reception (via BPF) & StateRefresh/Originator state machine & socket to recv (S,G) data packets (also with BPF) & fix some state machine errors
parent 43fc51da
...@@ -18,8 +18,10 @@ class Hello: ...@@ -18,8 +18,10 @@ class Hello:
options = packet.payload.payload.get_options() options = packet.payload.payload.get_options()
if (1 in options) and (20 in options): if (1 in options) and (20 in options):
hello_hold_time = options[1] #hello_hold_time = options[1]
generation_id = options[20] hello_hold_time = options[1].holdtime
#generation_id = options[20]
generation_id = options[20].generation_id
else: else:
raise Exception raise Exception
......
...@@ -34,6 +34,7 @@ class Interface(object): ...@@ -34,6 +34,7 @@ class Interface(object):
# set socket TTL to 1 # set socket TTL to 1
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1) s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
s.setsockopt(socket.IPPROTO_IP, socket.IP_TTL, 1)
# don't receive outgoing packets # don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 0) s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 0)
...@@ -62,6 +63,7 @@ class Interface(object): ...@@ -62,6 +63,7 @@ class Interface(object):
packet = None packet = None
return packet return packet
except Exception: except Exception:
traceback.print_exc()
return None return None
""" """
......
...@@ -5,6 +5,7 @@ import netifaces ...@@ -5,6 +5,7 @@ import netifaces
from Packet.ReceivedPacket import ReceivedPacket from Packet.ReceivedPacket import ReceivedPacket
import Main import Main
import traceback import traceback
from ctypes import create_string_buffer, addressof
if not hasattr(socket, 'SO_BINDTODEVICE'): if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25 socket.SO_BINDTODEVICE = 25
...@@ -12,17 +13,32 @@ if not hasattr(socket, 'SO_BINDTODEVICE'): ...@@ -12,17 +13,32 @@ if not hasattr(socket, 'SO_BINDTODEVICE'):
class InterfaceIGMP(object): class InterfaceIGMP(object):
ETH_P_IP = 0x0800 # Internet Protocol packet ETH_P_IP = 0x0800 # Internet Protocol packet
FILTER_IGMP = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
struct.pack('HBBI', 0x15, 0, 3, 0x00000800),
struct.pack('HBBI', 0x30, 0, 0, 0x00000017),
struct.pack('HBBI', 0x15, 0, 1, 0x00000002),
struct.pack('HBBI', 0x6, 0, 0, 0x00040000),
struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
]
SO_ATTACH_FILTER = 26
PACKET_MR_ALLMULTI = 2 PACKET_MR_ALLMULTI = 2
def __init__(self, interface_name: str, vif_index:int): def __init__(self, interface_name: str, vif_index:int):
# RECEIVE SOCKET # RECEIVE SOCKET
rcv_s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP)) rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
# allow all multicast packets # receive only IGMP packets by setting a BPF filter
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.PACKET_MR_ALLMULTI, struct.pack("i HH BBBBBBBB", 0, InterfaceIGMP.PACKET_MR_ALLMULTI, 0, 0,0,0,0,0,0,0,0)) filters = b''.join(InterfaceIGMP.FILTER_IGMP)
b = create_string_buffer(filters)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', len(InterfaceIGMP.FILTER_IGMP), mem_addr_of_filters)
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.SO_ATTACH_FILTER, fprog)
# bind to interface # bind to interface
rcv_s.bind((interface_name, 0)) rcv_s.bind((interface_name, 0x0800))
self.recv_socket = rcv_s self.recv_socket = rcv_s
...@@ -62,17 +78,9 @@ class InterfaceIGMP(object): ...@@ -62,17 +78,9 @@ class InterfaceIGMP(object):
def receive(self): def receive(self):
while self.interface_enabled: while self.interface_enabled:
try: try:
(raw_packet, x) = self.recv_socket.recvfrom(256 * 1024) (raw_packet, _) = self.recv_socket.recvfrom(256 * 1024)
if raw_packet: if raw_packet:
raw_packet = raw_packet[14:] raw_packet = raw_packet[14:]
from Packet.PacketIpHeader import PacketIpHeader
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, raw_packet[:PacketIpHeader.IP_HDR_LEN])
#print(proto)
if proto != socket.IPPROTO_IGMP:
continue
#print((raw_packet, x))
packet = ReceivedPacket(raw_packet, self) packet = ReceivedPacket(raw_packet, self)
Main.igmp.receive_handle(packet) Main.igmp.receive_handle(packet)
except Exception: except Exception:
......
...@@ -5,12 +5,14 @@ from Packet.ReceivedPacket import ReceivedPacket ...@@ -5,12 +5,14 @@ from Packet.ReceivedPacket import ReceivedPacket
import Main import Main
import traceback import traceback
from RWLock.RWLock import RWLockWrite from RWLock.RWLock import RWLockWrite
from Packet.PacketPimHelloOptions import *
from Packet.PacketPimHello import PacketPimHello from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader from Packet.PacketPimHeader import PacketPimHeader
from Packet.Packet import Packet from Packet.Packet import Packet
from Hello import Hello from Hello import Hello
from utils import HELLO_HOLD_TIME_TIMEOUT from utils import HELLO_HOLD_TIME_TIMEOUT
from threading import Timer from threading import Timer
from tree.globals import REFRESH_INTERVAL
class InterfacePim(Interface): class InterfacePim(Interface):
MCAST_GRP = '224.0.0.13' MCAST_GRP = '224.0.0.13'
...@@ -20,7 +22,7 @@ class InterfacePim(Interface): ...@@ -20,7 +22,7 @@ class InterfacePim(Interface):
MAX_TRIGGERED_HELLO_PERIOD = 5 MAX_TRIGGERED_HELLO_PERIOD = 5
def __init__(self, interface_name: str, vif_index:int): def __init__(self, interface_name: str, vif_index:int, state_refresh_capable:bool=False):
super().__init__(interface_name) super().__init__(interface_name)
# generation id # generation id
...@@ -33,9 +35,8 @@ class InterfacePim(Interface): ...@@ -33,9 +35,8 @@ class InterfacePim(Interface):
self.hello_timer.start() self.hello_timer.start()
# state refresh capable
# todo: state refresh capable self._state_refresh_capable = state_refresh_capable
self._state_refresh_capable = False
# todo: lan delay enabled # todo: lan delay enabled
self._lan_delay_enabled = False self._lan_delay_enabled = False
...@@ -58,10 +59,6 @@ class InterfacePim(Interface): ...@@ -58,10 +59,6 @@ class InterfacePim(Interface):
receive_thread.daemon = True receive_thread.daemon = True
receive_thread.start() receive_thread.start()
def create_virtual_interface(self):
self.vif_index = Main.kernel.create_virtual_interface(ip_interface=self.ip_interface, interface_name=self.interface_name)
def receive(self): def receive(self):
while self.is_enabled(): while self.is_enabled():
try: try:
...@@ -80,8 +77,15 @@ class InterfacePim(Interface): ...@@ -80,8 +77,15 @@ class InterfacePim(Interface):
self.hello_timer.cancel() self.hello_timer.cancel()
pim_payload = PacketPimHello() pim_payload = PacketPimHello()
pim_payload.add_option(1, 3.5 * Hello.TRIGGERED_HELLO_DELAY) pim_payload.add_option(PacketPimHelloHoldtime(holdtime=3.5 * Hello.TRIGGERED_HELLO_DELAY))
pim_payload.add_option(20, self.generation_id) pim_payload.add_option(PacketPimHelloGenerationID(self.generation_id))
# TODO implementar LANPRUNEDELAY e OVERRIDE_INTERVAL por interface e nas maquinas de estados ler valor de interface e nao do globals.py
#pim_payload.add_option(PacketPimHelloLANPruneDelay(lan_prune_delay=self._propagation_delay, override_interval=self._override_interval))
if self._state_refresh_capable:
pim_payload.add_option(PacketPimHelloStateRefreshCapable(REFRESH_INTERVAL))
ph = PacketPimHeader(pim_payload) ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph) packet = Packet(payload=ph)
self.send(packet.bytes()) self.send(packet.bytes())
...@@ -96,8 +100,8 @@ class InterfacePim(Interface): ...@@ -96,8 +100,8 @@ class InterfacePim(Interface):
# send pim_hello timeout message # send pim_hello timeout message
pim_payload = PacketPimHello() pim_payload = PacketPimHello()
pim_payload.add_option(1, HELLO_HOLD_TIME_TIMEOUT) pim_payload.add_option(PacketPimHelloHoldtime(holdtime=HELLO_HOLD_TIME_TIMEOUT))
pim_payload.add_option(20, self.generation_id) pim_payload.add_option(PacketPimHelloGenerationID(self.generation_id))
ph = PacketPimHeader(pim_payload) ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph) packet = Packet(payload=ph)
self.send(packet.bytes()) self.send(packet.bytes())
...@@ -130,3 +134,7 @@ class InterfacePim(Interface): ...@@ -130,3 +134,7 @@ class InterfacePim(Interface):
def remove_neighbor(self, ip): def remove_neighbor(self, ip):
with self.neighbors_lock.genWlock(): with self.neighbors_lock.genWlock():
del self.neighbors[ip] del self.neighbors[ip]
def is_state_refresh_enabled(self):
return self._state_refresh_capable
...@@ -41,6 +41,12 @@ class Kernel: ...@@ -41,6 +41,12 @@ class Kernel:
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
# Interface flags
VIFF_TUNNEL = 0x1 # IPIP tunnel
VIFF_SRCRT = 0x2 # NI
VIFF_REGISTER = 0x4 # register vif
VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface
def __init__(self): def __init__(self):
# Kernel is running # Kernel is running
self.running = True self.running = True
...@@ -66,6 +72,9 @@ class Kernel: ...@@ -66,6 +72,9 @@ class Kernel:
self.rwlock = RWLockWrite() self.rwlock = RWLockWrite()
self.interface_lock = Lock() self.interface_lock = Lock()
# Create register interface
# todo useless in PIM-DM... useful in PIM-SM
#self.create_virtual_interface("0.0.0.0", "pimreg", index=0, flags=Kernel.VIFF_REGISTER)
# Create virtual interfaces # Create virtual interfaces
''' '''
...@@ -149,8 +158,55 @@ class Kernel: ...@@ -149,8 +158,55 @@ class Kernel:
return index return index
def create_pim_interface(self, interface_name: str, state_refresh_capable:bool):
from InterfacePIM import InterfacePim
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if pim_interface:
# already exists
return
elif igmp_interface:
index = igmp_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.pim_interface:
pim_interface = InterfacePim(interface_name, index, state_refresh_capable)
self.pim_interface[interface_name] = pim_interface
ip_interface = pim_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
def create_igmp_interface(self, interface_name: str):
from InterfaceIGMP import InterfaceIGMP
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if igmp_interface:
# already exists
return
elif pim_interface:
index = pim_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.igmp_interface:
igmp_interface = InterfaceIGMP(interface_name, index)
self.igmp_interface[interface_name] = igmp_interface
ip_interface = igmp_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
'''
def create_interface(self, interface_name: str, igmp:bool = False, pim:bool = False): def create_interface(self, interface_name: str, igmp:bool = False, pim:bool = False):
from InterfaceIGMP import InterfaceIGMP from InterfaceIGMP import InterfaceIGMP
from InterfacePIM import InterfacePim from InterfacePIM import InterfacePim
...@@ -180,7 +236,7 @@ class Kernel: ...@@ -180,7 +236,7 @@ class Kernel:
if not vif_already_exists: if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index) self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
'''
...@@ -263,20 +319,6 @@ class Kernel: ...@@ -263,20 +319,6 @@ class Kernel:
# TODO: ver melhor tabela routing # TODO: ver melhor tabela routing
#self.routing[(socket.inet_ntoa(source_ip), socket.inet_ntoa(group_ip))] = {"inbound_interface_index": inbound_interface_index, "outbound_interfaces": outbound_interfaces} #self.routing[(socket.inet_ntoa(source_ip), socket.inet_ntoa(group_ip))] = {"inbound_interface_index": inbound_interface_index, "outbound_interfaces": outbound_interfaces}
'''
def flood(self, ip_src, ip_dst, iif):
source_ip = socket.inet_aton(ip_src)
group_ip = socket.inet_aton(ip_dst)
outbound_interfaces = [1]*Kernel.MAXVIFS
outbound_interfaces[iif] = 0
outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, iif, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
'''
def remove_multicast_route(self, kernel_entry: KernelEntry): def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip) source_ip = socket.inet_aton(kernel_entry.source_ip)
...@@ -312,8 +354,7 @@ class Kernel: ...@@ -312,8 +354,7 @@ class Kernel:
def handler(self): def handler(self):
while self.running: while self.running:
try: try:
msg = self.socket.recv(5000) msg = self.socket.recv(20)
#print(len(msg))
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst) = struct.unpack("II B B B B 4s 4s", msg[:20]) (_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst) = struct.unpack("II B B B B 4s 4s", msg[:20])
print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst))) print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
...@@ -336,6 +377,9 @@ class Kernel: ...@@ -336,6 +377,9 @@ class Kernel:
elif im_msgtype == Kernel.IGMPMSG_WRONGVIF: elif im_msgtype == Kernel.IGMPMSG_WRONGVIF:
print("WRONG VIF HANDLER") print("WRONG VIF HANDLER")
self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif) self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif)
#elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT:
# print("IGMP_WHOLEPKT")
# self.igmpmsg_wholepacket_handler(ip_src, ip_dst)
else: else:
raise Exception raise Exception
except Exception: except Exception:
...@@ -379,6 +423,17 @@ class Kernel: ...@@ -379,6 +423,17 @@ class Kernel:
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif) self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
#kernel_entry.recv_data_msg(iif) #kernel_entry.recv_data_msg(iif)
''' useless in PIM-DM... useful in PIM-SM
def igmpmsg_wholepacket_handler(self, ip_src, ip_dst):
#kernel_entry = self.routing[(ip_src, ip_dst)]
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg()
#kernel_entry.recv_data_msg(iif)
'''
""" """
def get_routing_entry(self, source_group: tuple): def get_routing_entry(self, source_group: tuple):
with self.rwlock.genRlock(): with self.rwlock.genRlock():
......
...@@ -15,6 +15,14 @@ kernel = None ...@@ -15,6 +15,14 @@ kernel = None
igmp = None igmp = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False):
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
def add_igmp_interface(interface_name):
kernel.create_igmp_interface(interface_name=interface_name)
'''
def add_interface(interface_name, pim=False, igmp=False): def add_interface(interface_name, pim=False, igmp=False):
#if pim is True and interface_name not in interfaces: #if pim is True and interface_name not in interfaces:
# interface = InterfacePim(interface_name) # interface = InterfacePim(interface_name)
...@@ -28,7 +36,7 @@ def add_interface(interface_name, pim=False, igmp=False): ...@@ -28,7 +36,7 @@ def add_interface(interface_name, pim=False, igmp=False):
# interfaces[interface_name] = kernel.pim_interface[interface_name] # interfaces[interface_name] = kernel.pim_interface[interface_name]
#if igmp: #if igmp:
# igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name] # igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name]
'''
def remove_interface(interface_name, pim=False, igmp=False): def remove_interface(interface_name, pim=False, igmp=False):
#if pim is True and ((interface_name in interfaces) or interface_name == "*"): #if pim is True and ((interface_name in interfaces) or interface_name == "*"):
...@@ -76,38 +84,6 @@ def list_neighbors(): ...@@ -76,38 +84,6 @@ def list_neighbors():
def list_enabled_interfaces(): def list_enabled_interfaces():
global interfaces global interfaces
# TESTE DE PIM JOIN/PRUNE
for interface in interfaces:
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
ph = PacketPimJoinPrune("10.0.0.13", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1", "10.1.1.1"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
ph = PacketPimJoinPrune("ff08::1", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("2001:1:a:b:c::1", ["1.1.1.1", "2001:1:a:b:c::2"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1"], ["2001:1:a:b:c::3"]))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
from Packet.PacketPimAssert import PacketPimAssert
ph = PacketPimAssert("224.12.12.12", "10.0.0.2", 210, 2)
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
from Packet.PacketPimGraft import PacketPimGraft
ph = PacketPimGraft("10.0.0.13")
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'IGMP State']) t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'IGMP State'])
for interface in netifaces.interfaces(): for interface in netifaces.interfaces():
...@@ -191,12 +167,14 @@ def main(): ...@@ -191,12 +167,14 @@ def main():
from JoinPrune import JoinPrune from JoinPrune import JoinPrune
from GraftAck import GraftAck from GraftAck import GraftAck
from Graft import Graft from Graft import Graft
from StateRefresh import StateRefresh
Hello() Hello()
Assert() Assert()
JoinPrune() JoinPrune()
Graft() Graft()
GraftAck() GraftAck()
StateRefresh()
global kernel global kernel
kernel = Kernel() kernel = Kernel()
......
import struct import struct
from abc import ABCMeta, abstractstaticmethod
from .PacketPimHelloOptions import PacketPimHelloOptions, PacketPimHelloStateRefreshCapable, PacketPimHelloGenerationID, PacketPimHelloLANPruneDelay, PacketPimHelloHoldtime
''' '''
0 1 2 3 0 1 2 3
...@@ -25,6 +27,7 @@ class PacketPimHello: ...@@ -25,6 +27,7 @@ class PacketPimHello:
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS) PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
PIM_MSG_TYPES_LENGTH = {1: 2, PIM_MSG_TYPES_LENGTH = {1: 2,
2: 4,
20: 4, 20: 4,
21: 4, 21: 4,
} }
...@@ -33,16 +36,26 @@ class PacketPimHello: ...@@ -33,16 +36,26 @@ class PacketPimHello:
def __init__(self): def __init__(self):
self.options = {} self.options = {}
'''
def add_option(self, option_type: int, option_value: int or float): def add_option(self, option_type: int, option_value: int or float):
option_value = int(option_value) option_value = int(option_value)
# if option_value requires more bits than the bits available for that field: option value will have all field bits = 1 # if option_value requires more bits than the bits available for that field: option value will have all field bits = 1
if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length(): if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1 option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option_type] = option_value self.options[option_type] = option_value
'''
def add_option(self, option: 'PacketPimHelloOptions'):
#if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
# option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option.type] = option
def get_options(self): def get_options(self):
return self.options return self.options
'''
def bytes(self) -> bytes: def bytes(self) -> bytes:
res = b'' res = b''
for (option_type, option_value) in self.options.items(): for (option_type, option_value) in self.options.items():
...@@ -50,10 +63,21 @@ class PacketPimHello: ...@@ -50,10 +63,21 @@ class PacketPimHello:
type_length_hdr = struct.pack(PacketPimHello.PIM_HDR_OPTS, option_type, option_length) type_length_hdr = struct.pack(PacketPimHello.PIM_HDR_OPTS, option_type, option_length)
res += type_length_hdr + struct.pack("! " + str(option_length) + "s", option_value.to_bytes(option_length, byteorder='big')) res += type_length_hdr + struct.pack("! " + str(option_length) + "s", option_value.to_bytes(option_length, byteorder='big'))
return res return res
'''
def bytes(self) -> bytes:
res = b''
for option in self.options.values():
res += option.bytes()
return res
def __len__(self): def __len__(self):
return len(self.bytes()) return len(self.bytes())
'''
@staticmethod @staticmethod
def parse_bytes(data: bytes): def parse_bytes(data: bytes):
pim_payload = PacketPimHello() pim_payload = PacketPimHello()
...@@ -66,13 +90,24 @@ class PacketPimHello: ...@@ -66,13 +90,24 @@ class PacketPimHello:
(option_value,) = struct.unpack("! " + str(option_length) + "s", data[:option_length]) (option_value,) = struct.unpack("! " + str(option_length) + "s", data[:option_length])
option_value_number = int.from_bytes(option_value, byteorder='big') option_value_number = int.from_bytes(option_value, byteorder='big')
print("option value: ", option_value_number) print("option value: ", option_value_number)
'''
options_list.append({"OPTION TYPE": option_type, #options_list.append({"OPTION TYPE": option_type,
"OPTION LENGTH": option_length, # "OPTION LENGTH": option_length,
"OPTION VALUE": option_value_number # "OPTION VALUE": option_value_number
}) # })
'''
pim_payload.add_option(option_type, option_value_number) pim_payload.add_option(option_type, option_value_number)
data = data[option_length:] data = data[option_length:]
return pim_payload return pim_payload
'''
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
while data != b'':
option = PacketPimHelloOptions.parse_bytes(data)
option_length = len(option)
data = data[option_length:]
pim_payload.add_option(option)
return pim_payload
import struct
from abc import ABCMeta
import math
class PacketPimHelloOptions(metaclass=ABCMeta):
PIM_HDR_OPTS = "! HH"
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type: int, length: int):
self.type = type
self.length = length
def bytes(self) -> bytes:
return struct.pack(PacketPimHelloOptions.PIM_HDR_OPTS, self.type, self.length)
def __len__(self):
return self.PIM_HDR_OPTS_LEN + self.length
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
(type, length) = struct.unpack(PacketPimHelloOptions.PIM_HDR_OPTS,
data[:PacketPimHelloOptions.PIM_HDR_OPTS_LEN])
print("TYPE:", type)
print("LENGTH:", length)
data = data[PacketPimHelloOptions.PIM_HDR_OPTS_LEN:]
#return PIM_MSG_TYPES[type](data)
return PIM_MSG_TYPES.get(type, PacketPimHelloUnknown).parse_bytes(data, type, length)
class PacketPimHelloStateRefreshCapable(PacketPimHelloOptions):
PIM_HDR_OPT = "! BBH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 1 | Interval | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
VERSION = 1
def __init__(self, interval: int):
super().__init__(type=21, length=4)
self.interval = interval
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.VERSION, self.interval, 0)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(version, interval, _) = struct.unpack(PacketPimHelloStateRefreshCapable.PIM_HDR_OPT,
data[:PacketPimHelloStateRefreshCapable.PIM_HDR_OPT_LEN])
return PacketPimHelloStateRefreshCapable(interval)
class PacketPimHelloLANPruneDelay(PacketPimHelloOptions):
PIM_HDR_OPT = "! HH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|T| LAN Prune Delay | Override Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, lan_prune_delay: float, override_interval: float):
super().__init__(type=2, length=4)
self.lan_prune_delay = 0x7FFF & math.ceil(lan_prune_delay)
self.override_interval = math.ceil(override_interval)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.lan_prune_delay, self.override_interval)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(lan_prune_delay, override_interval) = struct.unpack(PacketPimHelloLANPruneDelay.PIM_HDR_OPT,
data[:PacketPimHelloLANPruneDelay.PIM_HDR_OPT_LEN])
lan_prune_delay = lan_prune_delay & 0x7FFF
return PacketPimHelloLANPruneDelay(lan_prune_delay=lan_prune_delay, override_interval=override_interval)
class PacketPimHelloHoldtime(PacketPimHelloOptions):
PIM_HDR_OPT = "! H"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, holdtime: int or float):
super().__init__(type=1, length=2)
self.holdtime = int(holdtime)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.holdtime)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(holdtime, ) = struct.unpack(PacketPimHelloHoldtime.PIM_HDR_OPT,
data[:PacketPimHelloHoldtime.PIM_HDR_OPT_LEN])
print("HOLDTIME:", holdtime)
return PacketPimHelloHoldtime(holdtime=holdtime)
class PacketPimHelloGenerationID(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Generation ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, generation_id: int):
super().__init__(type=20, length=4)
self.generation_id = generation_id
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.generation_id)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(generation_id, ) = struct.unpack(PacketPimHelloGenerationID.PIM_HDR_OPT,
data[:PacketPimHelloGenerationID.PIM_HDR_OPT_LEN])
print("GenerationID:", generation_id)
return PacketPimHelloGenerationID(generation_id=generation_id)
class PacketPimHelloUnknown(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Unknown |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type, length):
super().__init__(type=type, length=length)
print("PIM Hello Option Unknown... TYPE=", type, "LENGTH=", length)
def bytes(self) -> bytes:
raise Exception
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
return PacketPimHelloUnknown(type, length)
PIM_MSG_TYPES = {1: PacketPimHelloHoldtime,
2: PacketPimHelloLANPruneDelay,
20: PacketPimHelloGenerationID,
21: PacketPimHelloStateRefreshCapable,
}
...@@ -64,10 +64,13 @@ class MyDaemon(Daemon): ...@@ -64,10 +64,13 @@ class MyDaemon(Daemon):
elif 'list_state' in args and args.list_state: elif 'list_state' in args and args.list_state:
connection.sendall(pickle.dumps(Main.list_state())) connection.sendall(pickle.dumps(Main.list_state()))
elif 'add_interface' in args and args.add_interface: elif 'add_interface' in args and args.add_interface:
Main.add_interface(args.add_interface[0], pim=True) Main.add_pim_interface(args.add_interface[0], False)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_sr' in args and args.add_interface_sr:
Main.add_pim_interface(args.add_interface_sr[0], True)
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_igmp' in args and args.add_interface_igmp: elif 'add_interface_igmp' in args and args.add_interface_igmp:
Main.add_interface(args.add_interface_igmp[0], igmp=True) Main.add_igmp_interface(args.add_interface_igmp[0])
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface' in args and args.remove_interface: elif 'remove_interface' in args and args.remove_interface:
Main.remove_interface(args.remove_interface[0], pim=True) Main.remove_interface(args.remove_interface[0], pim=True)
...@@ -99,6 +102,7 @@ if __name__ == "__main__": ...@@ -99,6 +102,7 @@ if __name__ == "__main__":
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP") group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table") group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface") group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface")
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface") group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface") group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface") group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
......
...@@ -4,9 +4,9 @@ from Packet.Packet import Packet ...@@ -4,9 +4,9 @@ from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from Interface import Interface from Interface import Interface
import Main import Main
from utils import HELLO_HOLD_TIME_TIMEOUT
class StateRefresh: class StateRefresh:
...@@ -17,8 +17,27 @@ class StateRefresh: ...@@ -17,8 +17,27 @@ class StateRefresh:
# receive handler # receive handler
def receive_handle(self, packet: ReceivedPacket): def receive_handle(self, packet: ReceivedPacket):
#check if interface supports state refresh
if not packet.interface._state_refresh_capable:
return
ip = packet.ip_header.ip_src ip = packet.ip_header.ip_src
print("ip = ", ip) print("ip = ", ip)
pkt_join_prune = packet.payload.payload pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh
# TODO # TODO
raise Exception
\ No newline at end of file interface_index = packet.interface.vif_index
source = pkt_state_refresh.source_address
group = pkt_state_refresh.multicast_group_adress
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
try:
# import time
# time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
pass
...@@ -65,7 +65,10 @@ class UnicastRouting(object): ...@@ -65,7 +65,10 @@ class UnicastRouting(object):
unicast_routing_entry = UnicastRouting.get_route(ip_dst) unicast_routing_entry = UnicastRouting.get_route(ip_dst)
entry_protocol = unicast_routing_entry["proto"] entry_protocol = unicast_routing_entry["proto"]
entry_cost = unicast_routing_entry["priority"] entry_cost = unicast_routing_entry["priority"]
return (entry_protocol, entry_cost) mask = unicast_routing_entry["dst_len"]
if entry_cost is None:
entry_cost = 0
return (entry_protocol, entry_cost, mask)
""" """
def get_rpf(ip_dst: str): def get_rpf(ip_dst: str):
......
import subprocess
import struct
import socket
from ctypes import create_string_buffer, addressof
SO_ATTACH_FILTER = 26
ETH_P_IP = 0x0800 # Internet Protocol packet
def get_s_g_bpf_filter_code(source, group, interface_name):
cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group)
result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
bpf_filter = b''
FILTER = []
tmp = result.stdout.read().splitlines()
num = int(tmp[0])
for line in tmp[1:]: #read and store result in log file
print(line)
#FILTER += (struct.pack("HBBI", *tuple(map(int, line.split(b' ')))), )
bpf_filter += struct.pack("HBBI", *tuple(map(int, line.split(b' '))))
print(num)
print(FILTER)
# defined in linux/filter.h.
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', num, mem_addr_of_filters)
# Create listening socket with filters
s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, ETH_P_IP)
s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
s.bind((interface_name, ETH_P_IP))
return s
...@@ -8,6 +8,7 @@ from .tree_interface import TreeInterface ...@@ -8,6 +8,7 @@ from .tree_interface import TreeInterface
from threading import Timer, Lock, RLock from threading import Timer, Lock, RLock
from tree.metric import AssertMetric from tree.metric import AssertMetric
import UnicastRouting import UnicastRouting
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
class KernelEntry: class KernelEntry:
TREE_TIMEOUT = 180 TREE_TIMEOUT = 180
...@@ -37,21 +38,12 @@ class KernelEntry: ...@@ -37,21 +38,12 @@ class KernelEntry:
print("RPF_NODE:", UnicastRouting.get_route(source_ip)) print("RPF_NODE:", UnicastRouting.get_route(source_ip))
print(self.rpf_node == source_ip) print(self.rpf_node == source_ip)
# (S,G) starts IG state # (S,G) starts IG state
self._was_olist_null = False self._was_olist_null = False
# todo
#self._rpf_is_origin = False
self._originator_state = OriginatorState.NotOriginator
# decide inbound interface based on rpf check # decide inbound interface based on rpf check
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()] self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
#Main.kernel.flood(source_ip, group_ip, self.inbound_interface_index)
self.interface_state = {} # type: Dict[int, TreeInterface] self.interface_state = {} # type: Dict[int, TreeInterface]
for i in Main.kernel.vif_index_to_name_dic.keys(): for i in Main.kernel.vif_index_to_name_dic.keys():
try: try:
...@@ -71,10 +63,6 @@ class KernelEntry: ...@@ -71,10 +63,6 @@ class KernelEntry:
self.change() self.change()
self.evaluate_olist_change() self.evaluate_olist_change()
print('Tree created') print('Tree created')
#self._liveliness_timer = None
#if self.is_originater():
# self.set_liveliness_timer()
# print('set SAT')
#self._lock = threading.RLock() #self._lock = threading.RLock()
...@@ -92,9 +80,9 @@ class KernelEntry: ...@@ -92,9 +80,9 @@ class KernelEntry:
return UnicastRouting.check_rpf(self.source_ip) return UnicastRouting.check_rpf(self.source_ip)
################################# ################################################
# Receive (S,G) packet # Receive (S,G) data packets or control packets
################################# ################################################
def recv_data_msg(self, index): def recv_data_msg(self, index):
print("recv data") print("recv data")
self.interface_state[index].recv_data_msg() self.interface_state[index].recv_data_msg()
...@@ -132,8 +120,38 @@ class KernelEntry: ...@@ -132,8 +120,38 @@ class KernelEntry:
def recv_state_refresh_msg(self, index, packet): def recv_state_refresh_msg(self, index, packet):
print("recv state refresh msg") print("recv state refresh msg")
prune_indicator = 1 source_of_state_refresh = packet.ip_header.ip_src
self.interface_state[index].recv_state_refresh_msg(prune_indicator)
metric_preference = packet.payload.payload.metric_preference
metric = packet.payload.payload.metric
mask_len = packet.payload.payload.mask_len
ttl = packet.payload.payload.ttl
prune_indicator_flag = packet.payload.payload.prune_indicator_flag #P
assert_override_flag = packet.payload.payload.assert_override_flag #O
interval = packet.payload.payload.interval
received_metric = AssertMetric(metric_preference=metric_preference, route_metric=metric, ip_address=source_of_state_refresh, state_refresh_interval=interval)
self.interface_state[index].recv_state_refresh_msg(received_metric, prune_indicator_flag)
iif = packet.interface.vif_index
if iif != self.inbound_interface_index:
return
if self.interface_state[iif].get_neighbor_RPF() != source_of_state_refresh:
return
# todo refresh limit
if ttl == 0:
return
self.forward_state_refresh_msg(packet.payload.payload)
################################################
# Send state refresh msg
################################################
def forward_state_refresh_msg(self, state_refresh_packet):
for interface in self.interface_state.values():
interface.send_state_refresh(state_refresh_packet)
############################################################### ###############################################################
...@@ -177,13 +195,9 @@ class KernelEntry: ...@@ -177,13 +195,9 @@ class KernelEntry:
self.rpf_node = rpf_node self.rpf_node = rpf_node
self.interface_state[self.inbound_interface_index].change_rpf(self._was_olist_null) self.interface_state[self.inbound_interface_index].change_rpf(self._was_olist_null)
def update(self, caller, arg):
#todo
return
def nbr_event(self, link, node, event): def nbr_event(self, link, node, event):
# todo # todo pode ser interessante verificar se a adicao/remocao de vizinhos se altera o olist
return return
def is_olist_null(self): def is_olist_null(self):
......
...@@ -37,12 +37,11 @@ class AssertStateABC(metaclass=ABCMeta): ...@@ -37,12 +37,11 @@ class AssertStateABC(metaclass=ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abstractstaticmethod @abstractstaticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", assert_time, better_metric): def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
""" """
Receive Preferred Assert OR State Refresh Receive Preferred Assert OR State Refresh
@type interface: TreeInterface @type interface: TreeInterface
@type assert_time: int
@type better_metric: AssertMetric @type better_metric: AssertMetric
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -160,11 +159,11 @@ class NoInfoState(AssertStateABC): ...@@ -160,11 +159,11 @@ class NoInfoState(AssertStateABC):
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, NI -> W') 'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, NI -> W')
@staticmethod @staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, state_refresh_interval = None): def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
''' '''
@type interface: TreeInterface @type interface: TreeInterface
''' '''
#interface.assert_timer.set_timer(assert_time) state_refresh_interval = better_metric.state_refresh_interval
if state_refresh_interval is None: if state_refresh_interval is None:
# event caused by Assert Msg # event caused by Assert Msg
assert_timer_value = pim_globals.ASSERT_TIME assert_timer_value = pim_globals.ASSERT_TIME
...@@ -175,12 +174,8 @@ class NoInfoState(AssertStateABC): ...@@ -175,12 +174,8 @@ class NoInfoState(AssertStateABC):
interface.set_assert_timer(assert_timer_value) interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric) interface.set_assert_winner_metric(better_metric)
interface.set_assert_state(AssertState.Loser) interface.set_assert_state(AssertState.Loser)
#interface.assert_timer.reset()
#interface.assert_state = AssertState.Loser
#interface.assert_winner_metric = better_metric
# todo MUST also multicast a Prune(S,G) to the Assert winner <- TO THE colocar endereco do winner # MUST also multicast a Prune(S,G) to the Assert winner
if interface.could_assert(): if interface.could_assert():
interface.send_prune(holdtime=assert_timer_value) interface.send_prune(holdtime=assert_timer_value)
...@@ -240,14 +235,11 @@ class WinnerState(AssertStateABC): ...@@ -240,14 +235,11 @@ class WinnerState(AssertStateABC):
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, W -> W') 'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, W -> W')
@staticmethod @staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, state_refresh_interval = None): def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
''' '''
@type better_metric: AssertMetric @type better_metric: AssertMetric
''' '''
state_refresh_interval = better_metric.state_refresh_interval
#interface.assert_timer.set_timer(assert_time)
#interface.assert_timer.reset()
if state_refresh_interval is None: if state_refresh_interval is None:
# event caused by AssertMsg # event caused by AssertMsg
assert_timer_value = pim_globals.ASSERT_TIME assert_timer_value = pim_globals.ASSERT_TIME
...@@ -256,22 +248,15 @@ class WinnerState(AssertStateABC): ...@@ -256,22 +248,15 @@ class WinnerState(AssertStateABC):
assert_timer_value = state_refresh_interval*3 assert_timer_value = state_refresh_interval*3
interface.set_assert_timer(assert_timer_value) interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric) interface.set_assert_winner_metric(better_metric)
#interface.assert_state = AssertState.Loser
interface.set_assert_state(AssertState.Loser) interface.set_assert_state(AssertState.Loser)
if interface.could_assert: interface.send_prune(holdtime=assert_timer_value)
interface.send_prune(holdtime=assert_timer_value)
print('receivedPreferedMetric, W -> L') print('receivedPreferedMetric, W -> L')
@staticmethod @staticmethod
def sendStateRefresh(interface: "TreeInterfaceDownstream", state_refresh_interval): def sendStateRefresh(interface: "TreeInterfaceDownstream", state_refresh_interval):
#interface.assert_timer.set_timer(time)
interface.set_assert_timer(state_refresh_interval*3) interface.set_assert_timer(state_refresh_interval*3)
#interface.assert_timer.reset()
@staticmethod @staticmethod
def assertTimerExpires(interface: "TreeInterfaceDownstream"): def assertTimerExpires(interface: "TreeInterfaceDownstream"):
...@@ -334,12 +319,11 @@ class LoserState(AssertStateABC): ...@@ -334,12 +319,11 @@ class LoserState(AssertStateABC):
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, L -> L') 'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, L -> L')
@staticmethod @staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, state_refresh_interval = None): def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
''' '''
@type better_metric: AssertMetric @type better_metric: AssertMetric
''' '''
#interface.assert_timer.set_timer(assert_time) state_refresh_interval = better_metric.state_refresh_interval
#interface.assert_timer.reset()
if state_refresh_interval is None: if state_refresh_interval is None:
assert_timer_value = pim_globals.ASSERT_TIME assert_timer_value = pim_globals.ASSERT_TIME
else: else:
...@@ -353,7 +337,7 @@ class LoserState(AssertStateABC): ...@@ -353,7 +337,7 @@ class LoserState(AssertStateABC):
if interface.could_assert(): if interface.could_assert():
# todo enviar holdtime = assert_timer_value???! # todo enviar holdtime = assert_timer_value???!
interface.send_prune() interface.send_prune(holdtime=assert_timer_value)
print('receivedPreferedMetric, L -> L') print('receivedPreferedMetric, L -> L')
......
...@@ -34,7 +34,7 @@ class DownstreamStateABS(metaclass=ABCMeta): ...@@ -34,7 +34,7 @@ class DownstreamStateABS(metaclass=ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abstractstaticmethod @abstractstaticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime): def PPTexpires(interface: "TreeInterfaceDownstream"):
""" """
PPT(S,G) Expires PPT(S,G) Expires
...@@ -127,7 +127,7 @@ class NoInfo(DownstreamStateABS): ...@@ -127,7 +127,7 @@ class NoInfo(DownstreamStateABS):
print('receivedGraft, NI -> NI') print('receivedGraft, NI -> NI')
@staticmethod @staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime): def PPTexpires(interface: "TreeInterfaceDownstream"):
""" """
PPT(S,G) Expires PPT(S,G) Expires
...@@ -221,7 +221,7 @@ class PrunePending(DownstreamStateABS): ...@@ -221,7 +221,7 @@ class PrunePending(DownstreamStateABS):
print('receivedGraft, PP -> NI') print('receivedGraft, PP -> NI')
@staticmethod @staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime): def PPTexpires(interface: "TreeInterfaceDownstream"):
""" """
PPT(S,G) Expires PPT(S,G) Expires
...@@ -335,7 +335,7 @@ class Pruned(DownstreamStateABS): ...@@ -335,7 +335,7 @@ class Pruned(DownstreamStateABS):
print('receivedGraft, P -> NI') print('receivedGraft, P -> NI')
@staticmethod @staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime): def PPTexpires(interface: "TreeInterfaceDownstream"):
""" """
PPT(S,G) Expires PPT(S,G) Expires
......
...@@ -10,6 +10,7 @@ ASSERT_TIME = 180 ...@@ -10,6 +10,7 @@ ASSERT_TIME = 180
GRAFT_RETRY_PERIOD = 3 GRAFT_RETRY_PERIOD = 3
JT_OVERRIDE_INTERVAL = 3.0 JT_OVERRIDE_INTERVAL = 3.0
OVERRIDE_INTERVAL = 2.5 OVERRIDE_INTERVAL = 2.5
PROPAGATION_DELAY = 0.5
REFRESH_INTERVAL = 60 # State Refresh Interval REFRESH_INTERVAL = 60 # State Refresh Interval
SOURCE_LIFETIME = 210 SOURCE_LIFETIME = 210
T_LIMIT = 210 T_LIMIT = 210
......
...@@ -5,13 +5,14 @@ class AssertMetric(object): ...@@ -5,13 +5,14 @@ class AssertMetric(object):
Note: we consider the node name the ip of the metric. Note: we consider the node name the ip of the metric.
''' '''
def __init__(self, metric_preference: int or float = float("Inf"), route_metric: int or float = float("Inf"), ip_address: str = "0.0.0.0"): def __init__(self, metric_preference: int or float = float("Inf"), route_metric: int or float = float("Inf"), ip_address: str = "0.0.0.0", state_refresh_interval:int = None):
if type(ip_address) is str: if type(ip_address) is str:
ip_address = ipaddress.ip_address(ip_address) ip_address = ipaddress.ip_address(ip_address)
self._metric_preference = metric_preference self._metric_preference = metric_preference
self._route_metric = route_metric self._route_metric = route_metric
self._ip_address = ip_address self._ip_address = ip_address
self._state_refresh_interval = state_refresh_interval
def is_better_than(self, other): def is_better_than(self, other):
if self.metric_preference != other.metric_preference: if self.metric_preference != other.metric_preference:
...@@ -39,7 +40,7 @@ class AssertMetric(object): ...@@ -39,7 +40,7 @@ class AssertMetric(object):
''' '''
(source_ip, _) = tree_if.get_tree_id() (source_ip, _) = tree_if.get_tree_id()
import UnicastRouting import UnicastRouting
metric_preference, metric_cost = UnicastRouting.get_metric(source_ip) (metric_preference, metric_cost, _) = UnicastRouting.get_metric(source_ip)
return AssertMetric(metric_preference, metric_cost, tree_if.get_ip()) return AssertMetric(metric_preference, metric_cost, tree_if.get_ip())
...@@ -75,6 +76,14 @@ class AssertMetric(object): ...@@ -75,6 +76,14 @@ class AssertMetric(object):
self._ip_address = value self._ip_address = value
@property
def state_refresh_interval(self):
return self._state_refresh_interval
@state_refresh_interval.setter
def state_refresh_interval(self, value):
self._state_refresh_interval = value
def get_ip(self): def get_ip(self):
return str(self._ip_address) return str(self._ip_address)
...@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractstaticmethod ...@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractstaticmethod
from tree import globals as pim_globals from tree import globals as pim_globals
class OriginatorStateABC(metaclass=ABCMeta): class OriginatorStateABC(metaclass=ABCMeta):
@abstractstaticmethod
def recvDataMsgFromSource(tree): def recvDataMsgFromSource(tree):
pass pass
...@@ -22,33 +23,32 @@ class OriginatorStateABC(metaclass=ABCMeta): ...@@ -22,33 +23,32 @@ class OriginatorStateABC(metaclass=ABCMeta):
class Originator(OriginatorStateABC): class Originator(OriginatorStateABC):
@staticmethod @staticmethod
def recvDataMsgFromSource(tree): def recvDataMsgFromSource(tree):
tree.source_active_timer.reset() tree.set_source_active_timer()
@staticmethod @staticmethod
def SRTexpires(tree): def SRTexpires(tree):
''' '''
@type tree: Tree @type tree: Tree
''' '''
tree.rprint('SRT expired, O to O') print('SRT expired, O to O')
tree.state_refresh_timer.reset() tree.set_state_refresh_timer()
tree.send_state_refresh_msg() tree.create_state_refresh_msg()
@staticmethod @staticmethod
def SATexpires(tree): def SATexpires(tree):
tree.rprint('SAT expired, O to NO') print('SAT expired, O to NO')
tree.source_active_timer.stop() tree.clear_state_refresh_timer()
tree.state_refresh_timer.stop() tree.set_originator_state(OriginatorState.NotOriginator)
tree.originator_state = OriginatorState.NotOriginator
@staticmethod @staticmethod
def SourceNotConnected(tree): def SourceNotConnected(tree):
tree.rprint('Source no longer directly connected, O to NO') print('Source no longer directly connected, O to NO')
tree.source_active_timer.stop() tree.clear_state_refresh_timer()
tree.state_refresh_timer.stop() tree.clear_source_active_timer()
tree.originator_state = OriginatorState.NotOriginator tree.set_originator_state(OriginatorState.NotOriginator)
class NotOriginator(OriginatorStateABC): class NotOriginator(OriginatorStateABC):
...@@ -57,14 +57,12 @@ class NotOriginator(OriginatorStateABC): ...@@ -57,14 +57,12 @@ class NotOriginator(OriginatorStateABC):
''' '''
@type interface: Tree @type interface: Tree
''' '''
tree.originator_state = OriginatorState.Originator tree.set_originator_state(OriginatorState.Originator)
tree.state_refresh_timer.start() tree.set_state_refresh_timer()
tree.source_active_timer.start() tree.set_source_active_timer()
tree.rprint('new DataMsg from Source, NO to O') print('new DataMsg from Source, NO to O')
# Since the recording of the TTL is common to both states,its registering is made on the
# Tree.new_state_refresh_msg(...) method
@staticmethod @staticmethod
def SRTexpires(tree): def SRTexpires(tree):
...@@ -76,7 +74,7 @@ class NotOriginator(OriginatorStateABC): ...@@ -76,7 +74,7 @@ class NotOriginator(OriginatorStateABC):
@staticmethod @staticmethod
def SourceNotConnected(tree): def SourceNotConnected(tree):
pass return
class OriginatorState(): class OriginatorState():
......
...@@ -15,8 +15,13 @@ from .metric import AssertMetric ...@@ -15,8 +15,13 @@ from .metric import AssertMetric
from .downstream_prune import DownstreamState, DownstreamStateABS from .downstream_prune import DownstreamState, DownstreamStateABS
from .tree_interface import TreeInterface from .tree_interface import TreeInterface
from Packet.ReceivedPacket import ReceivedPacket from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
from threading import Lock from threading import Lock
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
import traceback
class TreeInterfaceDownstream(TreeInterface): class TreeInterfaceDownstream(TreeInterface):
def __init__(self, kernel_entry, interface_id): def __init__(self, kernel_entry, interface_id):
...@@ -88,7 +93,7 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -88,7 +93,7 @@ class TreeInterfaceDownstream(TreeInterface):
# Timer timeout # Timer timeout
########################################### ###########################################
def prune_pending_timeout(self): def prune_pending_timeout(self):
self._prune_state.PPTexpires(self, 10) self._prune_state.PPTexpires(self)
def prune_timeout(self): def prune_timeout(self):
self._prune_state.PTexpires(self) self._prune_state.PTexpires(self)
...@@ -103,7 +108,6 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -103,7 +108,6 @@ class TreeInterfaceDownstream(TreeInterface):
def recv_prune_msg(self, upstream_neighbor_address, holdtime): def recv_prune_msg(self, upstream_neighbor_address, holdtime):
super().recv_prune_msg(upstream_neighbor_address, holdtime) super().recv_prune_msg(upstream_neighbor_address, holdtime)
#TODO if upstream_neighbor_address == self.get_ip():
if upstream_neighbor_address == self.get_ip(): if upstream_neighbor_address == self.get_ip():
self.set_receceived_prune_holdtime(holdtime) self.set_receceived_prune_holdtime(holdtime)
self._prune_state.receivedPrune(self, holdtime) self._prune_state.receivedPrune(self, holdtime)
...@@ -124,6 +128,55 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -124,6 +128,55 @@ class TreeInterfaceDownstream(TreeInterface):
self._prune_state.receivedGraft(self, source_ip) self._prune_state.receivedGraft(self, source_ip)
######################################
# Send messages
######################################
def send_state_refresh(self, state_refresh_msg_received):
if not self.get_interface()._state_refresh_capable:
return
interval = state_refresh_msg_received.interval
self._assert_state.sendStateRefresh(self, interval)
self._prune_state.send_state_refresh(self)
if self.lost_assert():
return
prune_indicator_bit = 0
if self.is_pruned():
prune_indicator_bit = 1
# TODO set timer
# todo maybe ja feito na maquina de estados Prune downstream
# if state_refresh_capable
# set PT....
import UnicastRouting
(metric_preference, metric, mask) = UnicastRouting.get_metric(state_refresh_msg_received.source_address)
assert_override_flag = 0
if self._assert_state == AssertState.NoInfo:
assert_override_flag = 1
try:
ph = PacketPimStateRefresh(multicast_group_adress=state_refresh_msg_received.multicast_group_adress,
source_address=state_refresh_msg_received.source_address,
originator_adress=state_refresh_msg_received.originator_adress,
metric_preference=metric_preference, metric=metric, mask_len=mask,
ttl=state_refresh_msg_received.ttl - 1,
prune_indicator_flag=prune_indicator_bit,
prune_now_flag=state_refresh_msg_received.prune_now_flag,
assert_override_flag=assert_override_flag,
interval=interval)
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
except:
traceback.print_exc()
return
##########################################################
# Override # Override
def is_forwarding(self): def is_forwarding(self):
......
...@@ -12,6 +12,12 @@ from .upstream_prune import UpstreamState ...@@ -12,6 +12,12 @@ from .upstream_prune import UpstreamState
from threading import Timer from threading import Timer
from .globals import * from .globals import *
import random import random
from .metric import AssertMetric
from .originator import OriginatorState, OriginatorStateABC
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
import traceback
from . import DataPacketsSocket
import threading
class TreeInterfaceUpstream(TreeInterface): class TreeInterfaceUpstream(TreeInterface):
...@@ -22,10 +28,37 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -22,10 +28,37 @@ class TreeInterfaceUpstream(TreeInterface):
self._override_timer = None self._override_timer = None
self._prune_limit_timer = None self._prune_limit_timer = None
self._originator_state = None self._originator_state = OriginatorState.NotOriginator
self._state_refresh_timer = None
self._source_active_timer = None
self._prune_now_counter = 0
if self.is_S_directly_conn(): if self.is_S_directly_conn():
self._graft_prune_state.sourceIsNowDirectConnect(self) self._graft_prune_state.sourceIsNowDirectConnect(self)
self._originator_state.recvDataMsgFromSource(self)
# TODO TESTE SOCKET RECV DATA PCKTS
self.socket_is_enabled = True
(s,g) = self.get_tree_id()
interface_name = self.get_interface().interface_name
self.socket_pkt = DataPacketsSocket.get_s_g_bpf_filter_code(s, g, interface_name)
# run receive method in background
receive_thread = threading.Thread(target=self.socket_recv)
receive_thread.daemon = True
receive_thread.start()
def socket_recv(self):
while self.socket_is_enabled:
try:
self.socket_pkt.recvfrom(0)
print("PACOTE DADOS RECEBIDO")
self.recv_data_msg()
except:
traceback.print_exc()
continue
########################################## ##########################################
# Set state # Set state
...@@ -38,6 +71,9 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -38,6 +71,9 @@ class TreeInterfaceUpstream(TreeInterface):
self.change_tree() self.change_tree()
self.evaluate_ingroup() self.evaluate_ingroup()
def set_originator_state(self, new_state: OriginatorStateABC):
if new_state != self._originator_state:
self._originator_state = new_state
########################################## ##########################################
# Check timers # Check timers
...@@ -81,6 +117,26 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -81,6 +117,26 @@ class TreeInterfaceUpstream(TreeInterface):
if self._prune_limit_timer is not None: if self._prune_limit_timer is not None:
self._prune_limit_timer.cancel() self._prune_limit_timer.cancel()
# State Refresh timers
def set_state_refresh_timer(self):
self.clear_state_refresh_timer()
self._state_refresh_timer = Timer(REFRESH_INTERVAL, self.state_refresh_timeout)
self._state_refresh_timer.start()
def clear_state_refresh_timer(self):
if self._state_refresh_timer is not None:
self._state_refresh_timer.cancel()
def set_source_active_timer(self):
self.clear_source_active_timer()
self._source_active_timer = Timer(SOURCE_LIFETIME, self.source_active_timeout)
self._source_active_timer.start()
def clear_source_active_timer(self):
if self._source_active_timer is not None:
self._source_active_timer.cancel()
########################################### ###########################################
# Timer timeout # Timer timeout
########################################### ###########################################
...@@ -93,6 +149,13 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -93,6 +149,13 @@ class TreeInterfaceUpstream(TreeInterface):
def prune_limit_timeout(self): def prune_limit_timeout(self):
return return
# State Refresh timers
def state_refresh_timeout(self):
self._originator_state.SRTexpires(self)
def source_active_timeout(self):
self._originator_state.SATexpires(self)
########################################### ###########################################
# Recv packets # Recv packets
########################################### ###########################################
...@@ -101,12 +164,9 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -101,12 +164,9 @@ class TreeInterfaceUpstream(TreeInterface):
if self.is_olist_null() and not self.is_prune_limit_timer_running() and not self.is_S_directly_conn(): if self.is_olist_null() and not self.is_prune_limit_timer_running() and not self.is_S_directly_conn():
self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self) self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self)
def recv_state_refresh_msg(self, prune_indicator: int): if self.is_S_directly_conn():
# todo check rpf nbr self._originator_state.recvDataMsgFromSource(self)
if prune_indicator == 1:
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs1(self)
elif prune_indicator == 0 and not self.is_prune_limit_timer_running():
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(self)
def recv_join_msg(self, upstream_neighbor_address): def recv_join_msg(self, upstream_neighbor_address):
super().recv_join_msg(upstream_neighbor_address) super().recv_join_msg(upstream_neighbor_address)
...@@ -122,6 +182,33 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -122,6 +182,33 @@ class TreeInterfaceUpstream(TreeInterface):
# todo check rpf nbr # todo check rpf nbr
self._graft_prune_state.recvGraftAckFromRPFnbr(self) self._graft_prune_state.recvGraftAckFromRPFnbr(self)
def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator: int):
super().recv_state_refresh_msg(received_metric, prune_indicator)
if self.get_neighbor_RPF() != received_metric.get_ip():
return
if prune_indicator == 1:
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs1(self)
elif prune_indicator == 0 and not self.is_prune_limit_timer_running():
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(self)
####################################
def create_state_refresh_msg(self):
self._prune_now_counter+=1
self._prune_now_counter%=3
(source_ip, group_ip) = self.get_tree_id()
ph = PacketPimStateRefresh(multicast_group_adress=group_ip,
source_address=source_ip,
originator_adress=self.get_ip(),
metric_preference=0, metric=0, mask_len=0,
ttl=256,
prune_indicator_flag=0,
prune_now_flag=(self._prune_now_counter+1)//3,
assert_override_flag=0,
interval=60)
self._kernel_entry.forward_state_refresh_msg(ph)
########################################### ###########################################
# Change olist # Change olist
########################################### ###########################################
...@@ -147,14 +234,20 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -147,14 +234,20 @@ class TreeInterfaceUpstream(TreeInterface):
#Override #Override
def delete(self): def delete(self):
super().delete() super().delete()
self.socket_is_enabled = False
self.socket_pkt.close()
self.clear_graft_retry_timer() self.clear_graft_retry_timer()
self.clear_assert_timer() self.clear_assert_timer()
self.clear_prune_limit_timer() self.clear_prune_limit_timer()
self.clear_override_timer() self.clear_override_timer()
self.clear_state_refresh_timer()
self.clear_source_active_timer()
def is_downstream(self): def is_downstream(self):
return False return False
def is_originator(self):
return self._originator_state == OriginatorState.Originator
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
# Properties # Properties
......
...@@ -44,6 +44,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -44,6 +44,7 @@ class TreeInterface(metaclass=ABCMeta):
#self._cost = cost #self._cost = cost
#self._evaluate_ig = evaluate_ig_cb #self._evaluate_ig = evaluate_ig_cb
# Local Membership State
try: try:
interface_name = Main.kernel.vif_index_to_name_dic[interface_id] interface_name = Main.kernel.vif_index_to_name_dic[interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
...@@ -56,9 +57,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -56,9 +57,6 @@ class TreeInterface(metaclass=ABCMeta):
self._local_membership_state = LocalMembership.NoInfo self._local_membership_state = LocalMembership.NoInfo
# Local Membership State
#self._local_membership_state = None # todo NoInfo or Include
# Prune State # Prune State
self._prune_state = DownstreamState.NoInfo self._prune_state = DownstreamState.NoInfo
self._prune_pending_timer = None self._prune_pending_timer = None
...@@ -134,7 +132,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -134,7 +132,7 @@ class TreeInterface(metaclass=ABCMeta):
pass pass
def recv_assert_msg(self, received_metric: AssertMetric): def recv_assert_msg(self, received_metric: AssertMetric):
if self._assert_winner_metric.is_better_than(received_metric): if self.my_assert_metric().is_better_than(received_metric):
# received inferior assert # received inferior assert
if self._assert_winner_metric.ip_address == received_metric.ip_address: if self._assert_winner_metric.ip_address == received_metric.ip_address:
# received from assert winner # received from assert winner
...@@ -142,16 +140,10 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -142,16 +140,10 @@ class TreeInterface(metaclass=ABCMeta):
elif self.could_assert(): elif self.could_assert():
# received from non assert winner and could_assert # received from non assert winner and could_assert
self._assert_state.receivedInferiorMetricFromNonWinner_couldAssertIsTrue(self) self._assert_state.receivedInferiorMetricFromNonWinner_couldAssertIsTrue(self)
else: elif received_metric.is_better_than(self._assert_winner_metric):
#received preferred assert #received preferred assert
self._assert_state.receivedPreferedMetric(self, received_metric) self._assert_state.receivedPreferedMetric(self, received_metric)
def recv_reset_msg(self):
pass
def recv_prune_msg(self, upstream_neighbor_address, holdtime): def recv_prune_msg(self, upstream_neighbor_address, holdtime):
if upstream_neighbor_address == self.get_ip(): if upstream_neighbor_address == self.get_ip():
self._assert_state.receivedPruneOrJoinOrGraft(self) self._assert_state.receivedPruneOrJoinOrGraft(self)
...@@ -167,14 +159,8 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -167,14 +159,8 @@ class TreeInterface(metaclass=ABCMeta):
def recv_graft_ack_msg(self): def recv_graft_ack_msg(self):
pass pass
def recv_state_refresh_msg(self, prune_indicator): def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator):
pass self.recv_assert_msg(received_metric)
def forward_state_reset_msg(self):
raise NotImplemented
###################################### ######################################
...@@ -185,48 +171,37 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -185,48 +171,37 @@ class TreeInterface(metaclass=ABCMeta):
try: try:
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo self.get_rpf_()
ip_dst = self.get_neighbor_RPF() ip_dst = self.get_neighbor_RPF()
ph = PacketPimGraft(ip_dst) ph = PacketPimGraft(ip_dst)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes(), ip_dst) self.get_interface().send(pckt.bytes(), ip_dst)
#msg = GraftMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg)
except: except:
traceback.print_exc() traceback.print_exc()
return return
def send_graft_ack(self, ip_sender): def send_graft_ack(self, ip_sender):
print("send graft ack") print("send graft ack")
try: try:
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo endereco?!!
ph = PacketPimGraftAck(ip_sender) ph = PacketPimGraftAck(ip_sender)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes(), ip_sender) self.get_interface().send(pckt.bytes(), ip_sender)
#msg = GraftAckMsg(self.get_tree().tree_id, self.get_node())
#self.pim_if.send_mcast(msg)
except: except:
traceback.print_exc() traceback.print_exc()
return return
def send_prune(self, holdtime=None): def send_prune(self, holdtime=None):
if holdtime is None: if holdtime is None:
holdtime = T_LIMIT holdtime = T_LIMIT
print("send prune") print("send prune")
try: try:
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo help ip of ph
#ph = PacketPimJoinPrune("123.123.123.123", 210)
ph = PacketPimJoinPrune(self.get_neighbor_RPF(), holdtime) ph = PacketPimJoinPrune(self.get_neighbor_RPF(), holdtime)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
...@@ -242,7 +217,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -242,7 +217,6 @@ class TreeInterface(metaclass=ABCMeta):
holdtime = T_LIMIT holdtime = T_LIMIT
try: try:
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo help ip of ph
ph = PacketPimJoinPrune(self.get_ip(), holdtime) ph = PacketPimJoinPrune(self.get_ip(), holdtime)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
...@@ -252,24 +226,18 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -252,24 +226,18 @@ class TreeInterface(metaclass=ABCMeta):
except: except:
traceback.print_exc() traceback.print_exc()
return return
# todo
#msg = PruneMsg(self.get_tree().tree_id,
# self.get_node(), self._assert_timer.time_left())
#self.pim_if.send_mcast(msg)
def send_join(self): def send_join(self):
print("send join") print("send join")
try: try:
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo help ip of ph
ph = PacketPimJoinPrune(self.get_neighbor_RPF(), 210) ph = PacketPimJoinPrune(self.get_neighbor_RPF(), 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes()) self.get_interface().send(pckt.bytes())
#msg = JoinMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg)
except: except:
traceback.print_exc() traceback.print_exc()
return return
...@@ -290,8 +258,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -290,8 +258,6 @@ class TreeInterface(metaclass=ABCMeta):
return return
def send_assert_cancel(self): def send_assert_cancel(self):
print("send assert cancel") print("send assert cancel")
...@@ -304,12 +270,12 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -304,12 +270,12 @@ class TreeInterface(metaclass=ABCMeta):
except: except:
traceback.print_exc() traceback.print_exc()
return return
#msg = AssertMsg.new_assert_cancel(self.tree_id)
#self.pim_if.send_mcast(msg)
def send_state_refresh(self):
# todo time def send_state_refresh(self, state_refresh_msg_received: PacketPimStateRefresh):
self._assert_state.sendStateRefresh(self) pass
#############################################################
@abstractmethod @abstractmethod
def is_forwarding(self): def is_forwarding(self):
...@@ -364,10 +330,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -364,10 +330,6 @@ class TreeInterface(metaclass=ABCMeta):
def __str__(self): def __str__(self):
return '{}<{}>'.format(self.__class__, self._interface.get_link()) return '{}<{}>'.format(self.__class__, self._interface.get_link())
def get_link(self):
# todo
return self._interface.get_link()
def get_interface(self): def get_interface(self):
kernel = Main.kernel kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id] interface_name = kernel.vif_index_to_name_dic[self._interface_id]
...@@ -396,8 +358,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -396,8 +358,6 @@ class TreeInterface(metaclass=ABCMeta):
raise NotImplementedError() raise NotImplementedError()
#def get_rpf_(self):
# return self.get_neighbor_RPF()
# obtain ip of RPF'(S) # obtain ip of RPF'(S)
......
...@@ -169,8 +169,9 @@ class Forward(UpstreamStateABC): ...@@ -169,8 +169,9 @@ class Forward(UpstreamStateABC):
@type interface: TreeInterfaceUpstream @type interface: TreeInterfaceUpstream
""" """
#interface.set_ot() # if OT is not running the router must set OT to t_override seconds
interface.set_override_timer() if not interface.is_override_timer_running():
interface.set_override_timer()
print('stateRefreshArrivesRPFnbr_pruneIs1, F -> F') print('stateRefreshArrivesRPFnbr_pruneIs1, F -> F')
...@@ -332,14 +333,8 @@ class Pruned(UpstreamStateABC): ...@@ -332,14 +333,8 @@ class Pruned(UpstreamStateABC):
@type interface: TreeInterfaceUpstream @type interface: TreeInterfaceUpstream
""" """
if not interface.is_S_directly_conn(): if not interface.is_S_directly_conn():
#interface.set_state(UpstreamState.Pruned)
# todo send prune?!?!?!?!
#timer = interface._prune_limit_timer
#timer.set_timer(interface.t_override)
#timer.start()
interface.set_prune_limit_timer() interface.set_prune_limit_timer()
interface.send_prune()
print("dataArrivesRPFinterface_OListNull_PLTstoped, P -> P") print("dataArrivesRPFinterface_OListNull_PLTstoped, P -> P")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment