Commit 40ef4a8f authored by Pedro Oliveira's avatar Pedro Oliveira Committed by GitHub

Merge pull request #1 from pedrofran12/PIM-DM

Pim dm implementation
parents bcefc317 dec96dcb
from time import time
try:
from threading import _Timer as Timer
except ImportError:
from threading import Timer
class RemainingTimer(Timer):
def __init__(self, interval, function):
super().__init__(interval, function)
self.start_time = time()
def time_remaining(self):
delta_time = time() - self.start_time
return self.interval - delta_time
'''
def test():
print("ola")
x = RemainingTimer(10, test)
x.start()
from time import sleep
for i in range(0, 10):
print(x.time_remaining())
sleep(1)
'''
"""Generic linux daemon base class for python 3.x."""
import sys, os, time, atexit, signal
class Daemon:
"""A generic Daemon class.
Usage: subclass the Daemon class and override the run() method."""
def __init__(self, pidfile): self.pidfile = pidfile
def daemonize(self):
"""Deamonize class. UNIX double fork mechanism."""
try:
pid = os.fork()
if pid > 0:
# exit first parent
sys.exit(0)
except OSError as err:
sys.stderr.write('fork #1 failed: {0}\n'.format(err))
sys.exit(1)
# decouple from parent environment
#os.chdir('/')
#os.setsid()
#os.umask(0)
# do second fork
try:
pid = os.fork()
if pid > 0:
# exit from second parent
sys.exit(0)
except OSError as err:
sys.stderr.write('fork #2 failed: {0}\n'.format(err))
sys.exit(1)
# redirect standard file descriptors
sys.stdout.flush()
sys.stderr.flush()
si = open(os.devnull, 'r')
so = open('stdout', 'a+')
se = open('stderror', 'a+')
os.dup2(si.fileno(), sys.stdin.fileno())
os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno())
# write pidfile
atexit.register(self.delpid)
pid = str(os.getpid())
with open(self.pidfile, 'w+') as f:
f.write(pid + '\n')
def delpid(self):
os.remove(self.pidfile)
def start(self):
"""Start the Daemon."""
# Check for a pidfile to see if the Daemon already runs
if self.is_running():
message = "pidfile {0} already exist. " + \
"Daemon already running?\n"
sys.stderr.write(message.format(self.pidfile))
sys.exit(1)
# Start the Daemon
self.daemonize()
self.run()
def stop(self):
"""Stop the Daemon."""
# Get the pid from the pidfile
try:
with open(self.pidfile, 'r') as pf:
pid = int(pf.read().strip())
except IOError:
pid = None
if not pid:
message = "pidfile {0} does not exist. " + \
"Daemon not running?\n"
sys.stderr.write(message.format(self.pidfile))
return # not an error in a restart
# Try killing the Daemon process
try:
while 1:
#os.killpg(os.getpgid(pid), signal.SIGTERM)
os.kill(pid, signal.SIGTERM)
time.sleep(0.1)
except OSError as err:
e = str(err.args)
if e.find("No such process") > 0:
if os.path.exists(self.pidfile):
os.remove(self.pidfile)
else:
print(str(err.args))
sys.exit(1)
def restart(self):
"""Restart the Daemon."""
self.stop()
self.start()
def run(self):
"""You should override this method when you subclass Daemon.
It will be called after the process has been daemonized by
start() or restart()."""
def is_running(self):
try:
with open(self.pidfile, 'r') as pf:
pid = int(pf.read().strip())
except IOError:
return False
""" Check For the existence of a unix pid. """
try:
os.kill(pid, 0)
return True
except:
return False
import socket
from abc import ABCMeta, abstractmethod
import threading
import random
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from RWLock.RWLock import RWLockWrite
class Interface(metaclass=ABCMeta):
MCAST_GRP = '224.0.0.13'
def __init__(self, interface_name, recv_socket, send_socket, vif_index):
self.interface_name = interface_name
# virtual interface index for the multicast routing table
self.vif_index = vif_index
# set receive socket and send socket
self._send_socket = send_socket
self._recv_socket = recv_socket
self.interface_enabled = False
def _enable(self):
self.interface_enabled = True
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def receive(self):
while self.interface_enabled:
try:
(raw_bytes, _) = self._recv_socket.recvfrom(256 * 1024)
if raw_bytes:
self._receive(raw_bytes)
except Exception:
traceback.print_exc()
continue
@abstractmethod
def _receive(self, raw_bytes):
raise NotImplementedError
def send(self, data: bytes, group_ip: str):
if self.interface_enabled and data:
self._send_socket.sendto(data, (group_ip, 0))
def remove(self):
self.interface_enabled = False
try:
self._recv_socket.shutdown(socket.SHUT_RDWR)
except Exception:
pass
self._recv_socket.close()
self._send_socket.close()
def is_enabled(self):
return self.interface_enabled
@abstractmethod
def get_ip(self):
raise NotImplementedError
\ No newline at end of file
import socket
import struct
from ipaddress import IPv4Address
from ctypes import create_string_buffer, addressof
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
from Interface import Interface
from utils import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query
if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25
class InterfaceIGMP(Interface):
ETH_P_IP = 0x0800 # Internet Protocol packet
SO_ATTACH_FILTER = 26
FILTER_IGMP = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
struct.pack('HBBI', 0x15, 0, 3, 0x00000800),
struct.pack('HBBI', 0x30, 0, 0, 0x00000017),
struct.pack('HBBI', 0x15, 0, 1, 0x00000002),
struct.pack('HBBI', 0x6, 0, 0, 0x00040000),
struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
]
def __init__(self, interface_name: str, vif_index: int):
# SEND SOCKET
snd_s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# bind to interface
snd_s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, str(interface_name + "\0").encode('utf-8'))
# RECEIVE SOCKET
rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
# receive only IGMP packets by setting a BPF filter
bpf_filter = b''.join(InterfaceIGMP.FILTER_IGMP)
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', len(InterfaceIGMP.FILTER_IGMP), mem_addr_of_filters)
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.SO_ATTACH_FILTER, fprog)
# bind to interface
rcv_s.bind((interface_name, 0x0800))
super().__init__(interface_name=interface_name, recv_socket=rcv_s, send_socket=snd_s, vif_index=vif_index)
self.interface_enabled = True
from igmp.RouterState import RouterState
self.interface_state = RouterState(self)
super()._enable()
def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
@property
def ip_interface(self):
return self.get_ip()
def send(self, data: bytes, address: str="224.0.0.1"):
super().send(data, address)
def _receive(self, raw_bytes):
if raw_bytes:
raw_bytes = raw_bytes[14:]
packet = ReceivedPacket(raw_bytes, self)
ip_src = packet.ip_header.ip_src
if not (ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast):
self.PKT_FUNCTIONS.get(packet.payload.get_igmp_type(), InterfaceIGMP.receive_unknown_type)(self, packet)
###########################################
# Recv packets
###########################################
def receive_version_1_membership_report(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_v1_membership_report(packet)
def receive_version_2_membership_report(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_v2_membership_report(packet)
def receive_leave_group(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_leave_group(packet)
def receive_membership_query(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"):
self.interface_state.receive_query(packet)
def receive_unknown_type(self, packet):
return
PKT_FUNCTIONS = {
Version_1_Membership_Report: receive_version_1_membership_report,
Version_2_Membership_Report: receive_version_2_membership_report,
Leave_Group: receive_leave_group,
Membership_Query: receive_membership_query,
}
##################
def remove(self):
super().remove()
self.interface_state.remove()
This diff is collapsed.
This diff is collapsed.
import netifaces
import time
from prettytable import PrettyTable
import sys
import logging, logging.handlers
from TestLogger import RootFilter
from Kernel import Kernel
import UnicastRouting
interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces
kernel = None
unicast_routing = None
logger = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False):
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
def add_igmp_interface(interface_name):
kernel.create_igmp_interface(interface_name=interface_name)
'''
def add_interface(interface_name, pim=False, igmp=False):
#if pim is True and interface_name not in interfaces:
# interface = InterfacePim(interface_name)
# interfaces[interface_name] = interface
# interface.create_virtual_interface()
#if igmp is True and interface_name not in igmp_interfaces:
# interface = InterfaceIGMP(interface_name)
# igmp_interfaces[interface_name] = interface
kernel.create_interface(interface_name=interface_name, pim=pim, igmp=igmp)
#if pim:
# interfaces[interface_name] = kernel.pim_interface[interface_name]
#if igmp:
# igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name]
'''
def remove_interface(interface_name, pim=False, igmp=False):
#if pim is True and ((interface_name in interfaces) or interface_name == "*"):
# if interface_name == "*":
# interface_name_list = list(interfaces.keys())
# else:
# interface_name_list = [interface_name]
# for if_name in interface_name_list:
# interface_obj = interfaces.pop(if_name)
# interface_obj.remove()
# #interfaces[if_name].remove()
# #del interfaces[if_name]
# print("removido interface")
# print(interfaces)
#if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"):
# if interface_name == "*":
# interface_name_list = list(igmp_interfaces.keys())
# else:
# interface_name_list = [interface_name]
# for if_name in interface_name_list:
# igmp_interfaces[if_name].remove()
# del igmp_interfaces[if_name]
# print("removido interface")
# print(igmp_interfaces)
kernel.remove_interface(interface_name, pim=pim, igmp=igmp)
def list_neighbors():
interfaces_list = interfaces.values()
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
check_time = time.time()
for interface in interfaces_list:
for neighbor in interface.get_neighbors():
uptime = check_time - neighbor.time_of_last_update
uptime = 0 if (uptime < 0) else uptime
t.add_row(
[interface.interface_name, neighbor.ip, neighbor.hello_hold_time, neighbor.generation_id, time.strftime("%H:%M:%S", time.gmtime(uptime))])
print(t)
return str(t)
def list_enabled_interfaces():
global interfaces
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'IGMP State'])
for interface in netifaces.interfaces():
try:
# TODO: fix same interface with multiple ips
ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr']
pim_enabled = interface in interfaces
igmp_enabled = interface in igmp_interfaces
enabled = str(pim_enabled) + "/" + str(igmp_enabled)
if igmp_enabled:
state = igmp_interfaces[interface].interface_state.print_state()
else:
state = "-"
t.add_row([interface, ip, enabled, state])
except Exception:
continue
print(t)
return str(t)
def list_state():
state_text = "IGMP State:\n" + list_igmp_state() + "\n\n\n\n" + "Multicast Routing State:\n" + list_routing_state()
return state_text
def list_igmp_state():
t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState'])
for (interface_name, interface_obj) in list(igmp_interfaces.items()):
interface_state = interface_obj.interface_state
state_txt = interface_state.print_state()
print(interface_state.group_state.items())
for (group_addr, group_state) in list(interface_state.group_state.items()):
print(group_addr)
group_state_txt = group_state.print_state()
t.add_row([interface_name, state_txt, group_addr, group_state_txt])
return str(t)
def list_routing_state():
routing_entries = []
for a in list(kernel.routing.values()):
for b in list(a.values()):
routing_entries.append(b)
vif_indexes = kernel.vif_index_to_name_dic.keys()
t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', 'LocalMembership', "Is Forwarding?"])
for entry in routing_entries:
ip = entry.source_ip
group = entry.group_ip
upstream_if_index = entry.inbound_interface_index
for index in vif_indexes:
interface_state = entry.interface_state[index]
interface_name = kernel.vif_index_to_name_dic[index]
local_membership = type(interface_state._local_membership_state).__name__
try:
assert_state = type(interface_state._assert_state).__name__
if index != upstream_if_index:
prune_state = type(interface_state._prune_state).__name__
is_forwarding = interface_state.is_forwarding()
else:
prune_state = type(interface_state._graft_prune_state).__name__
is_forwarding = "upstream"
except:
prune_state = "-"
assert_state = "-"
is_forwarding = "-"
t.add_row([ip, group, interface_name, prune_state, assert_state, local_membership, is_forwarding])
return str(t)
def stop():
remove_interface("*", pim=True, igmp=True)
kernel.exit()
unicast_routing.stop()
def test(router_name, server_logger_ip):
global logger
socketHandler = logging.handlers.SocketHandler(server_logger_ip,
logging.handlers.DEFAULT_TCP_LOGGING_PORT)
# don't bother with a formatter, since a socket handler sends the event as
# an unformatted pickle
socketHandler.addFilter(RootFilter(router_name))
logger.addHandler(socketHandler)
def main():
# logging
global logger
logger = logging.getLogger('pim')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
global kernel
kernel = Kernel()
global unicast_routing
unicast_routing = UnicastRouting.UnicastRouting()
global interfaces
global igmp_interfaces
interfaces = kernel.pim_interface
igmp_interfaces = kernel.igmp_interface
from threading import Timer
import time
from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING
from threading import Lock, RLock
import Main
import logging
if TYPE_CHECKING:
from InterfacePIM import InterfacePim
class Neighbor:
LOGGER = logging.getLogger('pim.Interface.Neighbor')
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int,
state_refresh_capable: bool):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
raise Exception
logger_info = dict(contact_interface.interface_logger.extra)
logger_info['neighbor_ip'] = ip
self.neighbor_logger = logging.LoggerAdapter(self.LOGGER, logger_info)
self.neighbor_logger.debug('Monitoring new neighbor ' + ip + ' with GenerationID: ' + str(generation_id) +
'; HelloHoldTime: ' + str(hello_hold_time) + '; StateRefreshCapable: ' +
str(state_refresh_capable))
self.contact_interface = contact_interface
self.ip = ip
self.generation_id = generation_id
# todo lan prune delay
# todo override interval
self.state_refresh_capable = state_refresh_capable
self.neighbor_liveness_timer = None
self.hello_hold_time = None
self.set_hello_hold_time(hello_hold_time)
self.time_of_last_update = time.time()
self.neighbor_lock = Lock()
self.tree_interface_nlt_subscribers = []
self.tree_interface_nlt_subscribers_lock = RLock()
def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.remove()
self.neighbor_logger.debug('Detected neighbor removal of ' + self.ip)
elif hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT:
self.neighbor_logger.debug('Neighbor Liveness Timer reseted of ' + self.ip)
self.neighbor_liveness_timer = Timer(hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
else:
self.neighbor_liveness_timer = None
def set_generation_id(self, generation_id):
# neighbor restarted
if self.generation_id != generation_id:
self.neighbor_logger.debug('Detected reset of ' + self.ip + '... new GenerationID: ' + str(generation_id))
self.generation_id = generation_id
self.contact_interface.force_send_hello()
self.reset()
"""
def heartbeat(self):
if (self.hello_hold_time != HELLO_HOLD_TIME_TIMEOUT) and \
(self.hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT):
print("HEARTBEAT")
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
self.neighbor_liveness_timer = Timer(self.hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
self.time_of_last_update = time.time()
"""
def remove(self):
print('HELLO TIMER EXPIRED... remove neighbor')
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
self.neighbor_logger.debug('Neighbor Liveness Timer expired of ' + self.ip)
self.contact_interface.remove_neighbor(self.ip)
# notify interfaces which have this neighbor as AssertWinner
with self.tree_interface_nlt_subscribers_lock:
for tree_if in self.tree_interface_nlt_subscribers:
tree_if.assert_winner_nlt_expires()
def reset(self):
self.contact_interface.new_or_reset_neighbor(self.ip)
def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable):
self.neighbor_logger.debug('Receive Hello message with HelloHoldTime: ' + str(hello_hold_time) +
'; GenerationID: ' + str(generation_id) + '; StateRefreshCapable: ' +
str(state_refresh_capable) + ' from neighbor ' + self.ip)
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.set_hello_hold_time(hello_hold_time)
else:
self.time_of_last_update = time.time()
self.set_generation_id(generation_id)
self.set_hello_hold_time(hello_hold_time)
if state_refresh_capable != self.state_refresh_capable:
self.state_refresh_capable = state_refresh_capable
def subscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock:
if tree_if not in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.append(tree_if)
def unsubscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock:
if tree_if in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.remove(tree_if)
from .PacketIpHeader import PacketIpHeader
from .PacketPayload import PacketPayload
class Packet(object):
def __init__(self, ip_header: PacketIpHeader = None, payload: PacketPayload = None):
self.ip_header = ip_header
self.payload = payload
def bytes(self) -> bytes:
return self.payload.bytes()
import struct
from utils import checksum
import socket
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Max Resp Time | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Group Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Resv |S| QRV | QQIC | Number of Sources (N) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address [1] |
+- -+
| Source Address [2] |
+- . -+
. . .
. . .
+- -+
| Source Address [N] |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIGMPHeader(PacketPayload):
IGMP_TYPE = 2
IGMP_HDR = "! BB H 4s"
IGMP_HDR_LEN = struct.calcsize(IGMP_HDR)
IGMP3_SRC_ADDR_HDR = "! BB H "
IGMP3_SRC_ADDR_HDR_LEN = struct.calcsize(IGMP3_SRC_ADDR_HDR)
IPv4_HDR = "! 4s"
IPv4_HDR_LEN = struct.calcsize(IPv4_HDR)
Membership_Query = 0x11
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Version_1_Membership_Report = 0x12
def __init__(self, type: int, max_resp_time: int, group_address: str="0.0.0.0"):
# todo check type
self.type = type
self.max_resp_time = max_resp_time
self.group_address = group_address
def get_igmp_type(self):
return self.type
def bytes(self) -> bytes:
# obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketIGMPHeader.IGMP_HDR, self.type, self.max_resp_time, 0,
socket.inet_aton(self.group_address))
igmp_checksum = checksum(msg_without_chcksum)
msg = msg_without_chcksum[0:2] + struct.pack("! H", igmp_checksum) + msg_without_chcksum[4:]
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
#print("parseIGMPHdr: ", data)
igmp_hdr = data[0:PacketIGMPHeader.IGMP_HDR_LEN]
(type, max_resp_time, rcv_checksum, group_address) = struct.unpack(PacketIGMPHeader.IGMP_HDR, igmp_hdr)
#print(type, max_resp_time, rcv_checksum, group_address)
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
#print("checksum calculated: " + str(checksum(msg_to_checksum)))
if checksum(msg_to_checksum) != rcv_checksum:
#print("wrong checksum")
raise Exception("wrong checksum")
igmp_hdr = igmp_hdr[PacketIGMPHeader.IGMP_HDR_LEN:]
group_address = socket.inet_ntoa(group_address)
pkt = PacketIGMPHeader(type, max_resp_time, group_address)
return pkt
\ No newline at end of file
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIpHeader:
IP_HDR = "! BBH HH BBH 4s 4s"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
self.version = ver
self.hdr_length = hdr_len
self.ttl = ttl
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, data)
ver = (verhlen & 0xf0) >> 4
hlen = (verhlen & 0x0f) * 4
'''
"VER": ver,
"HLEN": hlen,
"TOS": tos,
"IPLEN": iplen,
"IPID": ipid,
"FRAG": frag,
"TTL": ttl,
"PROTO": proto,
"CKSUM": cksum,
"SRC": socket.inet_ntoa(src),
"DST": socket.inet_ntoa(dst)
'''
src_ip = socket.inet_ntoa(src)
dst_ip = socket.inet_ntoa(dst)
return PacketIpHeader(ver, hlen, ttl, proto, src_ip, dst_ip)
import abc
class PacketPayload(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def bytes(self) -> bytes:
"""Get packet payload in bytes format"""
@abc.abstractmethod
def __len__(self):
"""Get packet payload length"""
@staticmethod
@abc.abstractmethod
def parse_bytes(data: bytes):
"""From bytes create a object payload"""
import struct
import socket
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from tree.globals import ASSERT_CANCEL_METRIC
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimAssert:
PIM_TYPE = 5
PIM_HDR_ASSERT = "! %ss %ss LL"
PIM_HDR_ASSERT_WITHOUT_ADDRESS = "! LL"
PIM_HDR_ASSERT_v4 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_ASSERT_v6 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN = struct.calcsize(PIM_HDR_ASSERT_WITHOUT_ADDRESS)
PIM_HDR_ASSERT_v4_LEN = struct.calcsize(PIM_HDR_ASSERT_v4)
PIM_HDR_ASSERT_v6_LEN = struct.calcsize(PIM_HDR_ASSERT_v6)
def __init__(self, multicast_group_address: str or bytes, source_address: str or bytes, metric_preference: int or float, metric: int or float):
if type(multicast_group_address) is bytes:
multicast_group_address = socket.inet_ntoa(multicast_group_address)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
if metric_preference > 0x7FFFFFFF:
metric_preference = 0x7FFFFFFF
if metric > ASSERT_CANCEL_METRIC:
metric = ASSERT_CANCEL_METRIC
self.multicast_group_address = multicast_group_address
self.source_address = source_address
self.metric_preference = metric_preference
self.metric = metric
def bytes(self) -> bytes:
multicast_group_address = PacketPimEncodedGroupAddress(self.multicast_group_address).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
msg = multicast_group_address + source_address + struct.pack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS,
0x7FFFFFFF & self.metric_preference,
self.metric)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_addr_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_addr_len = len(multicast_group_addr_obj)
data = data[multicast_group_addr_len:]
source_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_addr_len = len(source_addr_obj)
data = data[source_addr_len:]
(metric_preference, metric) = struct.unpack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS, data[:PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN])
pim_payload = PacketPimAssert(multicast_group_addr_obj.group_address, source_addr_obj.unicast_address, 0x7FFFFFFF & metric_preference, metric)
return pim_payload
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type |B| Reserved |Z| Mask Len |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Group Multicast Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedGroupAddress:
PIM_ENCODED_GROUP_ADDRESS_HDR = "! BBBB %s"
PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS = "! BBBB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS)
PIM_ENCODED_GROUP_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6 = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
RESERVED = 0
def __init__(self, group_address, mask_len=None):
if type(group_address) not in (str, bytes):
raise Exception
if type(group_address) is bytes:
group_address = socket.inet_ntoa(group_address)
self.group_address = group_address
self.mask_len = mask_len
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedGroupAddress.get_ip_info(self.group_address)
mask_len = self.mask_len
if mask_len is None:
mask_len = 8 * struct.calcsize(string_ip_hdr)
ip = socket.inet_pton(socket_family, self.group_address)
msg = struct.pack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0,
PacketPimEncodedGroupAddress.RESERVED, mask_len, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedGroupAddress.IPV4_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedGroupAddress.IPV6_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.group_address).version
if version == 4:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_group_addr = data[0:PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN]
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS, data_without_group_addr)
data_group_addr = data[PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV4_HDR, data_group_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV6_HDR, data_group_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedGroupAddress(ip, mask_len)
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type | Rsrvd |S|W|R| Mask Len |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedSourceAddress:
PIM_ENCODED_SOURCE_ADDRESS_HDR = "! BBBB %s"
PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS = "! BBBB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS)
PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6 = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
RESERVED_AND_SWR_BITS = 0
def __init__(self, source_address, mask_len=None):
if type(source_address) not in (str, bytes):
raise Exception
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
self.source_address = source_address
self.mask_len = mask_len
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedSourceAddress.get_ip_info(self.source_address)
mask_len = self.mask_len
if mask_len is None:
mask_len = 8 * struct.calcsize(string_ip_hdr)
ip = socket.inet_pton(socket_family, self.source_address)
msg = struct.pack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0,
PacketPimEncodedSourceAddress.RESERVED_AND_SWR_BITS, mask_len, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedSourceAddress.IPV4_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedSourceAddress.IPV6_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.source_address).version
if version == 4:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_source_addr = data[0:PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN]
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS, data_without_source_addr)
data_source_addr = data[PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV4_HDR, data_source_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV6_HDR, data_source_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedSourceAddress(ip, mask_len)
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type | Unicast Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedUnicastAddress:
PIM_ENCODED_UNICAST_ADDRESS_HDR = "! BB %s"
PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS = "! BB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS)
PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6 = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
def __init__(self, unicast_address):
if type(unicast_address) not in (str, bytes):
raise Exception
if type(unicast_address) is bytes:
unicast_address = socket.inet_ntoa(unicast_address)
self.unicast_address = unicast_address
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedUnicastAddress.get_ip_info(self.unicast_address)
ip = socket.inet_pton(socket_family, self.unicast_address)
msg = struct.pack(PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedUnicastAddress.IPV4_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedUnicastAddress.IPV6_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.unicast_address).version
if version == 4:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_unicast_addr = data[0:PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN]
(addr_family, encoding) = struct.unpack(PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS, data_without_unicast_addr)
data_unicast_addr = data[PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN:]
if addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV4_HDR, data_unicast_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV6_HDR, data_unicast_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedUnicastAddress(ip)
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraft(PacketPimJoinPrune):
PIM_TYPE = 6
def __init__(self, upstream_neighbor_address, holdtime=0):
super().__init__(upstream_neighbor_address=upstream_neighbor_address, hold_time=holdtime)
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraftAck(PacketPimJoinPrune):
PIM_TYPE = 7
def __init__(self, upstream_neighbor_address, holdtime=0):
super().__init__(upstream_neighbor_address, hold_time=holdtime)
import struct
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimAssert import PacketPimAssert
from Packet.PacketPimGraft import PacketPimGraft
from Packet.PacketPimGraftAck import PacketPimGraftAck
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from utils import checksum
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimHeader(PacketPayload):
PIM_VERSION = 2
PIM_HDR = "! BB H"
PIM_HDR_LEN = struct.calcsize(PIM_HDR)
PIM_MSG_TYPES = {0: PacketPimHello,
3: PacketPimJoinPrune,
5: PacketPimAssert,
6: PacketPimGraft,
7: PacketPimGraftAck,
9: PacketPimStateRefresh
}
def __init__(self, payload):
self.payload = payload
def get_pim_type(self):
return self.payload.PIM_TYPE
def bytes(self) -> bytes:
# obter mensagem e criar checksum
pim_vrs_type = (PacketPimHeader.PIM_VERSION << 4) + self.get_pim_type()
msg_without_chcksum = struct.pack(PacketPimHeader.PIM_HDR, pim_vrs_type, 0, 0)
msg_without_chcksum += self.payload.bytes()
pim_checksum = checksum(msg_without_chcksum)
msg = msg_without_chcksum[0:2] + struct.pack("! H", pim_checksum) + msg_without_chcksum[4:]
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
print("parsePimHdr: ", data)
pim_hdr = data[0:PacketPimHeader.PIM_HDR_LEN]
(pim_ver_type, reserved, rcv_checksum) = struct.unpack(PacketPimHeader.PIM_HDR, pim_hdr)
print(pim_ver_type, reserved, rcv_checksum)
pim_version = (pim_ver_type & 0xF0) >> 4
pim_type = pim_ver_type & 0x0F
if pim_version != PacketPimHeader.PIM_VERSION:
print("Version of PIM packet received not known (!=2)")
raise Exception
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
if checksum(msg_to_checksum) != rcv_checksum:
print("wrong checksum")
print("checksum calculated: " + str(checksum(msg_to_checksum)))
print("checksum recv: " + str(rcv_checksum))
raise Exception
pim_payload = data[PacketPimHeader.PIM_HDR_LEN:]
pim_payload = PacketPimHeader.PIM_MSG_TYPES[pim_type].parse_bytes(pim_payload)
return PacketPimHeader(pim_payload)
import struct
from abc import ABCMeta, abstractstaticmethod
from .PacketPimHelloOptions import PacketPimHelloOptions, PacketPimHelloStateRefreshCapable, PacketPimHelloGenerationID, PacketPimHelloLANPruneDelay, PacketPimHelloHoldtime
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Type | Option Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Value |
| ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Type | Option Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Value |
| ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimHello:
PIM_TYPE = 0
PIM_HDR_OPTS = "! HH"
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
PIM_MSG_TYPES_LENGTH = {1: 2,
2: 4,
20: 4,
21: 4,
}
# todo: pensar melhor na implementacao state refresh capable option...
def __init__(self):
self.options = {}
'''
def add_option(self, option_type: int, option_value: int or float):
option_value = int(option_value)
# if option_value requires more bits than the bits available for that field: option value will have all field bits = 1
if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option_type] = option_value
'''
def add_option(self, option: 'PacketPimHelloOptions'):
#if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
# option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option.type] = option
def get_options(self):
return self.options
'''
def bytes(self) -> bytes:
res = b''
for (option_type, option_value) in self.options.items():
option_length = PacketPimHello.PIM_MSG_TYPES_LENGTH[option_type]
type_length_hdr = struct.pack(PacketPimHello.PIM_HDR_OPTS, option_type, option_length)
res += type_length_hdr + struct.pack("! " + str(option_length) + "s", option_value.to_bytes(option_length, byteorder='big'))
return res
'''
def bytes(self) -> bytes:
res = b''
for option in self.options.values():
res += option.bytes()
return res
def __len__(self):
return len(self.bytes())
'''
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
while data != b'':
(option_type, option_length) = struct.unpack(PacketPimHello.PIM_HDR_OPTS,
data[:PacketPimHello.PIM_HDR_OPTS_LEN])
print(option_type, option_length)
data = data[PacketPimHello.PIM_HDR_OPTS_LEN:]
print(data)
(option_value,) = struct.unpack("! " + str(option_length) + "s", data[:option_length])
option_value_number = int.from_bytes(option_value, byteorder='big')
print("option value: ", option_value_number)
#options_list.append({"OPTION TYPE": option_type,
# "OPTION LENGTH": option_length,
# "OPTION VALUE": option_value_number
# })
pim_payload.add_option(option_type, option_value_number)
data = data[option_length:]
return pim_payload
'''
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
while data != b'':
option = PacketPimHelloOptions.parse_bytes(data)
option_length = len(option)
data = data[option_length:]
pim_payload.add_option(option)
return pim_payload
import struct
from abc import ABCMeta
import math
class PacketPimHelloOptions(metaclass=ABCMeta):
PIM_HDR_OPTS = "! HH"
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type: int, length: int):
self.type = type
self.length = length
def bytes(self) -> bytes:
return struct.pack(PacketPimHelloOptions.PIM_HDR_OPTS, self.type, self.length)
def __len__(self):
return self.PIM_HDR_OPTS_LEN + self.length
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
(type, length) = struct.unpack(PacketPimHelloOptions.PIM_HDR_OPTS,
data[:PacketPimHelloOptions.PIM_HDR_OPTS_LEN])
#print("TYPE:", type)
#print("LENGTH:", length)
data = data[PacketPimHelloOptions.PIM_HDR_OPTS_LEN:]
#return PIM_MSG_TYPES[type](data)
return PIM_MSG_TYPES.get(type, PacketPimHelloUnknown).parse_bytes(data, type, length)
class PacketPimHelloStateRefreshCapable(PacketPimHelloOptions):
PIM_HDR_OPT = "! BBH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 1 | Interval | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
VERSION = 1
def __init__(self, interval: int):
super().__init__(type=21, length=4)
self.interval = interval
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.VERSION, self.interval, 0)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(version, interval, _) = struct.unpack(PacketPimHelloStateRefreshCapable.PIM_HDR_OPT,
data[:PacketPimHelloStateRefreshCapable.PIM_HDR_OPT_LEN])
return PacketPimHelloStateRefreshCapable(interval)
class PacketPimHelloLANPruneDelay(PacketPimHelloOptions):
PIM_HDR_OPT = "! HH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|T| LAN Prune Delay | Override Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, lan_prune_delay: float, override_interval: float):
super().__init__(type=2, length=4)
self.lan_prune_delay = 0x7FFF & math.ceil(lan_prune_delay)
self.override_interval = math.ceil(override_interval)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.lan_prune_delay, self.override_interval)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(lan_prune_delay, override_interval) = struct.unpack(PacketPimHelloLANPruneDelay.PIM_HDR_OPT,
data[:PacketPimHelloLANPruneDelay.PIM_HDR_OPT_LEN])
lan_prune_delay = lan_prune_delay & 0x7FFF
return PacketPimHelloLANPruneDelay(lan_prune_delay=lan_prune_delay, override_interval=override_interval)
class PacketPimHelloHoldtime(PacketPimHelloOptions):
PIM_HDR_OPT = "! H"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, holdtime: int or float):
super().__init__(type=1, length=2)
self.holdtime = int(holdtime)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.holdtime)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(holdtime, ) = struct.unpack(PacketPimHelloHoldtime.PIM_HDR_OPT,
data[:PacketPimHelloHoldtime.PIM_HDR_OPT_LEN])
#print("HOLDTIME:", holdtime)
return PacketPimHelloHoldtime(holdtime=holdtime)
class PacketPimHelloGenerationID(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Generation ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, generation_id: int):
super().__init__(type=20, length=4)
self.generation_id = generation_id
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.generation_id)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(generation_id, ) = struct.unpack(PacketPimHelloGenerationID.PIM_HDR_OPT,
data[:PacketPimHelloGenerationID.PIM_HDR_OPT_LEN])
#print("GenerationID:", generation_id)
return PacketPimHelloGenerationID(generation_id=generation_id)
class PacketPimHelloUnknown(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Unknown |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type, length):
super().__init__(type=type, length=length)
#print("PIM Hello Option Unknown... TYPE=", type, "LENGTH=", length)
def bytes(self) -> bytes:
raise Exception
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
return PacketPimHelloUnknown(type, length)
PIM_MSG_TYPES = {1: PacketPimHelloHoldtime,
2: PacketPimHelloLANPruneDelay,
20: PacketPimHelloGenerationID,
21: PacketPimHelloStateRefreshCapable,
}
import struct
import socket
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimJoinPrune:
PIM_TYPE = 3
PIM_HDR_JOIN_PRUNE = "! %ss BBH"
PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS = "! BBH"
PIM_HDR_JOIN_PRUNE_v4 = PIM_HDR_JOIN_PRUNE % PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN
PIM_HDR_JOIN_PRUNE_v6 = PIM_HDR_JOIN_PRUNE % PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS)
PIM_HDR_JOIN_PRUNE_v4_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_v4)
PIM_HDR_JOIN_PRUNE_v6_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_v6)
def __init__(self, upstream_neighbor_address, hold_time):
if type(upstream_neighbor_address) not in (str, bytes):
raise Exception
if type(upstream_neighbor_address) is bytes:
upstream_neighbor_address = socket.inet_ntoa(upstream_neighbor_address)
self.groups = []
self.upstream_neighbor_address = upstream_neighbor_address
self.hold_time = hold_time
def add_multicast_group(self, group: PacketPimJoinPruneMulticastGroup):
# TODO verificar se grupo ja esta na msg
self.groups.append(group)
def bytes(self) -> bytes:
upstream_neighbor_address = PacketPimEncodedUnicastAddress(self.upstream_neighbor_address).bytes()
msg = upstream_neighbor_address + struct.pack(PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS, 0,
len(self.groups), self.hold_time)
for multicast_group in self.groups:
msg += multicast_group.bytes()
return msg
def __len__(self):
return len(self.bytes())
@classmethod
def parse_bytes(cls, data: bytes):
upstream_neighbor_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
upstream_neighbor_addr_len = len(upstream_neighbor_addr_obj)
data = data[upstream_neighbor_addr_len:]
(_, num_groups, hold_time) = struct.unpack(PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS,
data[:PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN])
data = data[PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN:]
pim_payload = cls(upstream_neighbor_addr_obj.unicast_address, hold_time)
for i in range(0, num_groups):
group = PacketPimJoinPruneMulticastGroup.parse_bytes(data)
group_len = len(group)
pim_payload.add_multicast_group(group)
data = data[group_len:]
return pim_payload
import struct
import socket
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
from Packet.PacketPimEncodedSourceAddress import PacketPimEncodedSourceAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimJoinPruneMulticastGroup:
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP = "! %ss HH"
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS = "! HH"
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_v4_LEN_ = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP % PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN)
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_v6_LEN_ = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP % PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6)
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS)
PIM_HDR_JOINED_PRUNED_SOURCE = "! %ss"
PIM_HDR_JOINED_PRUNED_SOURCE_v4_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN
PIM_HDR_JOINED_PRUNED_SOURCE_v6_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
def __init__(self, multicast_group: str or bytes, joined_src_addresses: list=[], pruned_src_addresses: list=[]):
if type(multicast_group) not in (str, bytes):
raise Exception
elif type(multicast_group) is bytes:
multicast_group = socket.inet_ntoa(multicast_group)
if type(joined_src_addresses) is not list:
raise Exception
if type(pruned_src_addresses) is not list:
raise Exception
self.multicast_group = multicast_group
self.joined_src_addresses = joined_src_addresses
self.pruned_src_addresses = pruned_src_addresses
def bytes(self) -> bytes:
multicast_group_address = PacketPimEncodedGroupAddress(self.multicast_group).bytes()
msg = multicast_group_address + struct.pack(self.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS,
len(self.joined_src_addresses), len(self.pruned_src_addresses))
for joined_src_address in self.joined_src_addresses:
joined_src_address_bytes = PacketPimEncodedSourceAddress(joined_src_address).bytes()
msg += joined_src_address_bytes
for pruned_src_address in self.pruned_src_addresses:
pruned_src_address_bytes = PacketPimEncodedSourceAddress(pruned_src_address).bytes()
msg += pruned_src_address_bytes
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_addr_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_addr_len = len(multicast_group_addr_obj)
data = data[multicast_group_addr_len:]
number_join_prune_data = data[:PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN]
(number_joined_sources, number_pruned_sources) = struct.unpack(PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS, number_join_prune_data)
joined = []
pruned = []
data = data[PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN:]
for i in range(0, number_joined_sources):
joined_obj = PacketPimEncodedSourceAddress.parse_bytes(data)
joined_obj_len = len(joined_obj)
data = data[joined_obj_len:]
joined.append(joined_obj.source_address)
for i in range(0, number_pruned_sources):
pruned_obj = PacketPimEncodedSourceAddress.parse_bytes(data)
pruned_obj_len = len(pruned_obj)
data = data[pruned_obj_len:]
pruned.append(pruned_obj.source_address)
return PacketPimJoinPruneMulticastGroup(multicast_group_addr_obj.group_address, joined, pruned)
import struct
import socket
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Originator Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Masklen | TTL |P|N|O|Reserved | Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimStateRefresh:
PIM_TYPE = 9
PIM_HDR_STATE_REFRESH = "! %ss %ss %ss I I BBBB"
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES = "! I I BBBB"
PIM_HDR_STATE_REFRESH_v4 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_STATE_REFRESH_v6 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES)
PIM_HDR_STATE_REFRESH_v4_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v4)
PIM_HDR_STATE_REFRESH_v6_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v6)
def __init__(self, multicast_group_adress: str or bytes, source_address: str or bytes, originator_adress: str or bytes,
metric_preference: int, metric: int, mask_len: int, ttl: int, prune_indicator_flag: bool,
prune_now_flag: bool, assert_override_flag: bool, interval: int):
if type(multicast_group_adress) is bytes:
multicast_group_adress = socket.inet_ntoa(multicast_group_adress)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
if type(originator_adress) is bytes:
originator_adress = socket.inet_ntoa(originator_adress)
self.multicast_group_adress = multicast_group_adress
self.source_address = source_address
self.originator_adress = originator_adress
self.metric_preference = metric_preference
self.metric = metric
self.mask_len = mask_len
self.ttl = ttl
self.prune_indicator_flag = prune_indicator_flag
self.prune_now_flag = prune_now_flag
self.assert_override_flag = assert_override_flag
self.interval = interval
def bytes(self) -> bytes:
multicast_group_adress = PacketPimEncodedGroupAddress(self.multicast_group_adress).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
originator_adress = PacketPimEncodedUnicastAddress(self.originator_adress).bytes()
prune_and_assert_flags = (self.prune_indicator_flag << 7) | (self.prune_now_flag << 6) | (self.assert_override_flag << 5)
msg = multicast_group_adress + source_address + originator_adress + \
struct.pack(self.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, 0x7FFFFFFF & self.metric_preference,
self.metric, self.mask_len, self.ttl, prune_and_assert_flags, self.interval)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_adress_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_adress_len = len(multicast_group_adress_obj)
data = data[multicast_group_adress_len:]
source_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_address_len = len(source_address_obj)
data = data[source_address_len:]
originator_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
originator_address_len = len(originator_address_obj)
data = data[originator_address_len:]
(metric_preference, metric, mask_len, ttl, reserved_and_prune_and_assert_flags, interval) = struct.unpack(PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, data[:PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN])
metric_preference = 0x7FFFFFFF & metric_preference
prune_indicator_flag = (0x80 & reserved_and_prune_and_assert_flags) >> 7
prune_now_flag = (0x40 & reserved_and_prune_and_assert_flags) >> 6
assert_override_flag = (0x20 & reserved_and_prune_and_assert_flags) >> 5
pim_payload = PacketPimStateRefresh(multicast_group_adress_obj.group_address, source_address_obj.unicast_address,
originator_address_obj.unicast_address, metric_preference, metric, mask_len,
ttl, prune_indicator_flag, prune_now_flag, assert_override_flag, interval)
return pim_payload
from Packet.Packet import Packet
from Packet.PacketIpHeader import PacketIpHeader
from Packet.PacketIGMPHeader import PacketIGMPHeader
from .PacketPimHeader import PacketPimHeader
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from Interface import Interface
class ReceivedPacket(Packet):
# choose payload protocol class based on ip protocol number
payload_protocol = {2: PacketIGMPHeader, 103: PacketPimHeader}
def __init__(self, raw_packet: bytes, interface: 'Interface'):
self.interface = interface
# Parse ao packet e preencher objeto Packet
packet_ip_hdr = raw_packet[:PacketIpHeader.IP_HDR_LEN]
ip_header = PacketIpHeader.parse_bytes(packet_ip_hdr)
protocol_number = ip_header.proto
packet_without_ip_hdr = raw_packet[ip_header.hdr_length:]
payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr)
super().__init__(ip_header=ip_header, payload=payload)
# pim # PIM-DM
\ No newline at end of file
We have implemented the specification of PIM-DM ([RFC3973](https://tools.ietf.org/html/rfc3973)).
This repository stores the implementation of this protocol. The implementation is written in Python language and is destined to Linux systems.
# Requirements
- Linux machine
- Python3 (we have written all code to be compatible with at least Python v3.2)
- pip (to install all dependencies)
# Installation
You may need sudo permitions, in order to run this protocol. This is required because we use raw sockets to exchange control messages. For this reason, some sockets to work properly need to have super user permissions.
First clone this repository:
`git clone https://github.com/pedrofran12/pim.git`
Then enter in the cloned repository and install all dependencies:
`pip3 install -r requirements.txt`
And thats it :D
# Run protocol
In order to interact with the protocol you need to allways execute Run.py file. You can interact with the protocol by executing this file and specifying a command and corresponding arguments:
`sudo python3 Run.py -COMMAND ARGUMENTS`
In order to determine which commands are available you can call the help command:
`sudo python3 Run.py -h`
or
`sudo python3 Run.py --help`
In order to start the protocol you first need to explicitly start it. This will start a daemon process, which will be running in the background. The command is the following:
`sudo python3 Run.py -start`
Then you can enable the protocol in specific interfaces. You need to specify which interfaces will have IGMP enabled and which interfaces will have the PIM-DM enabled.
To enable PIM-DM, without State-Refresh, in a given interface, you need to run the following command:
`sudo python3 Run.py -ai INTERFACE_NAME`
To enable PIM-DM, with State-Refresh, in a given interface, you need to run the following command:
`sudo python3 Run.py -aisf INTERFACE_NAME`
To enable IGMP in a given interface, you need to run the following command:
`sudo python3 Run.py -aiigmp INTERFACE_NAME`
If you have previously enabled an interface without State-Refresh and want to enable it, in the same interface, you first need to disable this interface, and the run the command -aisr. The same happens when you want to disable State Refresh in a previously enabled StateRefresh interface.
To remove a previously added interface, you need run the following commands:
To remove a previously added PIM-DM interface:
`sudo python3 Run.py -ri INTERFACE_NAME`
To remove a previously added IGMP interface:
`sudo python3 Run.py -riigmp INTERFACE_NAME`
If you want to stop the protocol process, and stop the daemon process, you need to explicitly run this command:
`sudo python3 Run.py -stop`
## Commands for monitoring the protocol process
We have built some list commands that can be used to check the "internals" of the implementation.
- List neighbors:
Verify neighbors that have established a neighborhood relationship
`sudo python3 Run.py -ln`
- List state:
List all state machines and corresponding state of all trees that are being monitored. Also list IGMP state for each group being monitored.
`sudo python3 Run.py -ls`
- Multicast Routing Table:
List Linux Multicast Routing Table (equivalent to ip mroute -show)
`sudo python3 Run.py -mr`
## Change settings
Files tree/globals.py and igmp/igmp_globals.py store all timer values and some configurations regarding IGMP and the PIM-DM. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Read Write Lock
"""
import threading
import time
class RWLockRead(object):
"""
A Read/Write lock giving preference to Reader
"""
def __init__(self):
self.V_ReadCount = 0
self.A_Resource = threading.Lock()
self.A_LockReadCount = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
self.V_Locked = self.A_RWLock.A_Resource.acquire(blocking, timeout)
return self.V_Locked
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockRead._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockRead._aWriter(self)
class RWLockWrite(object):
"""
A Read/Write lock giving preference to Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.V_WriteCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockWriteCount = threading.Lock()
self.A_LockReadEntry = threading.Lock()
self.A_LockReadTry = threading.Lock()
self.A_Resource = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadEntry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadEntry.release()
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
return False
self.A_RWLock.V_ReadCount += 1
if (self.A_RWLock.V_ReadCount == 1):
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if (self.A_RWLock.V_ReadCount == 0):
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockWriteCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_WriteCount += 1
if (self.A_RWLock.V_WriteCount == 1):
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_WriteCount -= 1
self.A_RWLock.A_LockWriteCount.release()
return False
self.A_RWLock.A_LockWriteCount.release()
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if self.A_RWLock.V_WriteCount == 0:
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if (self.A_RWLock.V_WriteCount == 0):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockWrite._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockWrite._aWriter(self)
class RWLockFair(object):
"""
A Read/Write lock giving fairness to both Reader and Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockRead = threading.Lock()
self.A_LockWrite = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockRead.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockFair._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockFair._aWriter(self)
#!/usr/bin/env python
from Daemon.Daemon import Daemon
import Main
import _pickle as pickle
import socket
import sys
import os
import argparse
import traceback
def client_socket(data_to_send):
# Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Connect the socket to the port where the server is listening
server_address = './uds_socket'
#print('connecting to %s' % server_address)
try:
sock.connect(server_address)
sock.sendall(pickle.dumps(data_to_send))
data_rcv = sock.recv(1024 * 256)
if data_rcv:
print(pickle.loads(data_rcv))
except socket.error:
pass
finally:
#print('closing socket')
sock.close()
class MyDaemon(Daemon):
def run(self):
Main.main()
server_address = './uds_socket'
# Make sure the socket does not already exist
try:
os.unlink(server_address)
except OSError:
if os.path.exists(server_address):
raise
# Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Bind the socket to the port
sock.bind(server_address)
# Listen for incoming connections
sock.listen(1)
while True:
try:
connection, client_address = sock.accept()
data = connection.recv(256 * 1024)
print(sys.stderr, 'sending data back to the client')
print(pickle.loads(data))
args = pickle.loads(data)
if 'list_interfaces' in args and args.list_interfaces:
connection.sendall(pickle.dumps(Main.list_enabled_interfaces()))
elif 'list_neighbors' in args and args.list_neighbors:
connection.sendall(pickle.dumps(Main.list_neighbors()))
elif 'list_state' in args and args.list_state:
connection.sendall(pickle.dumps(Main.list_state()))
elif 'add_interface' in args and args.add_interface:
Main.add_pim_interface(args.add_interface[0], False)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_sr' in args and args.add_interface_sr:
Main.add_pim_interface(args.add_interface_sr[0], True)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_igmp' in args and args.add_interface_igmp:
Main.add_igmp_interface(args.add_interface_igmp[0])
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface' in args and args.remove_interface:
Main.remove_interface(args.remove_interface[0], pim=True)
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface_igmp' in args and args.remove_interface_igmp:
Main.remove_interface(args.remove_interface_igmp[0], igmp=True)
connection.shutdown(socket.SHUT_RDWR)
elif 'stop' in args and args.stop:
Main.stop()
connection.shutdown(socket.SHUT_RDWR)
elif 'test' in args and args.test:
Main.test(args.test[0], args.test[1])
connection.shutdown(socket.SHUT_RDWR)
except Exception:
connection.shutdown(socket.SHUT_RDWR)
traceback.print_exc()
finally:
# Clean up the connection
connection.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PIM')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-start", "--start", action="store_true", default=False, help="Start PIM")
group.add_argument("-stop", "--stop", action="store_true", default=False, help="Stop PIM")
group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM")
group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces")
group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors")
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface")
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)")
group.add_argument("-t", "--test", nargs=2, metavar=('ROUTER_NAME', 'SERVER_LOG_IP'), help="Tester... send log information to SERVER_LOG_IP. Set the router name to ROUTER_NAME")
args = parser.parse_args()
print(parser.parse_args())
daemon = MyDaemon('/tmp/Daemon-pim.pid')
if args.start:
print("start")
daemon.start()
sys.exit(0)
elif args.stop:
client_socket(args)
daemon.stop()
sys.exit(0)
elif args.restart:
daemon.restart()
sys.exit(0)
elif args.verbose:
os.system("tailf stdout")
sys.exit(0)
elif args.multicast_routes:
os.system("ip mroute show")
sys.exit(0)
elif not daemon.is_running():
print("PIM is not running")
parser.print_usage()
sys.exit(0)
client_socket(args)
import socket
import time
import struct
# ficheiros importantes: /usr/include/linux/mroute.h
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # Activate the kernel mroute code */
MRT_DONE = (MRT_BASE+1) #/* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE+2) #/* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE+3) #/* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE+4) #/* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE+5) #/* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE+6) #/* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE+7) #/* Activate PIM assert mode */
MRT_PIM = (MRT_BASE+8) #/* enable PIM code */
MRT_TABLE = (MRT_BASE+9) #/* Specify mroute table ID */
MRT_ADD_MFC_PROXY = (MRT_BASE+10) #/* Add a (*,*|G) mfc entry */
MRT_DEL_MFC_PROXY = (MRT_BASE+11) #/* Del a (*,*|G) mfc entry */
MRT_MAX = (MRT_BASE+11)
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3
s2 = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
#MRT INIT
s2.setsockopt(socket.IPPROTO_IP, MRT_INIT, 1)
#MRT PIM
s2.setsockopt(socket.IPPROTO_IP, MRT_PIM, 1)
#ADD VIRTUAL INTERFACE
#estrutura = struct.pack("HBBI 4s 4s", 1, 0x4, 0, 0, socket.inet_aton("192.168.1.112"), socket.inet_aton("224.1.1.112"))
estrutura = struct.pack("HBBI 4s 4s", 0, 0x0, 1, 0, socket.inet_aton("10.0.0.1"), socket.inet_aton("0.0.0.0"))
print(estrutura)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_VIF, estrutura)
estrutura = struct.pack("HBBI 4s 4s", 1, 0x0, 1, 0, socket.inet_aton("192.168.2.2"), socket.inet_aton("0.0.0.0"))
print(estrutura)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_VIF, estrutura)
#time.sleep(5)
while True:
print("recv:")
msg = s2.recv(5000)
print(len(msg))
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst, _) = struct.unpack("II B B B B 4s 4s 8s", msg)
print(im_msgtype)
print(im_mbz)
print(im_vif)
print(socket.inet_ntoa(im_src))
print(socket.inet_ntoa(im_dst))
if im_msgtype == IGMPMSG_NOCACHE:
print("^^ IGMP NO CACHE")
print(struct.unpack("II B B B B 4s 4s 8s", msg))
#s2.setsockopt(socket.IPPROTO_IP, MRT_PIM, 1)
#print(s2.getsockopt(socket.IPPROTO_IP, 208))
#s2.setsockopt(socket.IPPROTO_IP, 208, 0)
#ADD MULTICAST FORWARDING ENTRY
estrutura = struct.pack("4s 4s H BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB IIIi", socket.inet_aton("10.0.0.2"), socket.inet_aton("224.1.1.113"), 0, 0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0, 0, 0)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_MFC, estrutura)
time.sleep(30)
#MRT DONE
s2.setsockopt(socket.IPPROTO_IP, MRT_DONE, 1)
s2.close()
exit(1)
import logging
class RootFilter(logging.Filter):
"""
This is a filter which injects contextual information into the log.
Rather than use actual contextual information, we just use random
data in this demo.
"""
def __init__(self, router_name):
super().__init__()
self.router_name = router_name
def filter(self, record):
record.routername = self.router_name
if not hasattr(record, 'tree'):
record.tree = ''
if not hasattr(record, 'vif'):
record.vif = ''
if not hasattr(record, 'interfacename'):
record.interfacename = ''
if not hasattr(record, 'neighbor_ip'):
record.neighbor_ip = ''
return True
import socket
import ipaddress
from pyroute2 import IPDB, IPRoute
import Main
from utils import if_indextoname
def get_route(ip_dst: str):
return UnicastRouting.get_route(ip_dst)
def get_metric(ip_dst: str):
return UnicastRouting.get_metric(ip_dst)
def check_rpf(ip_dst):
return UnicastRouting.check_rpf(ip_dst)
def get_unicast_info(ip_dst):
return UnicastRouting.get_unicast_info(ip_dst)
class UnicastRouting(object):
ipdb = None
def __init__(self):
UnicastRouting.ipr = IPRoute()
UnicastRouting.ipdb = IPDB()
self._ipdb = UnicastRouting.ipdb
self._ipdb.register_callback(UnicastRouting.unicast_changes, mode="post")
# get metrics (routing preference and cost) to IP ip_dst
@staticmethod
def get_metric(ip_dst: str):
(metric_administrative_distance, metric_cost, _, _, mask) = UnicastRouting.get_unicast_info(ip_dst)
return (metric_administrative_distance, metric_cost, mask)
# get output interface IP, used to send data to IP ip_dst
# (root interface IP to ip_dst)
@staticmethod
def check_rpf(ip_dst):
# vif index of rpf interface
return UnicastRouting.get_unicast_info(ip_dst)[3]
@staticmethod
def get_route(ip_dst: str):
ip_bytes = socket.inet_aton(ip_dst)
ip_int = int.from_bytes(ip_bytes, byteorder='big')
info = None
ipdb = UnicastRouting.ipdb # type:IPDB
for mask_len in range(32, 0, -1):
ip_bytes = (ip_int & (0xFFFFFFFF << (32 - mask_len))).to_bytes(4, "big")
ip_dst = socket.inet_ntoa(ip_bytes) + "/" + str(mask_len)
print(ip_dst)
if ip_dst in ipdb.routes:
print(info)
if ipdb.routes[ip_dst]['ipdb_scope'] != 'gc':
info = ipdb.routes[ip_dst]
break
else:
continue
if not info:
print("0.0.0.0/0")
if "default" in ipdb.routes:
info = ipdb.routes["default"]
print(info)
return info
@staticmethod
def get_unicast_info(ip_dst):
metric_administrative_distance = 0xFFFFFFFF
metric_cost = 0xFFFFFFFF
rpf_node = ip_dst
oif = None
mask = 0
unicast_route = UnicastRouting.get_route(ip_dst)
if unicast_route is not None:
oif = unicast_route.get("oif")
next_hop = unicast_route["gateway"]
multipaths = unicast_route["multipath"]
# prefsrc = unicast_route.get("prefsrc")
# rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop
rpf_node = next_hop if next_hop is not None else ip_dst
highest_ip = ipaddress.ip_address("0.0.0.0")
for m in multipaths:
if m["gateway"] is None:
oif = m.get('oif')
rpf_node = ip_dst
break
elif ipaddress.ip_address(m["gateway"]) > highest_ip:
highest_ip = ipaddress.ip_address(m["gateway"])
oif = m.get('oif')
rpf_node = m["gateway"]
metric_administrative_distance = unicast_route["proto"]
metric_cost = unicast_route["priority"]
metric_cost = metric_cost if metric_cost is not None else 0
mask = unicast_route["dst_len"]
interface_name = None if oif is None else if_indextoname(int(oif))
rpf_if = Main.kernel.vif_name_to_index_dic.get(interface_name)
return (metric_administrative_distance, metric_cost, rpf_node, rpf_if, mask)
@staticmethod
def unicast_changes(ipdb, msg, action):
print("unicast change?")
print(action)
if action == "RTM_NEWROUTE" or action == "RTM_DELROUTE":
print(ipdb.routes)
mask_len = msg["dst_len"]
network_address = None
attrs = msg["attrs"]
print(attrs)
for (key, value) in attrs:
print((key, value))
if key == "RTA_DST":
network_address = value
break
if network_address is None:
network_address = "0.0.0.0"
print(network_address)
print(mask_len)
print(network_address + "/" + str(mask_len))
subnet = ipaddress.ip_network(network_address + "/" + str(mask_len))
print(str(subnet))
Main.kernel.notify_unicast_changes(subnet)
'''
elif action == "RTM_NEWADDR" or action == "RTM_DELADDR":
print(action)
print(msg)
interface_name = None
attrs = msg["attrs"]
for (key, value) in attrs:
print((key, value))
if key == "IFA_LABEL":
interface_name = value
break
UnicastRouting.lock.release()
try:
Main.kernel.notify_interface_changes(interface_name)
except:
import traceback
traceback.print_exc()
pass
bnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet)
elif action == "RTM_NEWLINK" or action == "RTM_DELLINK":
attrs = msg["attrs"]
if_name = None
operation = None
for (key, value) in attrs:
print((key, value))
if key == "IFLA_IFNAME":
if_name = value
elif key == "IFLA_OPERSTATE":
operation = value
if if_name is not None and operation is not None:
break
if if_name is not None:
print(if_name + ": " + operation)
UnicastRouting.lock.release()
if operation == 'DOWN':
Main.kernel.remove_interface(if_name, igmp=True, pim=True)
subnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet)
'''
def stop(self):
if UnicastRouting.ipdb:
UnicastRouting.ipdb = None
if self._ipdb:
self._ipdb.release()
import logging
from threading import Lock
from threading import Timer
from utils import GroupMembershipInterval, LastMemberQueryInterval, TYPE_CHECKING
from .wrapper import NoMembersPresent
if TYPE_CHECKING:
from .RouterState import RouterState
class GroupState(object):
LOGGER = logging.getLogger('pim.igmp.RouterState.GroupState')
def __init__(self, router_state: 'RouterState', group_ip: str):
#logger
extra_dict_logger = router_state.router_state_logger.extra.copy()
extra_dict_logger['tree'] = '(*,' + group_ip + ')'
self.group_state_logger = logging.LoggerAdapter(GroupState.LOGGER, extra_dict_logger)
#timers and state
self.router_state = router_state
self.group_ip = group_ip
self.state = NoMembersPresent
self.timer = None
self.v1_host_timer = None
self.retransmit_timer = None
# lock
self.lock = Lock()
# KernelEntry's instances to notify change of igmp state
self.multicast_interface_state = []
self.multicast_interface_state_lock = Lock()
def print_state(self):
return self.state.print_state()
###########################################
# Set state
###########################################
def set_state(self, state):
self.state = state
self.group_state_logger.debug("change membership state to: " + state.print_state())
###########################################
# Set timers
###########################################
def set_timer(self, alternative: bool=False, max_response_time: int=None):
self.clear_timer()
if not alternative:
time = GroupMembershipInterval
else:
time = self.router_state.interface_state.get_group_membership_time(max_response_time)
timer = Timer(time, self.group_membership_timeout)
timer.start()
self.timer = timer
def clear_timer(self):
if self.timer is not None:
self.timer.cancel()
def set_v1_host_timer(self):
self.clear_v1_host_timer()
v1_host_timer = Timer(GroupMembershipInterval, self.group_membership_v1_timeout)
v1_host_timer.start()
self.v1_host_timer = v1_host_timer
def clear_v1_host_timer(self):
if self.v1_host_timer is not None:
self.v1_host_timer.cancel()
def set_retransmit_timer(self):
self.clear_retransmit_timer()
retransmit_timer = Timer(LastMemberQueryInterval, self.retransmit_timeout)
retransmit_timer.start()
self.retransmit_timer = retransmit_timer
def clear_retransmit_timer(self):
if self.retransmit_timer is not None:
self.retransmit_timer.cancel()
###########################################
# Get group state from specific interface state
###########################################
def get_interface_group_state(self):
return self.state.get_state(self.router_state)
###########################################
# Timer timeout
###########################################
def group_membership_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_timeout(self)
def group_membership_v1_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_v1_timeout(self)
def retransmit_timeout(self):
with self.lock:
self.get_interface_group_state().retransmit_timeout(self)
###########################################
# Receive Packets
###########################################
def receive_v1_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v1_membership_report(self)
def receive_v2_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v2_membership_report(self)
def receive_leave_group(self):
with self.lock:
self.get_interface_group_state().receive_leave_group(self)
def receive_group_specific_query(self, max_response_time: int):
with self.lock:
self.get_interface_group_state().receive_group_specific_query(self, max_response_time)
###########################################
# Notify Routing
###########################################
def notify_routing_add(self):
with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=True)
def notify_routing_remove(self):
with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.append(kernel_entry)
return self.has_members()
def remove_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.remove(kernel_entry)
def has_members(self):
return self.state is not NoMembersPresent
def remove(self):
with self.multicast_interface_state_lock:
self.clear_retransmit_timer()
self.clear_timer()
self.clear_v1_host_timer()
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
del self.multicast_interface_state[:]
from threading import Timer
import logging
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from utils import Membership_Query, QueryResponseInterval, QueryInterval, OtherQuerierPresentInterval, TYPE_CHECKING
from RWLock.RWLock import RWLockWrite
from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState
if TYPE_CHECKING:
from InterfaceIGMP import InterfaceIGMP
class RouterState(object):
ROUTER_STATE_LOGGER = logging.getLogger('pim.igmp.RouterState')
def __init__(self, interface: 'InterfaceIGMP'):
#logger
logger_extra = dict()
logger_extra['vif'] = interface.vif_index
logger_extra['interfacename'] = interface.interface_name
self.router_state_logger = logging.LoggerAdapter(RouterState.ROUTER_STATE_LOGGER, logger_extra)
# interface of the router connected to the network
self.interface = interface
# state of the router (Querier/NonQuerier)
self.interface_state = Querier
# state of each group
# Key: GroupIPAddress, Value: GroupState object
self.group_state = {}
self.group_state_lock = RWLockWrite()
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
self.interface.send(packet.bytes())
# set initial general query timer
timer = Timer(QueryInterval, self.general_query_timeout)
timer.start()
self.general_query_timer = timer
# present timer
self.other_querier_present_timer = None
# Send packet via interface
def send(self, data: bytes, address: str):
self.interface.send(data, address)
############################################
# interface_state methods
############################################
def print_state(self):
return self.interface_state.state_name()
def set_general_query_timer(self):
self.clear_general_query_timer()
general_query_timer = Timer(QueryInterval, self.general_query_timeout)
general_query_timer.start()
self.general_query_timer = general_query_timer
def clear_general_query_timer(self):
if self.general_query_timer is not None:
self.general_query_timer.cancel()
def set_other_querier_present_timer(self):
self.clear_other_querier_present_timer()
other_querier_present_timer = Timer(OtherQuerierPresentInterval, self.other_querier_present_timeout)
other_querier_present_timer.start()
self.other_querier_present_timer = other_querier_present_timer
def clear_other_querier_present_timer(self):
if self.other_querier_present_timer is not None:
self.other_querier_present_timer.cancel()
def general_query_timeout(self):
self.interface_state.general_query_timeout(self)
def other_querier_present_timeout(self):
self.interface_state.other_querier_present_timeout(self)
def change_interface_state(self, querier: bool):
if querier:
self.interface_state = Querier
self.router_state_logger.debug('change querier state to -> Querier')
else:
self.interface_state = NonQuerier
self.router_state_logger.debug('change querier state to -> NonQuerier')
############################################
# group state methods
############################################
def get_group_state(self, group_ip):
with self.group_state_lock.genRlock():
if group_ip in self.group_state:
return self.group_state[group_ip]
with self.group_state_lock.genWlock():
if group_ip in self.group_state:
group_state = self.group_state[group_ip]
else:
group_state = GroupState(self, group_ip)
self.group_state[group_ip] = group_state
return group_state
def receive_v1_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v1_membership_report()
self.get_group_state(igmp_group).receive_v1_membership_report()
def receive_v2_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v2_membership_report()
self.get_group_state(igmp_group).receive_v2_membership_report()
def receive_leave_group(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group in self.group_state:
# self.group_state[igmp_group].receive_leave_group()
self.get_group_state(igmp_group).receive_leave_group()
def receive_query(self, packet: ReceivedPacket):
self.interface_state.receive_query(self, packet)
igmp_group = packet.payload.group_address
# process group specific query
if igmp_group != "0.0.0.0" and igmp_group in self.group_state:
#if igmp_group != "0.0.0.0":
max_response_time = packet.payload.max_resp_time
#self.group_state[igmp_group].receive_group_specific_query(max_response_time)
self.get_group_state(igmp_group).receive_group_specific_query(max_response_time)
def remove(self):
for group in self.group_state.values():
group.remove()
\ No newline at end of file
from utils import TYPE_CHECKING
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_v1_membership_report')
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_group_specific_query')
# do nothing
return
from utils import TYPE_CHECKING
from ..wrapper import NoMembersPresent
from ..wrapper import CheckingMembership
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_v1_membership_report')
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_v2_membership_report')
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_group_specific_query')
group_state.set_timer(alternative=True, max_response_time=max_response_time)
group_state.set_state(CheckingMembership)
from utils import TYPE_CHECKING
from ..wrapper import MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: group_membership_timeout')
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_v1_membership_report')
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_group_specific_query')
# do nothing
return
from ipaddress import IPv4Address
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, TYPE_CHECKING
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from . import NoMembersPresent, MembersPresent, CheckingMembership
if TYPE_CHECKING:
from ..RouterState import RouterState
class NonQuerier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: general_query_timeout')
# do nothing
return
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: other_querier_present_timeout')
#change state to Querier
router_state.change_interface_state(querier=True)
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('NonQuerier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# reset other present querier timer
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Non Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return (max_response_time/10.0) * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return NonQuerier.get_members_present_state()
from Packet.PacketIGMPHeader import PacketIGMPHeader
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: group_membership_timeout')
group_state.clear_retransmit_timer()
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: retransmit_timeout')
group_addr = group_state.group_ip
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_addr)
group_state.router_state.send(data=packet.bytes(), address=group_addr)
group_state.set_retransmit_timer()
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.set_state(Version1MembersPresent)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_group_specific_query')
# do nothing
return
from Packet.PacketIGMPHeader import PacketIGMPHeader
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.set_state(Version1MembersPresent)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: receive_v2_membership_report')
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: receive_leave_group')
group_ip = group_state.group_ip
group_state.set_timer(alternative=True)
group_state.set_retransmit_timer()
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_ip)
group_state.router_state.send(data=packet.bytes(), address=group_ip)
group_state.set_state(CheckingMembership)
def receive_group_specific_query(group_state: 'GroupState', max_response_time):
group_state.group_state_logger.debug('Querier MembersPresent: receive_group_specific_query')
# do nothing
return
from utils import TYPE_CHECKING
from ..wrapper import MembersPresent
from ..wrapper import Version1MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: group_membership_timeout')
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.set_state(Version1MembersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_group_specific_query')
# do nothing
return
from ipaddress import IPv4Address
from utils import TYPE_CHECKING
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent
if TYPE_CHECKING:
from ..RouterState import RouterState
class Querier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: general_query_timeout')
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: other_querier_present_timeout')
# do nothing
return
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('Querier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# if source ip of membership query lower than the ip of the received interface => change state
# change state of interface
# Querier -> Non Querier
router_state.change_interface_state(querier=False)
# set other present querier timer
router_state.clear_general_query_timer()
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return LastMemberQueryInterval * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return Version1MembersPresent
from utils import TYPE_CHECKING
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: group_membership_v1_timeout')
group_state.set_state(MembersPresent)
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_v2_membership_report')
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_group_specific_query')
# do nothing
return
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_checking_membership_state()
def print_state():
return "CheckingMembership"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_members_present_state()
def print_state():
return "MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_no_members_present_state()
def print_state():
return "NoMembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_version_1_members_present_state()
def print_state():
return "Version1MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
import subprocess
import struct
import socket
from ctypes import create_string_buffer, addressof
SO_ATTACH_FILTER = 26
ETH_P_IP = 0x0800 # Internet Protocol packet
SO_RCVBUFFORCE = 33
def get_s_g_bpf_filter_code(source, group, interface_name):
#cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group)
cmd = "tcpdump -ddd \"(ip proto not 2) and host %s and dst %s\"" % (source, group)
result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
bpf_filter = b''
tmp = result.stdout.read().splitlines()
num = int(tmp[0])
for line in tmp[1:]:
print(line)
bpf_filter += struct.pack("HBBI", *tuple(map(int, line.split(b' '))))
print(num)
# defined in linux/filter.h.
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', num, mem_addr_of_filters)
# Create listening socket with filters
s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, ETH_P_IP)
s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
# todo pequeno ajuste (tamanho de buffer pequeno para o caso de trafego em rajadas):
#s.setsockopt(socket.SOL_SOCKET, SO_RCVBUFFORCE, 1)
s.bind((interface_name, ETH_P_IP))
return s
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
'''
Created on Feb 23, 2015
This module is intended to have all constants and global values for pim_dm
@author: alex
'''
ASSERT_TIME = 180
GRAFT_RETRY_PERIOD = 3
JP_OVERRIDE_INTERVAL = 3.0
OVERRIDE_INTERVAL = 2.5
PROPAGATION_DELAY = 0.5
REFRESH_INTERVAL = 60 # State Refresh Interval
SOURCE_LIFETIME = 210
T_LIMIT = 210
ASSERT_CANCEL_METRIC = 0xFFFFFFFF
\ No newline at end of file
from abc import ABCMeta, abstractmethod
class LocalMembershipStateABC(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def has_members():
raise NotImplementedError
class NoInfo(LocalMembershipStateABC):
@staticmethod
def has_members():
return False
class Include(LocalMembershipStateABC):
@staticmethod
def has_members():
return True
class LocalMembership():
NoInfo = NoInfo()
Include = Include()
\ No newline at end of file
import ipaddress
class AssertMetric(object):
def __init__(self, metric_preference: int = 0x7FFFFFFF, route_metric: int = 0xFFFFFFFF, ip_address: str = "0.0.0.0", state_refresh_interval:int = None):
if type(ip_address) is str:
ip_address = ipaddress.ip_address(ip_address)
self._metric_preference = metric_preference
self._route_metric = route_metric
self._ip_address = ip_address
self._state_refresh_interval = state_refresh_interval
def is_better_than(self, other):
if self.metric_preference != other.metric_preference:
return self.metric_preference < other.metric_preference
elif self.route_metric != other.route_metric:
return self.route_metric < other.route_metric
else:
return self.ip_address > other.ip_address
def is_worse(self, other):
return not self.is_better_than(other)
def equal_metric(self, other):
return self.metric_preference == other.metric_preference and self.metric_preference == other.metric_preference \
and self.ip_address == other.ip_address
@staticmethod
def infinite_assert_metric():
'''
@type metric: AssertMetric
'''
return AssertMetric()
@staticmethod
def spt_assert_metric(tree_if):
'''
@type metric: AssertMetric
@type tree_if: TreeInterface
'''
(source_ip, _) = tree_if.get_tree_id()
import UnicastRouting
(metric_preference, metric_cost, _) = UnicastRouting.get_metric(source_ip)
return AssertMetric(metric_preference, metric_cost, tree_if.get_ip())
def i_am_assert_winner(self, tree_if):
return self.get_ip() == tree_if.get_ip()
@property
def metric_preference(self):
return self._metric_preference
@metric_preference.setter
def metric_preference(self, value):
self._metric_preference = value
@property
def route_metric(self):
return self._route_metric
@route_metric.setter
def route_metric(self, value):
self._route_metric = value
@property
def ip_address(self):
return self._ip_address
@ip_address.setter
def ip_address(self, value):
if type(value) is str:
value = ipaddress.ip_address(value)
self._ip_address = value
@property
def state_refresh_interval(self):
return self._state_refresh_interval
@state_refresh_interval.setter
def state_refresh_interval(self, value):
self._state_refresh_interval = value
def get_ip(self):
return str(self._ip_address)
from abc import ABCMeta, abstractstaticmethod
from tree import globals as pim_globals
class OriginatorStateABC(metaclass=ABCMeta):
@abstractstaticmethod
def recvDataMsgFromSource(tree):
pass
@abstractstaticmethod
def SRTexpires(tree):
pass
@abstractstaticmethod
def SATexpires(tree):
pass
@abstractstaticmethod
def SourceNotConnected(tree):
pass
class Originator(OriginatorStateABC):
@staticmethod
def recvDataMsgFromSource(tree):
tree.set_source_active_timer()
@staticmethod
def SRTexpires(tree):
'''
@type tree: Tree
'''
tree.originator_logger.debug('SRT expired, O -> O')
tree.set_state_refresh_timer()
tree.create_state_refresh_msg()
@staticmethod
def SATexpires(tree):
tree.originator_logger.debug('SAT expired, O -> NO')
tree.clear_state_refresh_timer()
tree.set_originator_state(OriginatorState.NotOriginator)
@staticmethod
def SourceNotConnected(tree):
tree.originator_logger.debug('Source no longer directly connected, O -> NO')
tree.clear_state_refresh_timer()
tree.clear_source_active_timer()
tree.set_originator_state(OriginatorState.NotOriginator)
def __str__(self):
return 'Originator'
class NotOriginator(OriginatorStateABC):
@staticmethod
def recvDataMsgFromSource(tree):
'''
@type interface: Tree
'''
tree.originator_logger.debug('new DataMsg from Source, NO -> O')
tree.set_originator_state(OriginatorState.Originator)
tree.set_state_refresh_timer()
tree.set_source_active_timer()
@staticmethod
def SRTexpires(tree):
assert False, "SRTexpires in NO"
@staticmethod
def SATexpires(tree):
assert False, "SATexpires in NO"
@staticmethod
def SourceNotConnected(tree):
return
def __str__(self):
return 'NotOriginator'
class OriginatorState():
NotOriginator = NotOriginator()
Originator = Originator()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import array
'''
import struct
if struct.pack("H",1) == "\x00\x01": # big endian
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return s & 0xffff
else:
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s>>8)&0xff)|s<<8) & 0xffff
'''
HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF
HELLO_HOLD_TIME = 160
HELLO_HOLD_TIME_TIMEOUT = 0
def checksum(pkt: bytes) -> bytes:
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s >> 8) & 0xff) | s << 8) & 0xffff
import ctypes
import ctypes.util
libc = ctypes.CDLL(ctypes.util.find_library('c'))
def if_nametoindex(name):
if not isinstance(name, str):
raise TypeError('name must be a string.')
ret = libc.if_nametoindex(name)
if not ret:
raise RuntimeError("Invalid Name")
return ret
def if_indextoname(index):
if not isinstance(index, int):
raise TypeError('index must be an int.')
libc.if_indextoname.argtypes = [ctypes.c_uint32, ctypes.c_char_p]
libc.if_indextoname.restype = ctypes.c_char_p
ifname = ctypes.create_string_buffer(32)
ifname = libc.if_indextoname(index, ifname)
if not ifname:
raise RuntimeError ("Inavlid Index")
return ifname.decode("utf-8")
# obtain TYPE_CHECKING (for type hinting)
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False
# IGMP timers (in seconds)
RobustnessVariable = 2
QueryInterval = 125
QueryResponseInterval = 10
MaxResponseTime_QueryResponseInterval = QueryResponseInterval*10
GroupMembershipInterval = RobustnessVariable * QueryInterval + QueryResponseInterval
OtherQuerierPresentInterval = RobustnessVariable * QueryInterval + QueryResponseInterval/2
StartupQueryInterval = QueryInterval / 4
StartupQueryCount = RobustnessVariable
LastMemberQueryInterval = 1
MaxResponseTime_LastMemberQueryInterval = LastMemberQueryInterval*10
LastMemberQueryCount = RobustnessVariable
UnsolicitedReportInterval = 10
Version1RouterPresentTimeout = 400
# IGMP msg type
Membership_Query = 0x11
Version_1_Membership_Report = 0x12
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment