Commit 4c12b098 authored by Pedro Oliveira's avatar Pedro Oliveira

fix Hello&Neighbor (timers and deadlocks) & check olist of all trees when...

fix Hello&Neighbor (timers and deadlocks) & check olist of all trees when interface changes number of neighbors & packet reception handled inside of interfaces' classes (instead of methods in separate files) & fix send of state refresh through non-root interfaces (if assert loser dont check prune and assert state)
parent 2d3d8f7e
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
import Main
import traceback
class Assert:
TYPE = 5
def __init__(self):
Main.add_protocol(Assert.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
interface_name = interface.interface_name
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_assert = packet.payload.payload # type: PacketPimAssert
metric = pkt_assert.metric
metric_preference = pkt_assert.metric_preference
source = pkt_assert.source_address
group = pkt_assert.multicast_group_address
source_group = (source, group)
interface_name = packet.interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
try:
#Main.kernel.routing[source_group].recv_assert_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_assert_msg(interface_index, packet)
except:
traceback.print_exc()
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
class Graft:
TYPE = 6
def __init__(self):
Main.add_protocol(Graft.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
print("GRAFT!!")
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_graft_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
class GraftAck:
TYPE = 7
def __init__(self):
Main.add_protocol(GraftAck.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
print("GRAFT ACK!!")
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
from Packet.ReceivedPacket import ReceivedPacket
import Main
from Neighbor import Neighbor
class Hello:
TYPE = 0
TRIGGERED_HELLO_DELAY = 16 # TODO: configure via external file??
def __init__(self):
Main.add_protocol(Hello.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
options = packet.payload.payload.get_options()
if (1 in options) and (20 in options):
#hello_hold_time = options[1]
hello_hold_time = options[1].holdtime
#generation_id = options[20]
generation_id = options[20].generation_id
else:
raise Exception
with interface.neighbors_lock.genWlock():
if ip in interface.neighbors:
neighbor = interface.neighbors[ip]
else:
interface.neighbors[ip] = Neighbor(interface, ip, generation_id, hello_hold_time)
return
neighbor.receive_hello(generation_id, hello_hold_time)
"""
with neighbor.neighbor_lock:
# Already know Neighbor
print("neighbor conhecido")
neighbor.heartbeat()
if neighbor.hello_hold_time != hello_hold_time:
print("keep alive period diferente")
neighbor.set_hello_hold_time(hello_hold_time)
if neighbor.generation_id != generation_id:
print("neighbor reiniciado")
neighbor.set_generation_id(generation_id)
with interface.neighbors_lock.genWlock():
#if interface.get_neighbor(ip) is None:
if ip in interface.neighbors:
# Unknown Neighbor
if (1 in options) and (20 in options):
try:
#Main.add_neighbor(packet.interface, ip, options[20], options[1])
print("non neighbor and options inside")
except Exception:
# Received Neighbor with Timeout
print("non neighbor and options inside but neighbor timedout")
pass
return
print("non neighbor and required options not inside")
else:
# Already know Neighbor
print("neighbor conhecido")
neighbor = Main.get_neighbor(ip)
neighbor.heartbeat()
if 1 in options and neighbor.hello_hold_time != options[1]:
print("keep alive period diferente")
neighbor.set_hello_hold_time(options[1])
if 20 in options and neighbor.generation_id != options[20]:
print("neighbor reiniciado")
neighbor.remove()
Main.add_neighbor(packet.interface, ip, options[20], options[1])
"""
\ No newline at end of file
from Packet.ReceivedPacket import ReceivedPacket
from utils import *
from ipaddress import IPv4Address
class IGMP:
# receive handler
@staticmethod
def receive_handle(packet: ReceivedPacket):
interface = packet.interface
ip_src = packet.ip_header.ip_src
ip_dst = packet.ip_header.ip_dst
#print("ip = ", ip_src)
igmp_hdr = packet.payload
igmp_type = igmp_hdr.type
igmp_group = igmp_hdr.group_address
# source ip can't be 0.0.0.0 or multicast
if ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast:
return
if igmp_type == Version_1_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v1_membership_report(packet)
elif igmp_type == Version_2_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v2_membership_report(packet)
elif igmp_type == Leave_Group and ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_leave_group(packet)
elif igmp_type == Membership_Query and (ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0")):
interface.interface_state.receive_query(packet)
else:
raise Exception("Exception igmp packet: type={}; ip_dst={}; packet_group_report={}".format(igmp_type, ip_dst, igmp_group))
import socket
from abc import ABCMeta, abstractmethod
import threading
import random
import netifaces
......@@ -8,110 +9,55 @@ import traceback
from RWLock.RWLock import RWLockWrite
class Interface(object):
class Interface(metaclass=ABCMeta):
MCAST_GRP = '224.0.0.13'
# substituir ip por interface ou algo parecido
def __init__(self, interface_name: str):
def __init__(self, interface_name, recv_socket, send_socket, vif_index):
self.interface_name = interface_name
ip_interface = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']
self.ip_mask_interface = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['netmask']
self.ip_interface = ip_interface
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_PIM)
# virtual interface index for the multicast routing table
self.vif_index = vif_index
# allow other sockets to bind this port too
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# explicitly join the multicast group on the interface specified
#s.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8'))
# set socket output interface
s.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(ip_interface))
# set socket TTL to 1
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
s.setsockopt(socket.IPPROTO_IP, socket.IP_TTL, 1)
# don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 0)
self.socket = s
# set receive socket and send socket
self._send_socket = send_socket
self._recv_socket = recv_socket
self.interface_enabled = True
# generation id
#self.generation_id = random.getrandbits(32)
# todo neighbors
#self.neighbors = {}
#self.neighbors_lock = RWLockWrite()
# run receive method in background
#receive_thread = threading.Thread(target=self.receive)
#receive_thread.daemon = True
#receive_thread.start()
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def receive(self):
try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
else:
packet = None
return packet
except Exception:
traceback.print_exc()
return None
"""
while self.interface_enabled:
try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
(raw_bytes, _) = self._recv_socket.recvfrom(256 * 1024)
if raw_bytes:
self._receive(raw_bytes)
except Exception:
traceback.print_exc()
continue
"""
@abstractmethod
def _receive(self, raw_bytes):
raise NotImplementedError
def send(self, data: bytes, group_ip: str):
if self.interface_enabled and data:
self.socket.sendto(data, (group_ip, 0))
self._send_socket.sendto(data, (group_ip, 0))
def remove(self):
self.interface_enabled = False
try:
self.socket.shutdown(socket.SHUT_RDWR)
self._recv_socket.shutdown(socket.SHUT_RDWR)
except Exception:
pass
self.socket.close()
self._recv_socket.close()
self._send_socket.close()
def is_enabled(self):
return self.interface_enabled
@abstractmethod
def get_ip(self):
return self.ip_interface
"""
def add_neighbor(self, ip, random_number, hello_hold_time):
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
print("ADD NEIGHBOR")
from Neighbor import Neighbor
n = Neighbor(self, ip, random_number, hello_hold_time)
self.neighbors[ip] = n
Main.protocols[0].force_send(self)
def get_neighbors(self):
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
with self.neighbors_lock.genRlock():
return self.neighbors[ip]
"""
\ No newline at end of file
raise NotImplementedError
\ No newline at end of file
import socket
import struct
import threading
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from Interface import Interface
from ctypes import create_string_buffer, addressof
from ipaddress import IPv4Address
from utils import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query
if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25
class InterfaceIGMP(object):
class InterfaceIGMP(Interface):
ETH_P_IP = 0x0800 # Internet Protocol packet
SO_ATTACH_FILTER = 26
FILTER_IGMP = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
......@@ -22,10 +24,6 @@ class InterfaceIGMP(object):
struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
]
SO_ATTACH_FILTER = 26
PACKET_MR_ALLMULTI = 2
def __init__(self, interface_name: str, vif_index:int):
# RECEIVE SOCKET
rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
......@@ -40,7 +38,6 @@ class InterfaceIGMP(object):
# bind to interface
rcv_s.bind((interface_name, 0x0800))
self.recv_socket = rcv_s
# SEND SOCKET
snd_s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
......@@ -48,20 +45,12 @@ class InterfaceIGMP(object):
# bind to interface
snd_s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, str(interface_name + "\0").encode('utf-8'))
self.send_socket = snd_s
super().__init__(interface_name=interface_name, recv_socket=rcv_s, send_socket=snd_s, vif_index=vif_index)
self.interface_enabled = True
self.interface_name = interface_name
from igmp.RouterState import RouterState
self.interface_state = RouterState(self)
# virtual interface index for the multicast routing table
self.vif_index = vif_index
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
......@@ -70,24 +59,48 @@ class InterfaceIGMP(object):
def ip_interface(self):
return self.get_ip()
def send(self, data: bytes, address: str="224.0.0.1"):
if self.interface_enabled:
self.send_socket.sendto(data, (address, 0))
def receive(self):
while self.interface_enabled:
try:
(raw_packet, _) = self.recv_socket.recvfrom(256 * 1024)
if raw_packet:
raw_packet = raw_packet[14:]
packet = ReceivedPacket(raw_packet, self)
Main.igmp.receive_handle(packet)
except Exception:
traceback.print_exc()
continue
def remove(self):
self.interface_enabled = False
self.recv_socket.close()
self.send_socket.close()
super().send(data, address)
def _receive(self, raw_bytes):
if raw_bytes:
raw_bytes = raw_bytes[14:]
packet = ReceivedPacket(raw_bytes, self)
ip_src = packet.ip_header.ip_src
if not (ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast):
self.PKT_FUNCTIONS[packet.payload.get_igmp_type()](self, packet)
###########################################
# Recv packets
###########################################
def receive_version_1_membership_report(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_v1_membership_report(packet)
def receive_version_2_membership_report(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_v2_membership_report(packet)
def receive_leave_group(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_leave_group(packet)
def receive_membership_query(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"):
self.interface_state.receive_query(packet)
PKT_FUNCTIONS = {
Version_1_Membership_Report: receive_version_1_membership_report,
Version_2_Membership_Report: receive_version_2_membership_report,
Leave_Group: receive_leave_group,
Membership_Query: receive_membership_query,
}
This diff is collapsed.
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from Interface import Interface
import Main
import traceback
class JoinPrune:
TYPE = 3
def __init__(self):
Main.add_protocol(JoinPrune.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
# todo holdtime
holdtime = pkt_join_prune.hold_time
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
pruned_src_addresses = group.pruned_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_join_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
for source_address in pruned_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_prune_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
......@@ -470,3 +470,8 @@ class Kernel:
pass
# When interface changes number of neighbors verify if olist changes and prune/forward respectively
def interface_change_number_of_neighbors(self):
with self.rwlock.genWlock():
for entry in self.routing.values():
entry.change_at_number_of_neighbors()
......@@ -10,10 +10,9 @@ import UnicastRouting
interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces
protocols = {}
kernel = None
igmp = None
unicast_routing = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False):
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
......@@ -64,10 +63,6 @@ def remove_interface(interface_name, pim=False, igmp=False):
# print(igmp_interfaces)
kernel.remove_interface(interface_name, pim=pim, igmp=igmp)
def add_protocol(protocol_number, protocol_obj):
global protocols
protocols[protocol_number] = protocol_obj
def list_neighbors():
interfaces_list = interfaces.values()
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
......@@ -157,32 +152,16 @@ def list_routing_state():
def stop():
remove_interface("*", pim=True, igmp=True)
kernel.exit()
UnicastRouting.stop()
unicast_routing.stop()
def main():
from Hello import Hello
from IGMP import IGMP
from Assert import Assert
from JoinPrune import JoinPrune
from GraftAck import GraftAck
from Graft import Graft
from StateRefresh import StateRefresh
Hello()
Assert()
JoinPrune()
Graft()
GraftAck()
StateRefresh()
global kernel
kernel = Kernel()
global igmp
igmp = IGMP()
global u
u = UnicastRouting.UnicastRouting()
global unicast_routing
unicast_routing = UnicastRouting.UnicastRouting()
global interfaces
global igmp_interfaces
......
from threading import Timer
import time
from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING
from threading import Lock
from RWLock.RWLock import RWLockWrite
from threading import Lock, RLock
import Main
if TYPE_CHECKING:
from InterfacePIM import InterfacePim
class Neighbor:
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int):
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int, state_refresh_capable:bool):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
raise Exception
self.contact_interface = contact_interface
......@@ -17,7 +16,7 @@ class Neighbor:
self.generation_id = generation_id
# todo lan prune delay
# todo override interval
# todo state refresh capable
self.state_refresh_capable = state_refresh_capable
self.neighbor_liveness_timer = None
self.hello_hold_time = None
......@@ -26,13 +25,9 @@ class Neighbor:
self.neighbor_lock = Lock()
self.tree_interface_nlt_subscribers = []
self.tree_interface_nlt_subscribers_lock = RWLockWrite()
self.tree_interface_nlt_subscribers_lock = RLock()
# send hello to new neighbor
#self.contact_interface.send_hello()
# todo RANDOM DELAY??? => DO NOTHING... EVENTUALLY THE HELLO MESSAGE WILL BE SENT
def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time
if self.neighbor_liveness_timer is not None:
......@@ -69,14 +64,11 @@ class Neighbor:
print('HELLO TIMER EXPIRED... remove neighbor')
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
#Main.remove_neighbor(self.ip)
interface_name = self.contact_interface.interface_name
neighbor_ip = self.ip
del self.contact_interface.neighbors[self.ip]
self.contact_interface.remove_neighbor(self.ip)
# notify interfaces which have this neighbor as AssertWinner
with self.tree_interface_nlt_subscribers_lock.genRlock():
with self.tree_interface_nlt_subscribers_lock:
for tree_if in self.tree_interface_nlt_subscribers:
tree_if.assert_winner_nlt_expires()
......@@ -85,22 +77,23 @@ class Neighbor:
return
def receive_hello(self, generation_id, hello_hold_time):
def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.set_hello_hold_time(hello_hold_time)
else:
self.time_of_last_update = time.time()
self.set_generation_id(generation_id)
self.set_hello_hold_time(hello_hold_time)
if state_refresh_capable != self.state_refresh_capable:
self.state_refresh_capable = state_refresh_capable
def subscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock.genWlock():
with self.tree_interface_nlt_subscribers_lock:
if tree_if not in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.append(tree_if)
def unsubscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock.genWlock():
with self.tree_interface_nlt_subscribers_lock:
if tree_if in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.remove(tree_if)
......@@ -47,6 +47,9 @@ class PacketIGMPHeader(PacketPayload):
self.max_resp_time = max_resp_time
self.group_address = group_address
def get_igmp_type(self):
return self.type
def bytes(self) -> bytes:
# obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketIGMPHeader.IGMP_HDR, self.type, self.max_resp_time, 0,
......
......@@ -73,20 +73,4 @@ class PacketPimHeader(PacketPayload):
pim_payload = data[PacketPimHeader.PIM_HDR_LEN:]
pim_payload = PacketPimHeader.PIM_MSG_TYPES[pim_type].parse_bytes(pim_payload)
'''
if pim_type == 0: # hello
pim_payload = PacketPimHello.parse_bytes(pim_payload)
elif pim_type == 3: # join/prune
pim_payload = PacketPimJoinPrune.parse_bytes(pim_payload)
print("hold_time = ", pim_payload.hold_time)
print("upstream_neighbor = ", pim_payload.upstream_neighbor_address)
for i in pim_payload.groups:
print(i.multicast_group)
print(i.joined_src_addresses)
print(i.pruned_src_addresses)
elif pim_type == 5: # assert
pim_payload = PacketPimAssert.parse_bytes(pim_payload)
else:
raise Exception
'''
return PacketPimHeader(pim_payload)
import random
from threading import Timer
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from Interface import Interface
import Main
class StateRefresh:
TYPE = 9
def __init__(self):
Main.add_protocol(StateRefresh.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
#check if interface supports state refresh
if not packet.interface._state_refresh_capable:
return
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh
# TODO
interface_index = packet.interface.vif_index
source = pkt_state_refresh.source_address
group = pkt_state_refresh.multicast_group_adress
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
try:
# import time
# time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
pass
......@@ -195,9 +195,11 @@ class KernelEntry:
self.interface_state[self.inbound_interface_index].change_rpf(self._was_olist_null)
def nbr_event(self, link, node, event):
# todo pode ser interessante verificar se a adicao/remocao de vizinhos se altera o olist
return
# check if add/removal of neighbors from interface afects olist and forward/prune state of interface
def change_at_number_of_neighbors(self):
with self.CHANGE_STATE_LOCK:
self.change()
self.evaluate_olist_change()
def is_olist_null(self):
for interface in self.interface_state.values():
......
......@@ -120,12 +120,13 @@ class TreeInterfaceDownstream(TreeInterface):
return
interval = state_refresh_msg_received.interval
self._assert_state.sendStateRefresh(self, interval)
self._prune_state.send_state_refresh(self)
if self.lost_assert():
return
self._assert_state.sendStateRefresh(self, interval)
self._prune_state.send_state_refresh(self)
prune_indicator_bit = 0
if self.is_pruned():
prune_indicator_bit = 1
......@@ -164,7 +165,7 @@ class TreeInterfaceDownstream(TreeInterface):
# Override
def is_forwarding(self):
return ((len(self.get_interface().neighbors) >= 1 and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert()
return ((self.has_neighbors() and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert()
#return self._assert_state == AssertState.Winner and self.is_in_group()
def is_pruned(self):
......
......@@ -193,7 +193,6 @@ class TreeInterfaceUpstream(TreeInterface):
####################################
def create_state_refresh_msg(self):
self._prune_now_counter+=1
self._prune_now_counter%=3
(source_ip, group_ip) = self.get_tree_id()
ph = PacketPimStateRefresh(multicast_group_adress=group_ip,
source_address=source_ip,
......@@ -201,9 +200,11 @@ class TreeInterfaceUpstream(TreeInterface):
metric_preference=0, metric=0, mask_len=0,
ttl=256,
prune_indicator_flag=0,
prune_now_flag=(self._prune_now_counter+1)//3,
prune_now_flag=self._prune_now_counter//3,
assert_override_flag=0,
interval=60)
self._prune_now_counter %= 3
self._kernel_entry.forward_state_refresh_msg(ph)
###########################################
......
......@@ -30,19 +30,8 @@ from .globals import *
class TreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id):
'''
@type interface: SFMRInterface
@type node: Node
'''
#assert isinstance(interface, SFMRInterface)
self._kernel_entry = kernel_entry
self._interface_id = interface_id
#self._interface = interface
#self._node = node
#self._tree_id = tree_id
#self._cost = cost
#self._evaluate_ig = evaluate_ig_cb
# Local Membership State
try:
......@@ -53,7 +42,6 @@ class TreeInterface(metaclass=ABCMeta):
igmp_has_members = group_state.add_multicast_routing_entry(self)
self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo
except:
#traceback.print_exc()
self._local_membership_state = LocalMembership.NoInfo
......@@ -86,24 +74,17 @@ class TreeInterface(metaclass=ABCMeta):
self.evaluate_ingroup()
def set_assert_winner_metric(self, new_assert_metric: AssertMetric):
import ipaddress
with self.get_state_lock():
try:
old_neighbor = self.get_interface().get_neighbor(str(self._assert_winner_metric.ip_address))
new_neighbor = self.get_interface().get_neighbor(str(new_assert_metric.ip_address))
old_neighbor = self.get_interface().get_neighbor(self._assert_winner_metric.get_ip())
new_neighbor = self.get_interface().get_neighbor(new_assert_metric.get_ip())
if old_neighbor is not None:
old_neighbor.unsubscribe_nlt_expiration(self)
if new_neighbor is not None:
new_neighbor.subscribe_nlt_expiration(self)
'''
if new_assert_metric.ip_address == ipaddress.ip_address("0.0.0.0") or new_assert_metric.ip_address is None:
if old_neighbor is not None:
old_neighbor.unsubscribe_nlt_expiration(self)
else:
old_neighbor.unsubscribe_nlt_expiration(self)
new_neighbor.subscribe_nlt_expiration(self)
'''
except:
traceback.print_exc()
finally:
self._assert_winner_metric = new_assert_metric
......@@ -340,6 +321,12 @@ class TreeInterface(metaclass=ABCMeta):
ip = self.get_interface().get_ip()
return ip
def has_neighbors(self):
try:
return len(self.get_interface().neighbors) > 0
except:
return False
def get_tree_id(self):
return (self._kernel_entry.source_ip, self._kernel_entry.group_ip)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment