Commit b26c6216 authored by Pedro Oliveira's avatar Pedro Oliveira Committed by GitHub

IPv6 multicast routing support (#3)

parent f6ac9c82
...@@ -8,11 +8,14 @@ We have implemented PIM-DM specification ([RFC3973](https://tools.ietf.org/html/ ...@@ -8,11 +8,14 @@ We have implemented PIM-DM specification ([RFC3973](https://tools.ietf.org/html/
This repository stores the implementation of this protocol. The implementation is written in Python language and is destined to Linux systems. This repository stores the implementation of this protocol. The implementation is written in Python language and is destined to Linux systems.
Additionally, IGMPv2 and MLDv1 are implemented alongside with PIM-DM to detect interest of hosts.
# Requirements # Requirements
- Linux machine - Linux machine
- Python3 (we have written all code to be compatible with at least Python v3.2) - Unicast routing protocol
- Python3 (we have written all code to be compatible with at least Python v3.3)
- pip (to install all dependencies) - pip (to install all dependencies)
- tcpdump - tcpdump
...@@ -41,6 +44,8 @@ In order to start the protocol you first need to explicitly start it. This will ...@@ -41,6 +44,8 @@ In order to start the protocol you first need to explicitly start it. This will
sudo pim-dm -start sudo pim-dm -start
``` ```
IPv4 and IPv6 multicast is supported. By default all commands will be executed on IPv4 daemon. To execute a command on the IPv6 daemon use `-6`.
#### Add interface #### Add interface
...@@ -49,21 +54,27 @@ After starting the protocol process you can enable the protocol in specific inte ...@@ -49,21 +54,27 @@ After starting the protocol process you can enable the protocol in specific inte
- To enable PIM-DM without State-Refresh, in a given interface, you need to run the following command: - To enable PIM-DM without State-Refresh, in a given interface, you need to run the following command:
``` ```
sudo pim-dm -ai INTERFACE_NAME sudo pim-dm -ai INTERFACE_NAME [-4 | -6]
``` ```
- To enable PIM-DM with State-Refresh, in a given interface, you need to run the following command: - To enable PIM-DM with State-Refresh, in a given interface, you need to run the following command:
``` ```
sudo pim-dm -aisf INTERFACE_NAME sudo pim-dm -aisf INTERFACE_NAME [-4 | -6]
``` ```
- To enable IGMP in a given interface, you need to run the following command: - To enable IGMP/MLD in a given interface, you need to run the following command:
- IGMP:
``` ```
sudo pim-dm -aiigmp INTERFACE_NAME sudo pim-dm -aiigmp INTERFACE_NAME
``` ```
- MLD:
```
sudo pim-dm -aimld INTERFACE_NAME
```
#### Remove interface #### Remove interface
To remove a previously added interface, you need run the following commands: To remove a previously added interface, you need run the following commands:
...@@ -71,15 +82,20 @@ To remove a previously added interface, you need run the following commands: ...@@ -71,15 +82,20 @@ To remove a previously added interface, you need run the following commands:
- To remove a previously added PIM-DM interface: - To remove a previously added PIM-DM interface:
``` ```
sudo pim-dm -ri INTERFACE_NAME sudo pim-dm -ri INTERFACE_NAME [-4 | -6]
``` ```
- To remove a previously added IGMP interface: - To remove a previously added IGMP/MLD interface:
- IGMP:
``` ```
sudo pim-dm -riigmp INTERFACE_NAME sudo pim-dm -riigmp INTERFACE_NAME
``` ```
- MLD:
```
sudo pim-dm -rimld INTERFACE_NAME
```
#### Stop protocol process #### Stop protocol process
...@@ -96,31 +112,31 @@ We have built some list commands that can be used to check the "internals" of th ...@@ -96,31 +112,31 @@ We have built some list commands that can be used to check the "internals" of th
- #### List interfaces: - #### List interfaces:
Show all router interfaces and which ones have PIM-DM and IGMP enabled. For IGMP enabled interfaces check the IGMP Querier state. Show all router interfaces and which ones have PIM-DM and IGMP/MLD enabled. For IGMP/MLD enabled interfaces you can check the Querier state.
``` ```
sudo pim-dm -li sudo pim-dm -li [-4 | -6]
``` ```
- #### List neighbors - #### List neighbors
Verify neighbors that have established a neighborhood relationship. Verify neighbors that have established a neighborhood relationship.
``` ```
sudo pim-dm -ln sudo pim-dm -ln [-4 | -6]
``` ```
- #### List state - #### List state
List all state machines and corresponding state of all trees that are being monitored. Also list IGMP state for each group being monitored. List all state machines and corresponding state of all trees that are being monitored. Also list IGMP state for each group being monitored.
``` ```
sudo pim-dm -ls sudo pim-dm -ls [-4 | -6]
``` ```
- #### Multicast Routing Table - #### Multicast Routing Table
List Linux Multicast Routing Table (equivalent to `ip mroute -show`) List Linux Multicast Routing Table (equivalent to `ip mroute show`)
``` ```
sudo pim-dm -mr sudo pim-dm -mr [-4 | -6]
``` ```
...@@ -131,15 +147,10 @@ In order to determine which commands and corresponding arguments are available y ...@@ -131,15 +147,10 @@ In order to determine which commands and corresponding arguments are available y
pim-dm -h pim-dm -h
``` ```
or
```
pim-dm --help
```
## Change settings ## Change settings
Files tree/globals.py and igmp/igmp_globals.py store all timer values and some configurations regarding IGMP and the PIM-DM. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface. Files tree/globals.py, igmp/igmp_globals.py and mld/mld_globals.py store all timer values and some configurations regarding PIM-DM, IGMP and MLD. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface.
## Tests ## Tests
...@@ -151,4 +162,4 @@ We have performed tests to our PIM-DM implementation. You can check on the corre ...@@ -151,4 +162,4 @@ We have performed tests to our PIM-DM implementation. You can check on the corre
- [Test_PIM_Assert](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Assert) - Topology used to test the election of the AssertWinner. - [Test_PIM_Assert](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Assert) - Topology used to test the election of the AssertWinner.
- [Test_PIM_Join_Prune_Graft](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Join_Prune_Graft) - Topology used to test the Pruning and Grafting of the multicast distribution tree. - [Test_PIM_Join_Prune_Graft](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Join_Prune_Graft) - Topology used to test the Pruning and Grafting of the multicast distribution tree.
- [Test_PIM_StateRefresh](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_StateRefresh) - Topology used to test PIM-DM State Refresh. - [Test_PIM_StateRefresh](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_StateRefresh) - Topology used to test PIM-DM State Refresh.
- [Test_IGMP](https://github.com/pedrofran12/hpim_dm/tree/Test_IGMP) - Topology used to test our IGMPv2 implementation. - [Test_IGMP](https://github.com/pedrofran12/pim_dm/tree/Test_IGMP) - Topology used to test our IGMPv2 implementation.
...@@ -18,8 +18,11 @@ class Interface(metaclass=ABCMeta): ...@@ -18,8 +18,11 @@ class Interface(metaclass=ABCMeta):
self._recv_socket = recv_socket self._recv_socket = recv_socket
self.interface_enabled = False self.interface_enabled = False
def _enable(self): def _enable(self):
"""
Enable this interface
This will start a thread to be executed in the background to be used in the reception of control packets
"""
self.interface_enabled = True self.interface_enabled = True
# run receive method in background # run receive method in background
receive_thread = threading.Thread(target=self.receive) receive_thread = threading.Thread(target=self.receive)
...@@ -27,24 +30,39 @@ class Interface(metaclass=ABCMeta): ...@@ -27,24 +30,39 @@ class Interface(metaclass=ABCMeta):
receive_thread.start() receive_thread.start()
def receive(self): def receive(self):
"""
Method that will be executed in the background for the reception of control packets
"""
while self.interface_enabled: while self.interface_enabled:
try: try:
(raw_bytes, _) = self._recv_socket.recvfrom(256 * 1024) (raw_bytes, ancdata, _, src_addr) = self._recv_socket.recvmsg(256 * 1024, 500)
if raw_bytes: if raw_bytes:
self._receive(raw_bytes) self._receive(raw_bytes, ancdata, src_addr)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
continue continue
@abstractmethod @abstractmethod
def _receive(self, raw_bytes): def _receive(self, raw_bytes, ancdata, src_addr):
"""
Subclass method to be implemented
This method will be invoked whenever a new control packet is received
"""
raise NotImplementedError raise NotImplementedError
def send(self, data: bytes, group_ip: str): def send(self, data: bytes, group_ip: str):
"""
Send a control packet through this interface
Explicitly destined to group_ip (can be unicast or multicast IP)
"""
if self.interface_enabled and data: if self.interface_enabled and data:
self._send_socket.sendto(data, (group_ip, 0)) self._send_socket.sendto(data, (group_ip, 0))
def remove(self): def remove(self):
"""
This interface is no longer active....
Clear all state regarding it
"""
self.interface_enabled = False self.interface_enabled = False
try: try:
self._recv_socket.shutdown(socket.SHUT_RDWR) self._recv_socket.shutdown(socket.SHUT_RDWR)
...@@ -54,8 +72,14 @@ class Interface(metaclass=ABCMeta): ...@@ -54,8 +72,14 @@ class Interface(metaclass=ABCMeta):
self._send_socket.close() self._send_socket.close()
def is_enabled(self): def is_enabled(self):
"""
Verify if this interface is enabled
"""
return self.interface_enabled return self.interface_enabled
@abstractmethod @abstractmethod
def get_ip(self): def get_ip(self):
"""
Get IP of this interface
"""
raise NotImplementedError raise NotImplementedError
...@@ -4,7 +4,7 @@ from ipaddress import IPv4Address ...@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from ctypes import create_string_buffer, addressof from ctypes import create_string_buffer, addressof
import netifaces import netifaces
from pimdm.Interface import Interface from pimdm.Interface import Interface
from pimdm.Packet.ReceivedPacket import ReceivedPacket from pimdm.packet.ReceivedPacket import ReceivedPacket
from pimdm.igmp.igmp_globals import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query from pimdm.igmp.igmp_globals import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query
if not hasattr(socket, 'SO_BINDTODEVICE'): if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25 socket.SO_BINDTODEVICE = 25
...@@ -48,18 +48,20 @@ class InterfaceIGMP(Interface): ...@@ -48,18 +48,20 @@ class InterfaceIGMP(Interface):
self.interface_state = RouterState(self) self.interface_state = RouterState(self)
super()._enable() super()._enable()
def get_ip(self): def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr'] return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
@property @property
def ip_interface(self): def ip_interface(self):
"""
Get IP of this interface
"""
return self.get_ip() return self.get_ip()
def send(self, data: bytes, address: str="224.0.0.1"): def send(self, data: bytes, address: str="224.0.0.1"):
super().send(data, address) super().send(data, address)
def _receive(self, raw_bytes): def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes: if raw_bytes:
raw_bytes = raw_bytes[14:] raw_bytes = raw_bytes[14:]
packet = ReceivedPacket(raw_bytes, self) packet = ReceivedPacket(raw_bytes, self)
...@@ -91,7 +93,8 @@ class InterfaceIGMP(Interface): ...@@ -91,7 +93,8 @@ class InterfaceIGMP(Interface):
def receive_membership_query(self, packet): def receive_membership_query(self, packet):
ip_dst = packet.ip_header.ip_dst ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address igmp_group = packet.payload.group_address
if ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"): if (IPv4Address(igmp_group).is_multicast and ip_dst == igmp_group) or \
(ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"):
self.interface_state.receive_query(packet) self.interface_state.receive_query(packet)
def receive_unknown_type(self, packet): def receive_unknown_type(self, packet):
......
import socket
import struct
import netifaces
import ipaddress
from socket import if_nametoindex
from ipaddress import IPv6Address
from .Interface import Interface
from .packet.ReceivedPacket import ReceivedPacket_v6
from .mld.mld_globals import MULTICAST_LISTENER_QUERY_TYPE, MULTICAST_LISTENER_DONE_TYPE, MULTICAST_LISTENER_REPORT_TYPE
from ctypes import create_string_buffer, addressof
ETH_P_IPV6 = 0x86DD # IPv6 over bluebook
SO_ATTACH_FILTER = 26
ICMP6_FILTER = 1
IPV6_ROUTER_ALERT = 22
def ICMP6_FILTER_SETBLOCKALL():
return struct.pack("I"*8, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
def ICMP6_FILTER_SETPASS(type, filterp):
return filterp[:type >> 5] + (bytes([(filterp[type >> 5] & ~(1 << ((type) & 31)))])) + filterp[(type >> 5) + 1:]
class InterfaceMLD(Interface):
IPv6_LINK_SCOPE_ALL_NODES = IPv6Address("ff02::1")
IPv6_LINK_SCOPE_ALL_ROUTERS = IPv6Address("ff02::2")
IPv6_ALL_ZEROS = IPv6Address("::")
FILTER_MLD = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
struct.pack('HBBI', 0x15, 0, 9, 0x000086dd),
struct.pack('HBBI', 0x30, 0, 0, 0x00000014),
struct.pack('HBBI', 0x15, 0, 7, 0x00000000),
struct.pack('HBBI', 0x30, 0, 0, 0x00000036),
struct.pack('HBBI', 0x15, 0, 5, 0x0000003a),
struct.pack('HBBI', 0x30, 0, 0, 0x0000003e),
struct.pack('HBBI', 0x15, 2, 0, 0x00000082),
struct.pack('HBBI', 0x15, 1, 0, 0x00000083),
struct.pack('HBBI', 0x15, 0, 1, 0x00000084),
struct.pack('HBBI', 0x6, 0, 0, 0x00040000),
struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
]
def __init__(self, interface_name: str, vif_index: int):
# SEND SOCKET
s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_ICMPV6)
# set socket output interface
if_index = if_nametoindex(interface_name)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, struct.pack('@I', if_index))
"""
# set ICMP6 filter to only receive MLD packets
icmp6_filter = ICMP6_FILTER_SETBLOCKALL()
icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_QUERY_TYPE, icmp6_filter)
icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_REPORT_TYPE, icmp6_filter)
icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_DONE_TYPE, icmp6_filter)
s.setsockopt(socket.IPPROTO_ICMPV6, ICMP6_FILTER, icmp6_filter)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_RECVPKTINFO, True)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, False)
s.setsockopt(socket.IPPROTO_IPV6, self.IPV6_ROUTER_ALERT, 0)
rcv_s = s
"""
ip_interface = "::"
for if_addr in netifaces.ifaddresses(interface_name)[netifaces.AF_INET6]:
ip_interface = if_addr["addr"]
if ipaddress.IPv6Address(ip_interface.split("%")[0]).is_link_local:
# bind to interface
s.bind(socket.getaddrinfo(ip_interface, None, 0, socket.SOCK_RAW, 0, socket.AI_PASSIVE)[0][4])
ip_interface = ip_interface.split("%")[0]
break
self.ip_interface = ip_interface
# RECEIVE SOCKET
rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(ETH_P_IPV6))
# receive only MLD packets by setting a BPF filter
bpf_filter = b''.join(InterfaceMLD.FILTER_MLD)
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', len(InterfaceMLD.FILTER_MLD), mem_addr_of_filters)
rcv_s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
# bind to interface
rcv_s.bind((interface_name, ETH_P_IPV6))
super().__init__(interface_name=interface_name, recv_socket=rcv_s, send_socket=s, vif_index=vif_index)
self.interface_enabled = True
from .mld.RouterState import RouterState
self.interface_state = RouterState(self)
super()._enable()
@staticmethod
def _get_address_family():
return socket.AF_INET6
def get_ip(self):
return self.ip_interface
def send(self, data: bytes, address: str = "FF02::1"):
# send router alert option
cmsg_level = socket.IPPROTO_IPV6
cmsg_type = socket.IPV6_HOPOPTS
cmsg_data = b'\x3a\x00\x05\x02\x00\x00\x01\x00'
self._send_socket.sendmsg([data], [(cmsg_level, cmsg_type, cmsg_data)], 0, (address, 0))
"""
def receive(self):
while self.interface_enabled:
try:
(raw_bytes, ancdata, _, src_addr) = self._recv_socket.recvmsg(256 * 1024, 500)
if raw_bytes:
self._receive(raw_bytes, ancdata, src_addr)
except Exception:
import traceback
traceback.print_exc()
continue
"""
def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes:
raw_bytes = raw_bytes[14:]
src_addr = (socket.inet_ntop(socket.AF_INET6, raw_bytes[8:24]),)
print("MLD IP_SRC bf= ", src_addr)
dst_addr = raw_bytes[24:40]
(next_header,) = struct.unpack("B", raw_bytes[6:7])
print("NEXT HEADER:", next_header)
payload_starts_at_len = 40
if next_header == 0:
# Hop by Hop options
(next_header,) = struct.unpack("B", raw_bytes[40:41])
if next_header != 58:
return
(hdr_ext_len,) = struct.unpack("B", raw_bytes[payload_starts_at_len +1:payload_starts_at_len + 2])
if hdr_ext_len > 0:
payload_starts_at_len = payload_starts_at_len + 1 + hdr_ext_len*8
else:
payload_starts_at_len = payload_starts_at_len + 8
raw_bytes = raw_bytes[payload_starts_at_len:]
ancdata = [(socket.IPPROTO_IPV6, socket.IPV6_PKTINFO, dst_addr)]
print("RECEIVE MLD")
print("ANCDATA: ", ancdata, "; SRC_ADDR: ", src_addr)
packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 58, self)
ip_src = packet.ip_header.ip_src
print("MLD IP_SRC = ", ip_src)
if not (ip_src == "::" or IPv6Address(ip_src).is_multicast):
self.PKT_FUNCTIONS.get(packet.payload.get_mld_type(), InterfaceMLD.receive_unknown_type)(self, packet)
"""
def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes:
packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 58, self)
self.PKT_FUNCTIONS[packet.payload.get_mld_type(), InterfaceMLD.receive_unknown_type](self, packet)
"""
###########################################
# Recv packets
###########################################
def receive_multicast_listener_report(self, packet):
print("RECEIVE MULTICAST LISTENER REPORT")
ip_dst = packet.ip_header.ip_dst
mld_group = packet.payload.group_address
ipv6_group = IPv6Address(mld_group)
ipv6_dst = IPv6Address(ip_dst)
if ipv6_dst == ipv6_group and ipv6_group.is_multicast:
self.interface_state.receive_report(packet)
def receive_multicast_listener_done(self, packet):
print("RECEIVE MULTICAST LISTENER DONE")
ip_dst = packet.ip_header.ip_dst
mld_group = packet.payload.group_address
if IPv6Address(ip_dst) == self.IPv6_LINK_SCOPE_ALL_ROUTERS and IPv6Address(mld_group).is_multicast:
self.interface_state.receive_done(packet)
def receive_multicast_listener_query(self, packet):
print("RECEIVE MULTICAST LISTENER QUERY")
ip_dst = packet.ip_header.ip_dst
mld_group = packet.payload.group_address
ipv6_group = IPv6Address(mld_group)
ipv6_dst = IPv6Address(ip_dst)
if (ipv6_group.is_multicast and ipv6_dst == ipv6_group) or\
(ipv6_dst == self.IPv6_LINK_SCOPE_ALL_NODES and ipv6_group == self.IPv6_ALL_ZEROS):
self.interface_state.receive_query(packet)
def receive_unknown_type(self, packet):
raise Exception("UNKNOWN MLD TYPE: " + str(packet.payload.get_mld_type()))
PKT_FUNCTIONS = {
MULTICAST_LISTENER_REPORT_TYPE: receive_multicast_listener_report,
MULTICAST_LISTENER_DONE_TYPE: receive_multicast_listener_done,
MULTICAST_LISTENER_QUERY_TYPE: receive_multicast_listener_query,
}
##################
def remove(self):
super().remove()
self.interface_state.remove()
import socket
import random import random
from pimdm.Interface import Interface import logging
from pimdm.Packet.ReceivedPacket import ReceivedPacket import netifaces
from pimdm import Main
import traceback import traceback
from pimdm.RWLock.RWLock import RWLockWrite
from pimdm.Packet.PacketPimHelloOptions import *
from pimdm.Packet.PacketPimHello import PacketPimHello
from pimdm.Packet.PacketPimHeader import PacketPimHeader
from pimdm.Packet.Packet import Packet
from pimdm.utils import HELLO_HOLD_TIME_TIMEOUT
from threading import Timer from threading import Timer
from pimdm.tree.globals import REFRESH_INTERVAL
import socket from pimdm.Interface import Interface
import netifaces from pimdm.packet.ReceivedPacket import ReceivedPacket
import logging from pimdm import Main
from pimdm.rwlock.RWLock import RWLockWrite
from pimdm.packet.PacketPimHelloOptions import *
from pimdm.packet.PacketPimHello import PacketPimHello
from pimdm.packet.PacketPimHeader import PacketPimHeader
from pimdm.packet.Packet import Packet
from pimdm.tree.globals import HELLO_HOLD_TIME_TIMEOUT, REFRESH_INTERVAL
class InterfacePim(Interface): class InterfacePim(Interface):
...@@ -83,18 +83,37 @@ class InterfacePim(Interface): ...@@ -83,18 +83,37 @@ class InterfacePim(Interface):
self.force_send_hello() self.force_send_hello()
def get_ip(self): def get_ip(self):
"""
Get IP of this interface
"""
return self.ip_interface return self.ip_interface
def _receive(self, raw_bytes): @staticmethod
def get_kernel():
"""
Get Kernel object
"""
return Main.kernel
def _receive(self, raw_bytes, ancdata, src_addr):
"""
Interface received a new control packet
"""
if raw_bytes: if raw_bytes:
packet = ReceivedPacket(raw_bytes, self) packet = ReceivedPacket(raw_bytes, self)
self.PKT_FUNCTIONS[packet.payload.get_pim_type()](self, packet) self.PKT_FUNCTIONS.get(packet.payload.get_pim_type(), InterfacePim.receive_unknown)(self, packet)
def send(self, data: bytes, group_ip: str=MCAST_GRP): def send(self, data: bytes, group_ip: str=MCAST_GRP):
"""
Send a new packet destined to group_ip IP
"""
super().send(data=data, group_ip=group_ip) super().send(data=data, group_ip=group_ip)
#Random interval for initial Hello message on bootup or triggered Hello message to a rebooting neighbor #Random interval for initial Hello message on bootup or triggered Hello message to a rebooting neighbor
def force_send_hello(self): def force_send_hello(self):
"""
Force the transmission of a new Hello message
"""
if self.hello_timer is not None: if self.hello_timer is not None:
self.hello_timer.cancel() self.hello_timer.cancel()
...@@ -103,6 +122,10 @@ class InterfacePim(Interface): ...@@ -103,6 +122,10 @@ class InterfacePim(Interface):
self.hello_timer.start() self.hello_timer.start()
def send_hello(self): def send_hello(self):
"""
Send a new Hello message
Include in it the HelloHoldTime and GenerationID
"""
self.interface_logger.debug('Send Hello message') self.interface_logger.debug('Send Hello message')
self.hello_timer.cancel() self.hello_timer.cancel()
...@@ -125,6 +148,10 @@ class InterfacePim(Interface): ...@@ -125,6 +148,10 @@ class InterfacePim(Interface):
self.hello_timer.start() self.hello_timer.start()
def remove(self): def remove(self):
"""
Remove this interface
Clear all state
"""
self.hello_timer.cancel() self.hello_timer.cancel()
self.hello_timer = None self.hello_timer = None
...@@ -136,17 +163,20 @@ class InterfacePim(Interface): ...@@ -136,17 +163,20 @@ class InterfacePim(Interface):
packet = Packet(payload=ph) packet = Packet(payload=ph)
self.send(packet.bytes()) self.send(packet.bytes())
Main.kernel.interface_change_number_of_neighbors() self.get_kernel().interface_change_number_of_neighbors()
super().remove() super().remove()
def check_number_of_neighbors(self): def check_number_of_neighbors(self):
has_neighbors = len(self.neighbors) > 0 has_neighbors = len(self.neighbors) > 0
if has_neighbors != self._had_neighbors: if has_neighbors != self._had_neighbors:
self._had_neighbors = has_neighbors self._had_neighbors = has_neighbors
Main.kernel.interface_change_number_of_neighbors() self.get_kernel().interface_change_number_of_neighbors()
def new_or_reset_neighbor(self, neighbor_ip): def new_or_reset_neighbor(self, neighbor_ip):
Main.kernel.new_or_reset_neighbor(self.vif_index, neighbor_ip) """
React to new neighbor or restart of known neighbor
"""
self.get_kernel().new_or_reset_neighbor(self.vif_index, neighbor_ip)
''' '''
def add_neighbor(self, ip, random_number, hello_hold_time): def add_neighbor(self, ip, random_number, hello_hold_time):
...@@ -160,27 +190,44 @@ class InterfacePim(Interface): ...@@ -160,27 +190,44 @@ class InterfacePim(Interface):
''' '''
def get_neighbors(self): def get_neighbors(self):
"""
Get list of known neighbors
"""
with self.neighbors_lock.genRlock(): with self.neighbors_lock.genRlock():
return self.neighbors.values() return self.neighbors.values()
def get_neighbor(self, ip): def get_neighbor(self, ip):
"""
Get specific neighbor by its IP
"""
with self.neighbors_lock.genRlock(): with self.neighbors_lock.genRlock():
return self.neighbors.get(ip) return self.neighbors.get(ip)
def remove_neighbor(self, ip): def remove_neighbor(self, ip):
"""
Remove known neighbor
"""
with self.neighbors_lock.genWlock(): with self.neighbors_lock.genWlock():
del self.neighbors[ip] del self.neighbors[ip]
self.interface_logger.debug("Remove neighbor: " + ip) self.interface_logger.debug("Remove neighbor: " + ip)
self.check_number_of_neighbors() self.check_number_of_neighbors()
def set_state_refresh_capable(self, value): def set_state_refresh_capable(self, value):
"""
Change StateRefresh capability of interface
"""
self._state_refresh_capable = value self._state_refresh_capable = value
def is_state_refresh_enabled(self): def is_state_refresh_enabled(self):
"""
Check if state refresh is enabled
"""
return self._state_refresh_capable return self._state_refresh_capable
# check if Interface is StateRefreshCapable
def is_state_refresh_capable(self): def is_state_refresh_capable(self):
"""
Check StateRefresh capability of interface neighbors
"""
with self.neighbors_lock.genWlock(): with self.neighbors_lock.genWlock():
if len(self.neighbors) == 0: if len(self.neighbors) == 0:
return False return False
...@@ -214,6 +261,9 @@ class InterfacePim(Interface): ...@@ -214,6 +261,9 @@ class InterfacePim(Interface):
# Recv packets # Recv packets
########################################### ###########################################
def receive_hello(self, packet): def receive_hello(self, packet):
"""
Receive an Hello packet
"""
ip = packet.ip_header.ip_src ip = packet.ip_header.ip_src
print("ip = ", ip) print("ip = ", ip)
options = packet.payload.payload.get_options() options = packet.payload.payload.get_options()
...@@ -226,7 +276,6 @@ class InterfacePim(Interface): ...@@ -226,7 +276,6 @@ class InterfacePim(Interface):
state_refresh_capable = (21 in options) state_refresh_capable = (21 in options)
with self.neighbors_lock.genWlock(): with self.neighbors_lock.genWlock():
if ip not in self.neighbors: if ip not in self.neighbors:
if hello_hold_time == 0: if hello_hold_time == 0:
...@@ -244,17 +293,23 @@ class InterfacePim(Interface): ...@@ -244,17 +293,23 @@ class InterfacePim(Interface):
neighbor.receive_hello(generation_id, hello_hold_time, state_refresh_capable) neighbor.receive_hello(generation_id, hello_hold_time, state_refresh_capable)
def receive_assert(self, packet): def receive_assert(self, packet):
"""
Receive an Assert packet
"""
pkt_assert = packet.payload.payload # type: PacketPimAssert pkt_assert = packet.payload.payload # type: PacketPimAssert
source = pkt_assert.source_address source = pkt_assert.source_address
group = pkt_assert.multicast_group_address group = pkt_assert.multicast_group_address
source_group = (source, group) source_group = (source, group)
try: try:
Main.kernel.get_routing_entry(source_group).recv_assert_msg(self.vif_index, packet) self.get_kernel().get_routing_entry(source_group).recv_assert_msg(self.vif_index, packet)
except: except:
traceback.print_exc() traceback.print_exc()
def receive_join_prune(self, packet): def receive_join_prune(self, packet):
"""
Receive Join/Prune packet
"""
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
join_prune_groups = pkt_join_prune.groups join_prune_groups = pkt_join_prune.groups
...@@ -266,7 +321,7 @@ class InterfacePim(Interface): ...@@ -266,7 +321,7 @@ class InterfacePim(Interface):
for source_address in joined_src_addresses: for source_address in joined_src_addresses:
source_group = (source_address, multicast_group) source_group = (source_address, multicast_group)
try: try:
Main.kernel.get_routing_entry(source_group).recv_join_msg(self.vif_index, packet) self.get_kernel().get_routing_entry(source_group).recv_join_msg(self.vif_index, packet)
except: except:
traceback.print_exc() traceback.print_exc()
continue continue
...@@ -274,12 +329,15 @@ class InterfacePim(Interface): ...@@ -274,12 +329,15 @@ class InterfacePim(Interface):
for source_address in pruned_src_addresses: for source_address in pruned_src_addresses:
source_group = (source_address, multicast_group) source_group = (source_address, multicast_group)
try: try:
Main.kernel.get_routing_entry(source_group).recv_prune_msg(self.vif_index, packet) self.get_kernel().get_routing_entry(source_group).recv_prune_msg(self.vif_index, packet)
except: except:
traceback.print_exc() traceback.print_exc()
continue continue
def receive_graft(self, packet): def receive_graft(self, packet):
"""
Receive Graft packet
"""
pkt_join_prune = packet.payload.payload # type: PacketPimGraft pkt_join_prune = packet.payload.payload # type: PacketPimGraft
join_prune_groups = pkt_join_prune.groups join_prune_groups = pkt_join_prune.groups
...@@ -290,12 +348,15 @@ class InterfacePim(Interface): ...@@ -290,12 +348,15 @@ class InterfacePim(Interface):
for source_address in joined_src_addresses: for source_address in joined_src_addresses:
source_group = (source_address, multicast_group) source_group = (source_address, multicast_group)
try: try:
Main.kernel.get_routing_entry(source_group).recv_graft_msg(self.vif_index, packet) self.get_kernel().get_routing_entry(source_group).recv_graft_msg(self.vif_index, packet)
except: except:
traceback.print_exc() traceback.print_exc()
continue continue
def receive_graft_ack(self, packet): def receive_graft_ack(self, packet):
"""
Receive an GraftAck packet
"""
pkt_join_prune = packet.payload.payload # type: PacketPimGraftAck pkt_join_prune = packet.payload.payload # type: PacketPimGraftAck
join_prune_groups = pkt_join_prune.groups join_prune_groups = pkt_join_prune.groups
...@@ -306,12 +367,15 @@ class InterfacePim(Interface): ...@@ -306,12 +367,15 @@ class InterfacePim(Interface):
for source_address in joined_src_addresses: for source_address in joined_src_addresses:
source_group = (source_address, multicast_group) source_group = (source_address, multicast_group)
try: try:
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(self.vif_index, packet) self.get_kernel().get_routing_entry(source_group).recv_graft_ack_msg(self.vif_index, packet)
except: except:
traceback.print_exc() traceback.print_exc()
continue continue
def receive_state_refresh(self, packet): def receive_state_refresh(self, packet):
"""
Receive an StateRefresh packet
"""
if not self.is_state_refresh_enabled(): if not self.is_state_refresh_enabled():
return return
pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh
...@@ -320,10 +384,15 @@ class InterfacePim(Interface): ...@@ -320,10 +384,15 @@ class InterfacePim(Interface):
group = pkt_state_refresh.multicast_group_adress group = pkt_state_refresh.multicast_group_adress
source_group = (source, group) source_group = (source, group)
try: try:
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(self.vif_index, packet) self.get_kernel().get_routing_entry(source_group).recv_state_refresh_msg(self.vif_index, packet)
except: except:
traceback.print_exc() traceback.print_exc()
def receive_unknown(self, packet):
"""
Receive an unknown packet
"""
raise Exception("Unknown PIM type: " + str(packet.payload.get_pim_type()))
PKT_FUNCTIONS = { PKT_FUNCTIONS = {
0: receive_hello, 0: receive_hello,
......
import socket
import random
import struct
import logging
import ipaddress
import netifaces
from pimdm import Main
from socket import if_nametoindex
from pimdm.Interface import Interface
from .InterfacePIM import InterfacePim
from pimdm.rwlock.RWLock import RWLockWrite
from pimdm.packet.ReceivedPacket import ReceivedPacket_v6
class InterfacePim6(InterfacePim):
MCAST_GRP = "ff02::d"
def __init__(self, interface_name: str, vif_index:int, state_refresh_capable:bool=False):
# generation id
self.generation_id = random.getrandbits(32)
# When PIM is enabled on an interface or when a router first starts, the Hello Timer (HT)
# MUST be set to random value between 0 and Triggered_Hello_Delay
self.hello_timer = None
# state refresh capable
self._state_refresh_capable = state_refresh_capable
self._neighbors_state_refresh_capable = False
# todo: lan delay enabled
self._lan_delay_enabled = False
# todo: propagation delay
self._propagation_delay = self.PROPAGATION_DELAY
# todo: override interval
self._override_interval = self.OVERRIDE_INTERNAL
# pim neighbors
self._had_neighbors = False
self.neighbors = {}
self.neighbors_lock = RWLockWrite()
self.interface_logger = logging.LoggerAdapter(InterfacePim.LOGGER, {'vif': vif_index, 'interfacename': interface_name})
# SOCKET
s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_PIM)
ip_interface = ""
for if_addr in netifaces.ifaddresses(interface_name)[netifaces.AF_INET6]:
ip_interface = if_addr["addr"]
if ipaddress.IPv6Address(if_addr['addr'].split("%")[0]).is_link_local:
ip_interface = if_addr['addr'].split("%")[0]
# bind to interface
s.bind(socket.getaddrinfo(if_addr['addr'], None, 0, socket.SOCK_RAW, 0, socket.AI_PASSIVE)[0][4])
break
self.ip_interface = ip_interface
# allow other sockets to bind this port too
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# explicitly join the multicast group on the interface specified
if_index = if_nametoindex(interface_name)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP,
socket.inet_pton(socket.AF_INET6, InterfacePim6.MCAST_GRP) + struct.pack('@I', if_index))
s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8'))
# set socket output interface
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, struct.pack('@I', if_index))
# set socket TTL to 1
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 1)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_UNICAST_HOPS, 1)
# don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 0)
Interface.__init__(self, interface_name, s, s, vif_index)
Interface._enable(self)
self.force_send_hello()
@staticmethod
def get_kernel():
return Main.kernel_v6
def send(self, data: bytes, group_ip: str=MCAST_GRP):
super().send(data=data, group_ip=group_ip)
def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes:
packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 103, self)
self.PKT_FUNCTIONS[packet.payload.get_pim_type()](self, packet)
import socket import socket
import struct import struct
from threading import RLock, Thread
import traceback
import ipaddress import ipaddress
import traceback
from socket import if_nametoindex
from threading import RLock, Thread
from abc import abstractmethod, ABCMeta
from pimdm.RWLock.RWLock import RWLockWrite from pimdm import UnicastRouting, Main
from pimdm.rwlock.RWLock import RWLockWrite
from pimdm.InterfacePIM import InterfacePim from pimdm.InterfaceMLD import InterfaceMLD
from pimdm.InterfaceIGMP import InterfaceIGMP from pimdm.InterfaceIGMP import InterfaceIGMP
from pimdm.InterfacePIM import InterfacePim
from pimdm.InterfacePIM6 import InterfacePim6
from pimdm.tree.KernelEntry import KernelEntry from pimdm.tree.KernelEntry import KernelEntry
from pimdm import UnicastRouting, Main from pimdm.tree.KernelEntryInterface import KernelEntry4Interface, KernelEntry6Interface
class Kernel:
# MRT
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */
MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */
MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */
MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */
#MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */
#MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */
#MRT_MAX = (MRT_BASE + 11)
class Kernel(metaclass=ABCMeta):
# Max Number of Virtual Interfaces # Max Number of Virtual Interfaces
MAXVIFS = 32 MAXVIFS = 32
# SIGNAL MSG TYPE def __init__(self, kernel_socket):
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
# Interface flags
VIFF_TUNNEL = 0x1 # IPIP tunnel
VIFF_SRCRT = 0x2 # NI
VIFF_REGISTER = 0x4 # register vif
VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface
def __init__(self):
# Kernel is running # Kernel is running
self.running = True self.running = True
# KEY : interface_ip, VALUE : vif_index # KEY : interface_ip, VALUE : vif_index
self.vif_dic = {} self.vif_index_to_name_dic = {} # KEY : vif_index, VALUE : interface_name
self.vif_index_to_name_dic = {} self.vif_name_to_index_dic = {} # KEY : interface_name, VALUE : vif_index
self.vif_name_to_index_dic = {}
# KEY : source_ip, VALUE : {group_ip: KernelEntry} # KEY : source_ip, VALUE : {group_ip: KernelEntry}
self.routing = {} self.routing = {}
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP) self.socket = kernel_socket
# MRT INIT
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_PIM, 0)
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ASSERT, 1)
self.socket = s
self.rwlock = RWLockWrite() self.rwlock = RWLockWrite()
self.interface_lock = RLock() self.interface_lock = RLock()
...@@ -74,9 +40,9 @@ class Kernel: ...@@ -74,9 +40,9 @@ class Kernel:
# todo useless in PIM-DM... useful in PIM-SM # todo useless in PIM-DM... useful in PIM-SM
#self.create_virtual_interface("0.0.0.0", "pimreg", index=0, flags=Kernel.VIFF_REGISTER) #self.create_virtual_interface("0.0.0.0", "pimreg", index=0, flags=Kernel.VIFF_REGISTER)
# interfaces being monitored by this process
self.pim_interface = {} # name: interface_pim self.pim_interface = {} # name: interface_pim
self.igmp_interface = {} # name: interface_igmp self.membership_interface = {} # name: interface_igmp or interface_mld
# logs # logs
self.interface_logger = Main.logger.getChild('KernelInterface') self.interface_logger = Main.logger.getChild('KernelInterface')
...@@ -101,55 +67,43 @@ class Kernel: ...@@ -101,55 +67,43 @@ class Kernel:
struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */ struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */
}; };
''' '''
@abstractmethod
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0): def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0):
if type(ip_interface) is str: raise NotImplementedError
ip_interface = socket.inet_aton(ip_interface)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface,
socket.inet_aton("0.0.0.0"))
with self.rwlock.genWlock():
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_dic[socket.inet_ntoa(ip_interface)] = index
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.new_interface(index)
self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index)
return index
def create_pim_interface(self, interface_name: str, state_refresh_capable:bool): def create_pim_interface(self, interface_name: str, state_refresh_capable:bool):
with self.interface_lock: with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name) pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name) membership_interface = self.membership_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface vif_already_exists = pim_interface or membership_interface
if pim_interface: if pim_interface:
# already exists # already exists
pim_interface.set_state_refresh_capable(state_refresh_capable) pim_interface.set_state_refresh_capable(state_refresh_capable)
return return
elif igmp_interface: elif membership_interface:
index = igmp_interface.vif_index index = membership_interface.vif_index
else: else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0] index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None ip_interface = None
if interface_name not in self.pim_interface: if interface_name not in self.pim_interface:
pim_interface = InterfacePim(interface_name, index, state_refresh_capable) pim_interface = self._create_pim_interface_object(interface_name, index, state_refresh_capable)
self.pim_interface[interface_name] = pim_interface self.pim_interface[interface_name] = pim_interface
ip_interface = pim_interface.ip_interface ip_interface = pim_interface.ip_interface
if not vif_already_exists: if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index) self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
def create_igmp_interface(self, interface_name: str): @abstractmethod
def _create_pim_interface_object(self, interface_name, index, state_refresh_capable):
raise NotImplementedError
def create_membership_interface(self, interface_name: str):
with self.interface_lock: with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name) pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name) membership_interface = self.membership_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface vif_already_exists = pim_interface or membership_interface
if igmp_interface: if membership_interface:
# already exists # already exists
return return
elif pim_interface: elif pim_interface:
...@@ -158,47 +112,222 @@ class Kernel: ...@@ -158,47 +112,222 @@ class Kernel:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0] index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None ip_interface = None
if interface_name not in self.igmp_interface: if interface_name not in self.membership_interface:
igmp_interface = InterfaceIGMP(interface_name, index) igmp_interface = self._create_membership_interface_object(interface_name, index)
self.igmp_interface[interface_name] = igmp_interface self.membership_interface[interface_name] = igmp_interface
ip_interface = igmp_interface.ip_interface ip_interface = igmp_interface.ip_interface
if not vif_already_exists: if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index) self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
@abstractmethod
def _create_membership_interface_object(self, interface_name, index):
raise NotImplementedError
def remove_interface(self, interface_name, igmp:bool=False, pim:bool=False): def remove_interface(self, interface_name, membership: bool = False, pim: bool = False):
with self.interface_lock: with self.interface_lock:
ip_interface = None
pim_interface = self.pim_interface.get(interface_name) pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name) membership_interface = self.membership_interface.get(interface_name)
if (igmp and not igmp_interface) or (pim and not pim_interface) or (not igmp and not pim): if (membership and not membership_interface) or (pim and not pim_interface) or (not membership and not pim):
return return
if pim: if pim:
pim_interface = self.pim_interface.pop(interface_name) pim_interface = self.pim_interface.pop(interface_name)
ip_interface = pim_interface.ip_interface
pim_interface.remove() pim_interface.remove()
elif igmp: elif membership:
igmp_interface = self.igmp_interface.pop(interface_name) membership_interface = self.membership_interface.pop(interface_name)
ip_interface = igmp_interface.ip_interface membership_interface.remove()
igmp_interface.remove()
if not self.membership_interface.get(interface_name) and not self.pim_interface.get(interface_name):
self.remove_virtual_interface(interface_name)
@abstractmethod
def remove_virtual_interface(self, interface_name):
raise NotImplementedError
#############################################
# Manipulate multicast routing table
#############################################
@abstractmethod
def set_multicast_route(self, kernel_entry: KernelEntry):
raise NotImplementedError
@abstractmethod
def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index):
raise NotImplementedError
@abstractmethod
def remove_multicast_route(self, kernel_entry: KernelEntry):
raise NotImplementedError
@abstractmethod
def exit(self):
raise NotImplementedError
@abstractmethod
def handler(self):
raise NotImplementedError
def get_routing_entry(self, source_group: tuple, create_if_not_existent=True):
ip_src = source_group[0]
ip_dst = source_group[1]
with self.rwlock.genRlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
with self.rwlock.genWlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
elif create_if_not_existent:
kernel_entry = KernelEntry(ip_src, ip_dst, self._get_kernel_entry_interface())
if ip_src not in self.routing:
self.routing[ip_src] = {}
iif = UnicastRouting.check_rpf(ip_src)
self.set_flood_multicast_route(ip_src, ip_dst, iif)
self.routing[ip_src][ip_dst] = kernel_entry
return kernel_entry
else:
return None
@staticmethod
@abstractmethod
def _get_kernel_entry_interface():
pass
# notify KernelEntries about changes at the unicast routing table
def notify_unicast_changes(self, subnet):
with self.rwlock.genWlock():
for source_ip in list(self.routing.keys()):
source_ip_obj = ipaddress.ip_address(source_ip)
if source_ip_obj not in subnet:
continue
for group_ip in list(self.routing[source_ip].keys()):
self.routing[source_ip][group_ip].network_update()
# notify about changes at the interface (IP)
'''
def notify_interface_change(self, interface_name):
with self.interface_lock:
# check if interface was already added
if interface_name not in self.vif_name_to_index_dic:
return
print("trying to change ip")
pim_interface = self.pim_interface.get(interface_name)
if pim_interface:
old_ip = pim_interface.get_ip()
pim_interface.change_interface()
new_ip = pim_interface.get_ip()
if old_ip != new_ip:
self.vif_dic[new_ip] = self.vif_dic.pop(old_ip)
igmp_interface = self.igmp_interface.get(interface_name)
if igmp_interface:
igmp_interface.change_interface()
'''
# When interface changes number of neighbors verify if olist changes and prune/forward respectively
def interface_change_number_of_neighbors(self):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.change_at_number_of_neighbors()
# When new neighbor connects try to resend last state refresh msg (if AssertWinner)
def new_or_reset_neighbor(self, vif_index, neighbor_ip):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.new_or_reset_neighbor(vif_index, neighbor_ip)
class Kernel4(Kernel):
# MRT
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */
MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */
MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */
MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */
#MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */
#MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */
#MRT_MAX = (MRT_BASE + 11)
# Max Number of Virtual Interfaces
MAXVIFS = 32
# SIGNAL MSG TYPE
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
# Interface flags
VIFF_TUNNEL = 0x1 # IPIP tunnel
VIFF_SRCRT = 0x2 # NI
VIFF_REGISTER = 0x4 # register vif
VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface
if (not self.igmp_interface.get(interface_name) and not self.pim_interface.get(interface_name)): def __init__(self):
self.remove_virtual_interface(ip_interface) s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# MRT INIT
s.setsockopt(socket.IPPROTO_IP, self.MRT_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IP, self.MRT_PIM, 0)
s.setsockopt(socket.IPPROTO_IP, self.MRT_ASSERT, 1)
super().__init__(s)
'''
Structure to create/remove virtual interfaces
struct vifctl {
vifi_t vifc_vifi; /* Index of VIF */
unsigned char vifc_flags; /* VIFF_ flags */
unsigned char vifc_threshold; /* ttl limit */
unsigned int vifc_rate_limit; /* Rate limiter values (NI) */
union {
struct in_addr vifc_lcl_addr; /* Local interface address */
int vifc_lcl_ifindex; /* Local interface index */
};
struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */
};
'''
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0):
if type(ip_interface) is str:
ip_interface = socket.inet_aton(ip_interface)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface,
socket.inet_aton("0.0.0.0"))
with self.rwlock.genWlock():
self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.new_interface(index)
def remove_virtual_interface(self, ip_interface): self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index)
return index
def remove_virtual_interface(self, interface_name):
#with self.interface_lock: #with self.interface_lock:
index = self.vif_dic[ip_interface] index = self.vif_name_to_index_dic.pop(interface_name, None)
struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0")) struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_VIF, struct_vifctl) self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DEL_VIF, struct_vifctl)
del self.vif_dic[ip_interface]
del self.vif_name_to_index_dic[self.vif_index_to_name_dic[index]] del self.vif_name_to_index_dic[self.vif_index_to_name_dic[index]]
interface_name = self.vif_index_to_name_dic.pop(index) interface_name = self.vif_index_to_name_dic.pop(index)
# alterar MFC's para colocar a 0 esta interface # change MFC's to not forward traffic by this interface (set OIL to 0 for this interface)
with self.rwlock.genWlock(): with self.rwlock.genWlock():
for source_dict in list(self.routing.values()): for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()): for kernel_entry in list(source_dict.values()):
...@@ -235,7 +364,7 @@ class Kernel: ...@@ -235,7 +364,7 @@ class Kernel:
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5 #outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0) #struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, kernel_entry.inbound_interface_index, *outbound_interfaces_and_other_parameters) struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, kernel_entry.inbound_interface_index, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl) self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_MFC, struct_mfcctl)
def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index): def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index):
source_ip = socket.inet_aton(source_ip) source_ip = socket.inet_aton(source_ip)
...@@ -250,7 +379,7 @@ class Kernel: ...@@ -250,7 +379,7 @@ class Kernel:
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5 #outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0) #struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters) struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl) self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_MFC, struct_mfcctl)
def remove_multicast_route(self, kernel_entry: KernelEntry): def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip) source_ip = socket.inet_aton(kernel_entry.source_ip)
...@@ -258,7 +387,7 @@ class Kernel: ...@@ -258,7 +387,7 @@ class Kernel:
outbound_interfaces_and_other_parameters = [0] + [0]*Kernel.MAXVIFS + [0]*4 outbound_interfaces_and_other_parameters = [0] + [0]*Kernel.MAXVIFS + [0]*4
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters) struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_MFC, struct_mfcctl) self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DEL_MFC, struct_mfcctl)
self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip) self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip)
if len(self.routing[kernel_entry.source_ip]) == 0: if len(self.routing[kernel_entry.source_ip]) == 0:
self.routing.pop(kernel_entry.source_ip) self.routing.pop(kernel_entry.source_ip)
...@@ -267,7 +396,7 @@ class Kernel: ...@@ -267,7 +396,7 @@ class Kernel:
self.running = False self.running = False
# MRT DONE # MRT DONE
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DONE, 1) self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DONE, 1)
self.socket.close() self.socket.close()
...@@ -304,10 +433,10 @@ class Kernel: ...@@ -304,10 +433,10 @@ class Kernel:
ip_src = socket.inet_ntoa(im_src) ip_src = socket.inet_ntoa(im_src)
ip_dst = socket.inet_ntoa(im_dst) ip_dst = socket.inet_ntoa(im_dst)
if im_msgtype == Kernel.IGMPMSG_NOCACHE: if im_msgtype == self.IGMPMSG_NOCACHE:
print("IGMP NO CACHE") print("IGMP NO CACHE")
self.igmpmsg_nocache_handler(ip_src, ip_dst, im_vif) self.igmpmsg_nocache_handler(ip_src, ip_dst, im_vif)
elif im_msgtype == Kernel.IGMPMSG_WRONGVIF: elif im_msgtype == self.IGMPMSG_WRONGVIF:
print("WRONG VIF HANDLER") print("WRONG VIF HANDLER")
self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif) self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif)
#elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT: #elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT:
...@@ -338,73 +467,266 @@ class Kernel: ...@@ -338,73 +467,266 @@ class Kernel:
#kernel_entry.recv_data_msg(iif) #kernel_entry.recv_data_msg(iif)
''' '''
@staticmethod
def _get_kernel_entry_interface():
return KernelEntry4Interface
def _create_pim_interface_object(self, interface_name, index, state_refresh_capable):
return InterfacePim(interface_name, index, state_refresh_capable)
def _create_membership_interface_object(self, interface_name, index):
return InterfaceIGMP(interface_name, index)
class Kernel6(Kernel):
# MRT6
MRT6_BASE = 200
MRT6_INIT = (MRT6_BASE) # /* Activate the kernel mroute code */
MRT6_DONE = (MRT6_BASE + 1) # /* Shutdown the kernel mroute */
MRT6_ADD_MIF = (MRT6_BASE + 2) # /* Add a virtual interface */
MRT6_DEL_MIF = (MRT6_BASE + 3) # /* Delete a virtual interface */
MRT6_ADD_MFC = (MRT6_BASE + 4) # /* Add a multicast forwarding entry */
MRT6_DEL_MFC = (MRT6_BASE + 5) # /* Delete a multicast forwarding entry */
MRT6_VERSION = (MRT6_BASE + 6) # /* Get the kernel multicast version */
MRT6_ASSERT = (MRT6_BASE + 7) # /* Activate PIM assert mode */
MRT6_PIM = (MRT6_BASE + 8) # /* enable PIM code */
MRT6_TABLE = (MRT6_BASE + 9) # /* Specify mroute table ID */
MRT6_ADD_MFC_PROXY = (MRT6_BASE + 10) # /* Add a (*,*|G) mfc entry */
MRT6_DEL_MFC_PROXY = (MRT6_BASE + 11) # /* Del a (*,*|G) mfc entry */
MRT6_MAX = (MRT6_BASE + 11)
# define SIOCGETMIFCNT_IN6 SIOCPROTOPRIVATE /* IP protocol privates */
# define SIOCGETSGCNT_IN6 (SIOCPROTOPRIVATE+1)
# define SIOCGETRPF (SIOCPROTOPRIVATE+2)
# Max Number of Virtual Interfaces
MAXVIFS = 32
def get_routing_entry(self, source_group: tuple, create_if_not_existent=True): # SIGNAL MSG TYPE
ip_src = source_group[0] MRT6MSG_NOCACHE = 1
ip_dst = source_group[1] MRT6MSG_WRONGMIF = 2
with self.rwlock.genRlock(): MRT6MSG_WHOLEPKT = 3 # /* used for use level encap */
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst] # Interface flags
MIFF_REGISTER = 0x1 # /* register vif */
def __init__(self):
s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_ICMPV6)
# MRT INIT
s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_PIM, 0)
s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ASSERT, 1)
super().__init__(s)
'''
Structure to create/remove multicast interfaces
struct mif6ctl {
mifi_t mif6c_mifi; /* Index of MIF */
unsigned char mif6c_flags; /* MIFF_ flags */
unsigned char vifc_threshold; /* ttl limit */
__u16 mif6c_pifi; /* the index of the physical IF */
unsigned int vifc_rate_limit; /* Rate limiter values (NI) */
};
'''
def create_virtual_interface(self, ip_interface, interface_name: str, index, flags=0x0):
physical_if_index = if_nametoindex(interface_name)
struct_mrt_add_vif = struct.pack("HBBHI", index, flags, 1, physical_if_index, 0)
with self.rwlock.genWlock(): with self.rwlock.genWlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]: self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MIF, struct_mrt_add_vif)
return self.routing[ip_src][ip_dst] self.vif_index_to_name_dic[index] = interface_name
elif create_if_not_existent: self.vif_name_to_index_dic[interface_name] = index
kernel_entry = KernelEntry(ip_src, ip_dst)
if ip_src not in self.routing:
self.routing[ip_src] = {}
iif = UnicastRouting.check_rpf(ip_src) for source_dict in list(self.routing.values()):
self.set_flood_multicast_route(ip_src, ip_dst, iif) for kernel_entry in list(source_dict.values()):
self.routing[ip_src][ip_dst] = kernel_entry kernel_entry.new_interface(index)
return kernel_entry
else:
return None
# notify KernelEntries about changes at the unicast routing table self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index)
def notify_unicast_changes(self, subnet): return index
def remove_virtual_interface(self, interface_name):
# with self.interface_lock:
mif_index = self.vif_name_to_index_dic.pop(interface_name, None)
interface_name = self.vif_index_to_name_dic.pop(mif_index)
physical_if_index = if_nametoindex(interface_name)
struct_vifctl = struct.pack("HBBHI", mif_index, 0, 0, physical_if_index, 0)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DEL_MIF, struct_vifctl)
# alterar MFC's para colocar a 0 esta interface
with self.rwlock.genWlock(): with self.rwlock.genWlock():
for source_ip in list(self.routing.keys()): for source_dict in list(self.routing.values()):
source_ip_obj = ipaddress.ip_address(source_ip) for kernel_entry in list(source_dict.values()):
if source_ip_obj not in subnet: kernel_entry.remove_interface(mif_index)
continue
for group_ip in list(self.routing[source_ip].keys()):
self.routing[source_ip][group_ip].network_update()
self.interface_logger.debug('Remove virtual interface: %s -> %d', interface_name, mif_index)
# notify about changes at the interface (IP)
''' '''
def notify_interface_change(self, interface_name): /* Cache manipulation structures for mrouted and PIMd */
with self.interface_lock: typedef __u32 if_mask;
# check if interface was already added typedef struct if_set {
if interface_name not in self.vif_name_to_index_dic: if_mask ifs_bits[__KERNEL_DIV_ROUND_UP(IF_SETSIZE, NIFBITS)];
return } if_set;
struct mf6cctl {
struct sockaddr_in6 mf6cc_origin; /* Origin of mcast */
struct sockaddr_in6 mf6cc_mcastgrp; /* Group in question */
mifi_t mf6cc_parent; /* Where it arrived */
struct if_set mf6cc_ifset; /* Where it is going */
};
'''
def set_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.source_ip)
sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0)
group_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.group_ip)
sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0)
print("trying to change ip") outbound_interfaces = kernel_entry.get_outbound_interfaces_indexes()
pim_interface = self.pim_interface.get(interface_name) if len(outbound_interfaces) != 8:
if pim_interface: raise Exception
old_ip = pim_interface.get_ip()
pim_interface.change_interface()
new_ip = pim_interface.get_ip()
if old_ip != new_ip:
self.vif_dic[new_ip] = self.vif_dic.pop(old_ip)
igmp_interface = self.igmp_interface.get(interface_name) # outbound_interfaces_and_other_parameters = list(kernel_entry.outbound_interfaces) + [0]*4
if igmp_interface: # outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4
igmp_interface.change_interface() outgoing_interface_list = outbound_interfaces
# outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group,
kernel_entry.inbound_interface_index,
*outgoing_interface_list)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MFC, struct_mf6cctl)
def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index):
source_ip = socket.inet_pton(socket.AF_INET6, source_ip)
sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0)
group_ip = socket.inet_pton(socket.AF_INET6, group_ip)
sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0)
outbound_interfaces = [255] * 8
outbound_interfaces[inbound_interface_index // 32] = 0xFFFFFFFF & ~(1 << (inbound_interface_index % 32))
# outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters)
struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group,
inbound_interface_index, *outbound_interfaces)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MFC, struct_mf6cctl)
def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.source_ip)
sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0)
group_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.group_ip)
sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0)
outbound_interfaces = [0] * 8
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters)
struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group, 0,
*outbound_interfaces)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DEL_MFC, struct_mf6cctl)
self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip)
if len(self.routing[kernel_entry.source_ip]) == 0:
self.routing.pop(kernel_entry.source_ip)
def exit(self):
self.running = False
# MRT DONE
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DONE, 1)
self.socket.close()
'''
/*
* Structure used to communicate from kernel to multicast router.
* We'll overlay the structure onto an MLD header (not an IPv6 heder like igmpmsg{}
* used for IPv4 implementation). This is because this structure will be passed via an
* IPv6 raw socket, on which an application will only receiver the payload i.e the data after
* the IPv6 header and all the extension headers. (See section 3 of RFC 3542)
*/
struct mrt6msg {
__u8 im6_mbz; /* must be zero */
__u8 im6_msgtype; /* what type of message */
__u16 im6_mif; /* mif rec'd on */
__u32 im6_pad; /* padding for 64 bit arch */
struct in6_addr im6_src, im6_dst;
};
/* ip6mr netlink cache report attributes */
enum {
IP6MRA_CREPORT_UNSPEC,
IP6MRA_CREPORT_MSGTYPE,
IP6MRA_CREPORT_MIF_ID,
IP6MRA_CREPORT_SRC_ADDR,
IP6MRA_CREPORT_DST_ADDR,
IP6MRA_CREPORT_PKT,
__IP6MRA_CREPORT_MAX
};
''' '''
def handler(self):
while self.running:
try:
msg = self.socket.recv(500)
if len(msg) < 40:
continue
(im6_mbz, im6_msgtype, im6_mif, _, im6_src, im6_dst) = struct.unpack("B B H I 16s 16s", msg[:40])
# print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
# When interface changes number of neighbors verify if olist changes and prune/forward respectively if im6_mbz != 0:
def interface_change_number_of_neighbors(self): continue
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.change_at_number_of_neighbors()
# When new neighbor connects try to resend last state refresh msg (if AssertWinner) print(im6_mbz)
def new_or_reset_neighbor(self, vif_index, neighbor_ip): print(im6_msgtype)
with self.rwlock.genRlock(): print(im6_mif)
for groups_dict in self.routing.values(): print(socket.inet_ntop(socket.AF_INET6, im6_src))
for entry in groups_dict.values(): print(socket.inet_ntop(socket.AF_INET6, im6_dst))
entry.new_or_reset_neighbor(vif_index, neighbor_ip) # print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
ip_src = socket.inet_ntop(socket.AF_INET6, im6_src)
ip_dst = socket.inet_ntop(socket.AF_INET6, im6_dst)
if im6_msgtype == self.MRT6MSG_NOCACHE:
print("MRT6 NO CACHE")
self.msg_nocache_handler(ip_src, ip_dst, im6_mif)
elif im6_msgtype == self.MRT6MSG_WRONGMIF:
print("WRONG MIF HANDLER")
self.msg_wrongvif_handler(ip_src, ip_dst, im6_mif)
# elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT:
# print("IGMP_WHOLEPKT")
# self.igmpmsg_wholepacket_handler(ip_src, ip_dst)
else:
raise Exception
except Exception:
traceback.print_exc()
continue
# receive multicast (S,G) packet and multicast routing table has no (S,G) entry
def msg_nocache_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
# receive multicast (S,G) packet in a outbound_interface
def msg_wrongvif_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
''' useless in PIM-DM... useful in PIM-SM
def msg_wholepacket_handler(self, ip_src, ip_dst):
#kernel_entry = self.routing[(ip_src, ip_dst)]
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg()
#kernel_entry.recv_data_msg(iif)
'''
@staticmethod
def _get_kernel_entry_interface():
return KernelEntry6Interface
def _create_pim_interface_object(self, interface_name, index, state_refresh_capable):
return InterfacePim6(interface_name, index, state_refresh_capable)
def _create_membership_interface_object(self, interface_name, index):
return InterfaceMLD(interface_name, index)
import sys import sys
import time import time
import netifaces import netifaces
import logging, logging.handlers import logging
import logging.handlers
from prettytable import PrettyTable from prettytable import PrettyTable
from pimdm.TestLogger import RootFilter
from pimdm import UnicastRouting
from pimdm import UnicastRouting
from pimdm.TestLogger import RootFilter
interfaces = {} # interfaces with multicast routing enabled interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces igmp_interfaces = {} # igmp interfaces
interfaces_v6 = {} # pim v6 interfaces
mld_interfaces = {} # mld interfaces
kernel = None kernel = None
kernel_v6 = None
unicast_routing = None unicast_routing = None
logger = None logger = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False):
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable) def add_pim_interface(interface_name, state_refresh_capable: bool = False, ipv4=True, ipv6=False):
if interface_name == "*":
for interface_name in netifaces.interfaces():
def add_igmp_interface(interface_name): add_pim_interface(interface_name, ipv4, ipv6)
kernel.create_igmp_interface(interface_name=interface_name) return
''' if ipv4 and kernel is not None:
def add_interface(interface_name, pim=False, igmp=False): kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
#if pim is True and interface_name not in interfaces: if ipv6 and kernel_v6 is not None:
# interface = InterfacePim(interface_name) kernel_v6.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
# interfaces[interface_name] = interface
# interface.create_virtual_interface()
#if igmp is True and interface_name not in igmp_interfaces: def add_membership_interface(interface_name, ipv4=True, ipv6=False):
# interface = InterfaceIGMP(interface_name) if interface_name == "*":
# igmp_interfaces[interface_name] = interface for interface_name in netifaces.interfaces():
kernel.create_interface(interface_name=interface_name, pim=pim, igmp=igmp) add_membership_interface(interface_name, ipv4, ipv6)
#if pim: return
# interfaces[interface_name] = kernel.pim_interface[interface_name]
#if igmp: if ipv4 and kernel is not None:
# igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name] kernel.create_membership_interface(interface_name=interface_name)
''' if ipv6 and kernel_v6 is not None:
kernel_v6.create_membership_interface(interface_name=interface_name)
def remove_interface(interface_name, pim=False, igmp=False):
#if pim is True and ((interface_name in interfaces) or interface_name == "*"):
# if interface_name == "*": def remove_interface(interface_name, pim=False, membership=False, ipv4=True, ipv6=False):
# interface_name_list = list(interfaces.keys()) if interface_name == "*":
# else: for interface_name in netifaces.interfaces():
# interface_name_list = [interface_name] remove_interface(interface_name, pim, membership, ipv4, ipv6)
# for if_name in interface_name_list: return
# interface_obj = interfaces.pop(if_name)
# interface_obj.remove() if ipv4 and kernel is not None:
# #interfaces[if_name].remove() kernel.remove_interface(interface_name, pim=pim, membership=membership)
# #del interfaces[if_name] if ipv6 and kernel_v6 is not None:
# print("removido interface") kernel_v6.remove_interface(interface_name, pim=pim, membership=membership)
# print(interfaces)
#if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"): def list_neighbors(ipv4=False, ipv6=False):
# if interface_name == "*": if ipv4:
# interface_name_list = list(igmp_interfaces.keys()) interfaces_list = interfaces.values()
# else: elif ipv6:
# interface_name_list = [interface_name] interfaces_list = interfaces_v6.values()
# for if_name in interface_name_list: else:
# igmp_interfaces[if_name].remove() return "Unknown IP family"
# del igmp_interfaces[if_name]
# print("removido interface")
# print(igmp_interfaces)
kernel.remove_interface(interface_name, pim=pim, igmp=igmp)
def list_neighbors():
interfaces_list = interfaces.values()
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"]) t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
check_time = time.time() check_time = time.time()
for interface in interfaces_list: for interface in interfaces_list:
...@@ -76,38 +74,62 @@ def list_neighbors(): ...@@ -76,38 +74,62 @@ def list_neighbors():
print(t) print(t)
return str(t) return str(t)
def list_enabled_interfaces():
global interfaces
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'State Refresh Enabled', 'IGMP State']) def list_enabled_interfaces(ipv4=False, ipv6=False):
if ipv4:
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'State Refresh Enabled', 'IGMP State'])
family = netifaces.AF_INET
pim_interfaces = interfaces
membership_interfaces = igmp_interfaces
elif ipv6:
t = PrettyTable(['Interface', 'IP', 'PIM/MLD Enabled', 'State Refresh Enabled', 'MLD State'])
family = netifaces.AF_INET6
pim_interfaces = interfaces_v6
membership_interfaces = mld_interfaces
else:
return "Unknown IP family"
for interface in netifaces.interfaces(): for interface in netifaces.interfaces():
try: try:
# TODO: fix same interface with multiple ips # TODO: fix same interface with multiple ips
ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'] ip = netifaces.ifaddresses(interface)[family][0]['addr']
pim_enabled = interface in interfaces pim_enabled = interface in pim_interfaces
igmp_enabled = interface in igmp_interfaces membership_enabled = interface in membership_interfaces
enabled = str(pim_enabled) + "/" + str(igmp_enabled) enabled = str(pim_enabled) + "/" + str(membership_enabled)
state_refresh_enabled = "-" state_refresh_enabled = "-"
if pim_enabled: if pim_enabled:
state_refresh_enabled = interfaces[interface].is_state_refresh_enabled() state_refresh_enabled = pim_interfaces[interface].is_state_refresh_enabled()
igmp_state = "-" membership_state = "-"
if igmp_enabled: if membership_enabled:
igmp_state = igmp_interfaces[interface].interface_state.print_state() membership_state = membership_interfaces[interface].interface_state.print_state()
t.add_row([interface, ip, enabled, state_refresh_enabled, igmp_state]) t.add_row([interface, ip, enabled, state_refresh_enabled, membership_state])
except Exception: except Exception:
continue continue
print(t) print(t)
return str(t) return str(t)
def list_state(): def list_state(ipv4=True, ipv6=False):
state_text = "IGMP State:\n" + list_igmp_state() + "\n\n\n\n" + "Multicast Routing State:\n" + list_routing_state() state_text = ""
return state_text if ipv4:
state_text = "IGMP State:\n{}\n\n\n\nMulticast Routing State:\n{}"
elif ipv6:
state_text = "MLD State:\n{}\n\n\n\nMulticast Routing State:\n{}"
else:
return state_text
return state_text.format(list_membership_state(ipv4, ipv6), list_routing_state(ipv4, ipv6))
def list_igmp_state(): def list_membership_state(ipv4=True, ipv6=False):
t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState']) t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState'])
for (interface_name, interface_obj) in list(igmp_interfaces.items()): if ipv4:
membership_interfaces = igmp_interfaces
elif ipv6:
membership_interfaces = mld_interfaces
else:
membership_interfaces = {}
for (interface_name, interface_obj) in list(membership_interfaces.items()):
interface_state = interface_obj.interface_state interface_state = interface_obj.interface_state
state_txt = interface_state.print_state() state_txt = interface_state.print_state()
print(interface_state.group_state.items()) print(interface_state.group_state.items())
...@@ -119,12 +141,22 @@ def list_igmp_state(): ...@@ -119,12 +141,22 @@ def list_igmp_state():
return str(t) return str(t)
def list_routing_state(): def list_routing_state(ipv4=False, ipv6=False):
if ipv4:
routes = kernel.routing.values()
vif_indexes = kernel.vif_index_to_name_dic.keys()
dict_index_to_name = kernel.vif_index_to_name_dic
elif ipv6:
routes = kernel_v6.routing.values()
vif_indexes = kernel_v6.vif_index_to_name_dic.keys()
dict_index_to_name = kernel_v6.vif_index_to_name_dic
else:
raise Exception("Unknown IP family")
routing_entries = [] routing_entries = []
for a in list(kernel.routing.values()): for a in list(routes):
for b in list(a.values()): for b in list(a.values()):
routing_entries.append(b) routing_entries.append(b)
vif_indexes = kernel.vif_index_to_name_dic.keys()
t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', 'LocalMembership', "Is Forwarding?"]) t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', 'LocalMembership', "Is Forwarding?"])
for entry in routing_entries: for entry in routing_entries:
...@@ -134,7 +166,7 @@ def list_routing_state(): ...@@ -134,7 +166,7 @@ def list_routing_state():
for index in vif_indexes: for index in vif_indexes:
interface_state = entry.interface_state[index] interface_state = entry.interface_state[index]
interface_name = kernel.vif_index_to_name_dic[index] interface_name = dict_index_to_name[index]
local_membership = type(interface_state._local_membership_state).__name__ local_membership = type(interface_state._local_membership_state).__name__
try: try:
assert_state = type(interface_state._assert_state).__name__ assert_state = type(interface_state._assert_state).__name__
...@@ -154,8 +186,11 @@ def list_routing_state(): ...@@ -154,8 +186,11 @@ def list_routing_state():
def stop(): def stop():
remove_interface("*", pim=True, igmp=True) remove_interface("*", pim=True, membership=True, ipv4=True, ipv6=True)
kernel.exit() if kernel is not None:
kernel.exit()
if kernel_v6 is not None:
kernel_v6.exit()
unicast_routing.stop() unicast_routing.stop()
...@@ -169,6 +204,22 @@ def test(router_name, server_logger_ip): ...@@ -169,6 +204,22 @@ def test(router_name, server_logger_ip):
logger.addHandler(socketHandler) logger.addHandler(socketHandler)
def enable_ipv6_kernel():
"""
Function to explicitly enable IPv6 Multicast Routing stack.
This may not be enabled by default due to some old linux kernels that may not have IPv6 stack or do not have
IPv6 multicast routing support
"""
global kernel_v6
from pimdm.Kernel import Kernel6
kernel_v6 = Kernel6()
global interfaces_v6
global mld_interfaces
interfaces_v6 = kernel_v6.pim_interface
mld_interfaces = kernel_v6.membership_interface
def main(): def main():
# logging # logging
global logger global logger
...@@ -177,8 +228,8 @@ def main(): ...@@ -177,8 +228,8 @@ def main():
logger.addHandler(logging.StreamHandler(sys.stdout)) logger.addHandler(logging.StreamHandler(sys.stdout))
global kernel global kernel
from pimdm.Kernel import Kernel from pimdm.Kernel import Kernel4
kernel = Kernel() kernel = Kernel4()
global unicast_routing global unicast_routing
unicast_routing = UnicastRouting.UnicastRouting() unicast_routing = UnicastRouting.UnicastRouting()
...@@ -186,4 +237,9 @@ def main(): ...@@ -186,4 +237,9 @@ def main():
global interfaces global interfaces
global igmp_interfaces global igmp_interfaces
interfaces = kernel.pim_interface interfaces = kernel.pim_interface
igmp_interfaces = kernel.igmp_interface igmp_interfaces = kernel.membership_interface
try:
enable_ipv6_kernel()
except:
pass
from threading import Timer
import time import time
from pimdm.utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING
from threading import Lock, RLock
import logging import logging
from threading import Timer
from threading import Lock, RLock
from pimdm.tree.globals import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from pimdm.InterfacePIM import InterfacePim from pimdm.InterfacePIM import InterfacePim
...@@ -10,7 +12,6 @@ if TYPE_CHECKING: ...@@ -10,7 +12,6 @@ if TYPE_CHECKING:
class Neighbor: class Neighbor:
LOGGER = logging.getLogger('pim.Interface.Neighbor') LOGGER = logging.getLogger('pim.Interface.Neighbor')
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int, def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int,
state_refresh_capable: bool): state_refresh_capable: bool):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT: if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
...@@ -37,7 +38,6 @@ class Neighbor: ...@@ -37,7 +38,6 @@ class Neighbor:
self.tree_interface_nlt_subscribers = [] self.tree_interface_nlt_subscribers = []
self.tree_interface_nlt_subscribers_lock = RLock() self.tree_interface_nlt_subscribers_lock = RLock()
def set_hello_hold_time(self, hello_hold_time: int): def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time self.hello_hold_time = hello_hold_time
if self.neighbor_liveness_timer is not None: if self.neighbor_liveness_timer is not None:
...@@ -85,11 +85,9 @@ class Neighbor: ...@@ -85,11 +85,9 @@ class Neighbor:
for tree_if in self.tree_interface_nlt_subscribers: for tree_if in self.tree_interface_nlt_subscribers:
tree_if.assert_winner_nlt_expires() tree_if.assert_winner_nlt_expires()
def reset(self): def reset(self):
self.contact_interface.new_or_reset_neighbor(self.ip) self.contact_interface.new_or_reset_neighbor(self.ip)
def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable): def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable):
self.neighbor_logger.debug('Receive Hello message with HelloHoldTime: ' + str(hello_hold_time) + self.neighbor_logger.debug('Receive Hello message with HelloHoldTime: ' + str(hello_hold_time) +
'; GenerationID: ' + str(generation_id) + '; StateRefreshCapable: ' + '; GenerationID: ' + str(generation_id) + '; StateRefreshCapable: ' +
......
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIpHeader:
IP_HDR = "! BBH HH BBH 4s 4s"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
self.version = ver
self.hdr_length = hdr_len
self.ttl = ttl
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, data)
ver = (verhlen & 0xf0) >> 4
hlen = (verhlen & 0x0f) * 4
'''
"VER": ver,
"HLEN": hlen,
"TOS": tos,
"IPLEN": iplen,
"IPID": ipid,
"FRAG": frag,
"TTL": ttl,
"PROTO": proto,
"CKSUM": cksum,
"SRC": socket.inet_ntoa(src),
"DST": socket.inet_ntoa(dst)
'''
src_ip = socket.inet_ntoa(src)
dst_ip = socket.inet_ntoa(dst)
return PacketIpHeader(ver, hlen, ttl, proto, src_ip, dst_ip)
#!/usr/bin/env python3 #!/usr/bin/env python3
from pimdm.Daemon.Daemon import Daemon
from pimdm import Main
import _pickle as pickle
import socket
import sys
import os import os
import sys
import socket
import argparse import argparse
import traceback import traceback
import _pickle as pickle
from pimdm import Main
from pimdm.daemon.Daemon import Daemon
VERSION = "1.1"
VERSION = "1.0.4.2"
def client_socket(data_to_send): def client_socket(data_to_send):
# Create a UDS socket # Create a UDS socket
...@@ -58,26 +60,36 @@ class MyDaemon(Daemon): ...@@ -58,26 +60,36 @@ class MyDaemon(Daemon):
print(sys.stderr, 'sending data back to the client') print(sys.stderr, 'sending data back to the client')
print(pickle.loads(data)) print(pickle.loads(data))
args = pickle.loads(data) args = pickle.loads(data)
if 'ipv4' not in args and 'ipv6' not in args or not (args.ipv4 or args.ipv6):
args.ipv4 = True
args.ipv6 = False
if 'list_interfaces' in args and args.list_interfaces: if 'list_interfaces' in args and args.list_interfaces:
connection.sendall(pickle.dumps(Main.list_enabled_interfaces())) connection.sendall(pickle.dumps(Main.list_enabled_interfaces(ipv4=args.ipv4, ipv6=args.ipv6)))
elif 'list_neighbors' in args and args.list_neighbors: elif 'list_neighbors' in args and args.list_neighbors:
connection.sendall(pickle.dumps(Main.list_neighbors())) connection.sendall(pickle.dumps(Main.list_neighbors(ipv4=args.ipv4, ipv6=args.ipv6)))
elif 'list_state' in args and args.list_state: elif 'list_state' in args and args.list_state:
connection.sendall(pickle.dumps(Main.list_state())) connection.sendall(pickle.dumps(Main.list_state(ipv4=args.ipv4, ipv6=args.ipv6)))
elif 'add_interface' in args and args.add_interface: elif 'add_interface' in args and args.add_interface:
Main.add_pim_interface(args.add_interface[0], False) Main.add_pim_interface(args.add_interface[0], False, ipv4=args.ipv4, ipv6=args.ipv6)
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_sr' in args and args.add_interface_sr: elif 'add_interface_sr' in args and args.add_interface_sr:
Main.add_pim_interface(args.add_interface_sr[0], True) Main.add_pim_interface(args.add_interface_sr[0], True, ipv4=args.ipv4, ipv6=args.ipv6)
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_igmp' in args and args.add_interface_igmp: elif 'add_interface_igmp' in args and args.add_interface_igmp:
Main.add_igmp_interface(args.add_interface_igmp[0]) Main.add_membership_interface(interface_name=args.add_interface_igmp[0], ipv4=True, ipv6=False)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_mld' in args and args.add_interface_mld:
Main.add_membership_interface(interface_name=args.add_interface_mld[0], ipv4=False, ipv6=True)
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface' in args and args.remove_interface: elif 'remove_interface' in args and args.remove_interface:
Main.remove_interface(args.remove_interface[0], pim=True) Main.remove_interface(args.remove_interface[0], pim=True, ipv4=args.ipv4, ipv6=args.ipv6)
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface_igmp' in args and args.remove_interface_igmp: elif 'remove_interface_igmp' in args and args.remove_interface_igmp:
Main.remove_interface(args.remove_interface_igmp[0], igmp=True) Main.remove_interface(args.remove_interface_igmp[0], membership=True, ipv4=True, ipv6=False)
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface_mld' in args and args.remove_interface_mld:
Main.remove_interface(args.remove_interface_mld[0], membership=True, ipv4=False, ipv6=True)
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
elif 'stop' in args and args.stop: elif 'stop' in args and args.stop:
Main.stop() Main.stop()
...@@ -102,18 +114,30 @@ def main(): ...@@ -102,18 +114,30 @@ def main():
group.add_argument("-start", "--start", action="store_true", default=False, help="Start PIM") group.add_argument("-start", "--start", action="store_true", default=False, help="Start PIM")
group.add_argument("-stop", "--stop", action="store_true", default=False, help="Stop PIM") group.add_argument("-stop", "--stop", action="store_true", default=False, help="Stop PIM")
group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM") group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM")
group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces") group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces. "
group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors") "Use -4 or -6 to specify IPv4 or IPv6 interfaces.")
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP") group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors. "
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table") "Use -4 or -6 to specify IPv4 or IPv6 PIM neighbors.")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface") group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List IGMP/MLD and PIM-DM state machines."
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled") " Use -4 or -6 to specify IPv4 or IPv6 state respectively.")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table. "
"Use -4 or -6 to specify IPv4 or IPv6 multicast routing table.")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface. "
"Use -4 or -6 to specify IPv4 or IPv6 interface.")
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled. "
"Use -4 or -6 to specify IPv4 or IPv6 interface.")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface") group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface") group.add_argument("-aimld", "--add_interface_mld", nargs=1, metavar='INTERFACE_NAME', help="Add MLD interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface. "
"Use -4 or -6 to specify IPv4 or IPv6 interface.")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface") group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
group.add_argument("-rimld", "--remove_interface_mld", nargs=1, metavar='INTERFACE_NAME', help="Remove MLD interface")
group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)") group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)")
group.add_argument("-t", "--test", nargs=2, metavar=('ROUTER_NAME', 'SERVER_LOG_IP'), help="Tester... send log information to SERVER_LOG_IP. Set the router name to ROUTER_NAME") group.add_argument("-t", "--test", nargs=2, metavar=('ROUTER_NAME', 'SERVER_LOG_IP'), help="Tester... send log information to SERVER_LOG_IP. Set the router name to ROUTER_NAME")
group.add_argument("--version", action='version', version='%(prog)s ' + VERSION) group.add_argument("--version", action='version', version='%(prog)s ' + VERSION)
group_ipversion = parser.add_mutually_exclusive_group(required=False)
group_ipversion.add_argument("-4", "--ipv4", action="store_true", default=False, help="Setting for IPv4")
group_ipversion.add_argument("-6", "--ipv6", action="store_true", default=False, help="Setting for IPv6")
args = parser.parse_args() args = parser.parse_args()
#print(parser.parse_args()) #print(parser.parse_args())
...@@ -137,7 +161,10 @@ def main(): ...@@ -137,7 +161,10 @@ def main():
os.system("tail -f /var/log/pimdm/stdout") os.system("tail -f /var/log/pimdm/stdout")
sys.exit(0) sys.exit(0)
elif args.multicast_routes: elif args.multicast_routes:
os.system("ip mroute show") if args.ipv4 or not args.ipv6:
os.system("ip mroute show")
elif args.ipv6:
os.system("ip -6 mroute show")
sys.exit(0) sys.exit(0)
elif not daemon.is_running(): elif not daemon.is_running():
print("PIM-DM is not running") print("PIM-DM is not running")
......
import socket import socket
import ipaddress import ipaddress
from pyroute2 import IPDB
from threading import RLock from threading import RLock
from socket import if_indextoname
from pimdm.utils import if_indextoname from pyroute2 import IPDB
def get_route(ip_dst: str): def get_route(ip_dst: str):
...@@ -48,27 +47,34 @@ class UnicastRouting(object): ...@@ -48,27 +47,34 @@ class UnicastRouting(object):
@staticmethod @staticmethod
def get_route(ip_dst: str): def get_route(ip_dst: str):
ip_bytes = socket.inet_aton(ip_dst) ip_version = ipaddress.ip_address(ip_dst).version
ip_int = int.from_bytes(ip_bytes, byteorder='big') if ip_version == 4:
family = socket.AF_INET
full_mask = 32
elif ip_version == 6:
family = socket.AF_INET6
full_mask = 128
else:
raise Exception("Unknown IP version")
info = None info = None
with UnicastRouting.lock: with UnicastRouting.lock:
ipdb = UnicastRouting.ipdb # type:IPDB ipdb = UnicastRouting.ipdb # type:IPDB
for mask_len in range(32, 0, -1): for mask_len in range(full_mask, 0, -1):
ip_bytes = (ip_int & (0xFFFFFFFF << (32 - mask_len))).to_bytes(4, "big") dst_network = str(ipaddress.ip_interface(ip_dst + "/" + str(mask_len)).network)
ip_dst = socket.inet_ntoa(ip_bytes) + "/" + str(mask_len)
print(ip_dst) print(dst_network)
if ip_dst in ipdb.routes: if dst_network in ipdb.routes:
print(info) print(info)
if ipdb.routes[ip_dst]['ipdb_scope'] != 'gc': if ipdb.routes[{'dst': dst_network, 'family': family}]['ipdb_scope'] != 'gc':
info = ipdb.routes[ip_dst] info = ipdb.routes[dst_network]
break break
else: else:
continue continue
if not info: if not info:
print("0.0.0.0/0") print("0.0.0.0/0 or ::/0")
if "default" in ipdb.routes: if "default" in ipdb.routes:
info = ipdb.routes["default"] info = ipdb.routes[{'dst': 'default', 'family': family}]
print(info) print(info)
return info return info
...@@ -85,13 +91,16 @@ class UnicastRouting(object): ...@@ -85,13 +91,16 @@ class UnicastRouting(object):
oif = unicast_route.get("oif") oif = unicast_route.get("oif")
next_hop = unicast_route["gateway"] next_hop = unicast_route["gateway"]
multipaths = unicast_route["multipath"] multipaths = unicast_route["multipath"]
# prefsrc = unicast_route.get("prefsrc") #prefsrc = unicast_route.get("prefsrc")
# rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop #rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop
rpf_node = next_hop if next_hop is not None else ip_dst rpf_node = next_hop if next_hop is not None else ip_dst
highest_ip = ipaddress.ip_address("0.0.0.0") if ipaddress.ip_address(ip_dst).version == 4:
highest_ip = ipaddress.ip_address("0.0.0.0")
else:
highest_ip = ipaddress.ip_address("::")
for m in multipaths: for m in multipaths:
if m["gateway"] is None: if m.get("gateway", None) is None:
oif = m.get('oif') oif = m.get('oif')
rpf_node = ip_dst rpf_node = ip_dst
break break
...@@ -107,14 +116,22 @@ class UnicastRouting(object): ...@@ -107,14 +116,22 @@ class UnicastRouting(object):
interface_name = None if oif is None else if_indextoname(int(oif)) interface_name = None if oif is None else if_indextoname(int(oif))
from pimdm import Main from pimdm import Main
rpf_if = Main.kernel.vif_name_to_index_dic.get(interface_name) if ipaddress.ip_address(ip_dst).version == 4:
rpf_if = Main.kernel.vif_name_to_index_dic.get(interface_name)
else:
rpf_if = Main.kernel_v6.vif_name_to_index_dic.get(interface_name)
return (metric_administrative_distance, metric_cost, rpf_node, rpf_if, mask) return (metric_administrative_distance, metric_cost, rpf_node, rpf_if, mask)
@staticmethod @staticmethod
def unicast_changes(ipdb, msg, action): def unicast_changes(ipdb, msg, action):
"""
Kernel notified about a change
Verify the type of change and recheck all trees if necessary
"""
print("unicast change?") print("unicast change?")
print(action) print(action)
UnicastRouting.lock.acquire() UnicastRouting.lock.acquire()
family = msg['family']
if action == "RTM_NEWROUTE" or action == "RTM_DELROUTE": if action == "RTM_NEWROUTE" or action == "RTM_DELROUTE":
print(ipdb.routes) print(ipdb.routes)
mask_len = msg["dst_len"] mask_len = msg["dst_len"]
...@@ -126,8 +143,10 @@ class UnicastRouting(object): ...@@ -126,8 +143,10 @@ class UnicastRouting(object):
if key == "RTA_DST": if key == "RTA_DST":
network_address = value network_address = value
break break
if network_address is None: if network_address is None and family == socket.AF_INET:
network_address = "0.0.0.0" network_address = "0.0.0.0"
elif network_address is None and family == socket.AF_INET6:
network_address = "::"
print(network_address) print(network_address)
print(mask_len) print(mask_len)
print(network_address + "/" + str(mask_len)) print(network_address + "/" + str(mask_len))
...@@ -135,7 +154,10 @@ class UnicastRouting(object): ...@@ -135,7 +154,10 @@ class UnicastRouting(object):
print(str(subnet)) print(str(subnet))
UnicastRouting.lock.release() UnicastRouting.lock.release()
from pimdm import Main from pimdm import Main
Main.kernel.notify_unicast_changes(subnet) if family == socket.AF_INET:
Main.kernel.notify_unicast_changes(subnet)
elif family == socket.AF_INET6:
Main.kernel_v6.notify_unicast_changes(subnet)
''' '''
elif action == "RTM_NEWADDR" or action == "RTM_DELADDR": elif action == "RTM_NEWADDR" or action == "RTM_DELADDR":
print(action) print(action)
...@@ -154,7 +176,7 @@ class UnicastRouting(object): ...@@ -154,7 +176,7 @@ class UnicastRouting(object):
import traceback import traceback
traceback.print_exc() traceback.print_exc()
pass pass
bnet = ipaddress.ip_network("0.0.0.0/0") subnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet) Main.kernel.notify_unicast_changes(subnet)
elif action == "RTM_NEWLINK" or action == "RTM_DELLINK": elif action == "RTM_NEWLINK" or action == "RTM_DELLINK":
attrs = msg["attrs"] attrs = msg["attrs"]
...@@ -172,7 +194,7 @@ class UnicastRouting(object): ...@@ -172,7 +194,7 @@ class UnicastRouting(object):
print(if_name + ": " + operation) print(if_name + ": " + operation)
UnicastRouting.lock.release() UnicastRouting.lock.release()
if operation == 'DOWN': if operation == 'DOWN':
Main.kernel.remove_interface(if_name, igmp=True, pim=True) Main.kernel.remove_interface(if_name, membership=True, pim=True)
subnet = ipaddress.ip_network("0.0.0.0/0") subnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet) Main.kernel.notify_unicast_changes(subnet)
''' '''
...@@ -180,6 +202,10 @@ class UnicastRouting(object): ...@@ -180,6 +202,10 @@ class UnicastRouting(object):
UnicastRouting.lock.release() UnicastRouting.lock.release()
def stop(self): def stop(self):
"""
No longer monitor unicast changes....
Invoked whenever the protocol is stopped
"""
if self._ipdb: if self._ipdb:
self._ipdb.release() self._ipdb.release()
if UnicastRouting.ipdb: if UnicastRouting.ipdb:
......
...@@ -129,13 +129,13 @@ class GroupState(object): ...@@ -129,13 +129,13 @@ class GroupState(object):
with self.multicast_interface_state_lock: with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state) print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state: for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=True) interface_state.notify_membership(has_members=True)
def notify_routing_remove(self): def notify_routing_remove(self):
with self.multicast_interface_state_lock: with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state) print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state: for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False) interface_state.notify_membership(has_members=False)
def add_multicast_routing_entry(self, kernel_entry): def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock: with self.multicast_interface_state_lock:
...@@ -155,5 +155,5 @@ class GroupState(object): ...@@ -155,5 +155,5 @@ class GroupState(object):
self.clear_timer() self.clear_timer()
self.clear_v1_host_timer() self.clear_v1_host_timer()
for interface_state in self.multicast_interface_state: for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False) interface_state.notify_membership(has_members=False)
del self.multicast_interface_state[:] del self.multicast_interface_state[:]
from threading import Timer from threading import Timer
import logging import logging
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.Packet.ReceivedPacket import ReceivedPacket from pimdm.packet.ReceivedPacket import ReceivedPacket
from pimdm.utils import TYPE_CHECKING from pimdm.utils import TYPE_CHECKING
from pimdm.RWLock.RWLock import RWLockWrite from pimdm.rwlock.RWLock import RWLockWrite
from .querier.Querier import Querier from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState from .GroupState import GroupState
......
from ipaddress import IPv4Address from ipaddress import IPv4Address
from pimdm.utils import TYPE_CHECKING from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.Packet.ReceivedPacket import ReceivedPacket from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import NoMembersPresent, MembersPresent, CheckingMembership from . import NoMembersPresent, MembersPresent, CheckingMembership
if TYPE_CHECKING: if TYPE_CHECKING:
......
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.utils import TYPE_CHECKING from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, LastMemberQueryInterval from ..igmp_globals import Membership_Query, LastMemberQueryInterval
from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent
......
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.utils import TYPE_CHECKING from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, LastMemberQueryInterval from ..igmp_globals import Membership_Query, LastMemberQueryInterval
from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent
......
...@@ -3,8 +3,8 @@ from ipaddress import IPv4Address ...@@ -3,8 +3,8 @@ from ipaddress import IPv4Address
from pimdm.utils import TYPE_CHECKING from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.Packet.ReceivedPacket import ReceivedPacket from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent
if TYPE_CHECKING: if TYPE_CHECKING:
......
import logging
from threading import Lock
from threading import Timer
from pimdm.utils import TYPE_CHECKING
from .wrapper import NoListenersPresent
from .mld_globals import MulticastListenerInterval, LastListenerQueryInterval
if TYPE_CHECKING:
from .RouterState import RouterState
class GroupState(object):
LOGGER = logging.getLogger('pim.mld.RouterState.GroupState')
def __init__(self, router_state: 'RouterState', group_ip: str):
#logger
extra_dict_logger = router_state.router_state_logger.extra.copy()
extra_dict_logger['tree'] = '(*,' + group_ip + ')'
self.group_state_logger = logging.LoggerAdapter(GroupState.LOGGER, extra_dict_logger)
#timers and state
self.router_state = router_state
self.group_ip = group_ip
self.state = NoListenersPresent
self.timer = None
self.retransmit_timer = None
# lock
self.lock = Lock()
# KernelEntry's instances to notify change of igmp state
self.multicast_interface_state = []
self.multicast_interface_state_lock = Lock()
def print_state(self):
return self.state.print_state()
###########################################
# Set state
###########################################
def set_state(self, state):
self.state = state
self.group_state_logger.debug("change membership state to: " + state.print_state())
###########################################
# Set timers
###########################################
def set_timer(self, alternative: bool=False, max_response_time: int=None):
self.clear_timer()
if not alternative:
time = MulticastListenerInterval
else:
time = self.router_state.interface_state.get_group_membership_time(max_response_time)
timer = Timer(time, self.group_membership_timeout)
timer.start()
self.timer = timer
def clear_timer(self):
if self.timer is not None:
self.timer.cancel()
def set_retransmit_timer(self):
self.clear_retransmit_timer()
retransmit_timer = Timer(LastListenerQueryInterval, self.retransmit_timeout)
retransmit_timer.start()
self.retransmit_timer = retransmit_timer
def clear_retransmit_timer(self):
if self.retransmit_timer is not None:
self.retransmit_timer.cancel()
###########################################
# Get group state from specific interface state
###########################################
def get_interface_group_state(self):
return self.state.get_state(self.router_state)
###########################################
# Timer timeout
###########################################
def group_membership_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_timeout(self)
def retransmit_timeout(self):
with self.lock:
self.get_interface_group_state().retransmit_timeout(self)
###########################################
# Receive Packets
###########################################
def receive_report(self):
with self.lock:
self.get_interface_group_state().receive_report(self)
def receive_done(self):
with self.lock:
self.get_interface_group_state().receive_done(self)
def receive_group_specific_query(self, max_response_time: int):
with self.lock:
self.get_interface_group_state().receive_group_specific_query(self, max_response_time)
###########################################
# Notify Routing
###########################################
def notify_routing_add(self):
with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_membership(has_members=True)
def notify_routing_remove(self):
with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_membership(has_members=False)
def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.append(kernel_entry)
return self.has_members()
def remove_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.remove(kernel_entry)
def has_members(self):
return self.state is not NoListenersPresent
def remove(self):
with self.multicast_interface_state_lock:
self.clear_retransmit_timer()
self.clear_timer()
for interface_state in self.multicast_interface_state:
interface_state.notify_membership(has_members=False)
del self.multicast_interface_state[:]
import logging
from threading import Timer
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from pimdm.utils import TYPE_CHECKING
from pimdm.rwlock.RWLock import RWLockWrite
from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState
from .mld_globals import QueryResponseInterval, QueryInterval, OtherQuerierPresentInterval, MULTICAST_LISTENER_QUERY_TYPE
if TYPE_CHECKING:
from pimdm.InterfaceMLD import InterfaceMLD
class RouterState(object):
ROUTER_STATE_LOGGER = logging.getLogger('pim.mld.RouterState')
def __init__(self, interface: 'InterfaceMLD'):
#logger
logger_extra = dict()
logger_extra['vif'] = interface.vif_index
logger_extra['interfacename'] = interface.interface_name
self.router_state_logger = logging.LoggerAdapter(RouterState.ROUTER_STATE_LOGGER, logger_extra)
# interface of the router connected to the network
self.interface = interface
# state of the router (Querier/NonQuerier)
self.interface_state = Querier
# state of each group
# Key: GroupIPAddress, Value: GroupState object
self.group_state = {}
self.group_state_lock = RWLockWrite()
# send general query
packet = PacketMLDHeader(type=MULTICAST_LISTENER_QUERY_TYPE, max_resp_delay=QueryResponseInterval*1000)
self.interface.send(packet.bytes())
# set initial general query timer
timer = Timer(QueryInterval, self.general_query_timeout)
timer.start()
self.general_query_timer = timer
# present timer
self.other_querier_present_timer = None
# Send packet via interface
def send(self, data: bytes, address: str):
self.interface.send(data, address)
############################################
# interface_state methods
############################################
def print_state(self):
return self.interface_state.state_name()
def set_general_query_timer(self):
self.clear_general_query_timer()
general_query_timer = Timer(QueryInterval, self.general_query_timeout)
general_query_timer.start()
self.general_query_timer = general_query_timer
def clear_general_query_timer(self):
if self.general_query_timer is not None:
self.general_query_timer.cancel()
def set_other_querier_present_timer(self):
self.clear_other_querier_present_timer()
other_querier_present_timer = Timer(OtherQuerierPresentInterval, self.other_querier_present_timeout)
other_querier_present_timer.start()
self.other_querier_present_timer = other_querier_present_timer
def clear_other_querier_present_timer(self):
if self.other_querier_present_timer is not None:
self.other_querier_present_timer.cancel()
def general_query_timeout(self):
self.interface_state.general_query_timeout(self)
def other_querier_present_timeout(self):
self.interface_state.other_querier_present_timeout(self)
def change_interface_state(self, querier: bool):
if querier:
self.interface_state = Querier
self.router_state_logger.debug('change querier state to -> Querier')
else:
self.interface_state = NonQuerier
self.router_state_logger.debug('change querier state to -> NonQuerier')
############################################
# group state methods
############################################
def get_group_state(self, group_ip):
with self.group_state_lock.genRlock():
if group_ip in self.group_state:
return self.group_state[group_ip]
with self.group_state_lock.genWlock():
if group_ip in self.group_state:
group_state = self.group_state[group_ip]
else:
group_state = GroupState(self, group_ip)
self.group_state[group_ip] = group_state
return group_state
def receive_report(self, packet: ReceivedPacket):
mld_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v2_membership_report()
self.get_group_state(mld_group).receive_report()
def receive_done(self, packet: ReceivedPacket):
mld_group = packet.payload.group_address
#if igmp_group in self.group_state:
# self.group_state[igmp_group].receive_leave_group()
self.get_group_state(mld_group).receive_done()
def receive_query(self, packet: ReceivedPacket):
self.interface_state.receive_query(self, packet)
mld_group = packet.payload.group_address
# process group specific query
if mld_group != "::" and mld_group in self.group_state:
#if igmp_group != "0.0.0.0":
max_response_time = packet.payload.max_resp_delay
#self.group_state[igmp_group].receive_group_specific_query(max_response_time)
self.get_group_state(mld_group).receive_group_specific_query(max_response_time)
def remove(self):
for group in self.group_state.values():
group.remove()
#MLD timers (in seconds)
RobustnessVariable = 2
QueryInterval = 125
QueryResponseInterval = 10
MulticastListenerInterval = (RobustnessVariable * QueryInterval) + (QueryResponseInterval)
OtherQuerierPresentInterval = (RobustnessVariable * QueryInterval) + 0.5 * QueryResponseInterval
StartupQueryInterval = (1/4) * QueryInterval
StartupQueryCount = RobustnessVariable
LastListenerQueryInterval = 1
LastListenerQueryCount = RobustnessVariable
UnsolicitedReportInterval = 10
# MLD msg type
MULTICAST_LISTENER_QUERY_TYPE = 130
MULTICAST_LISTENER_REPORT_TYPE = 131
MULTICAST_LISTENER_DONE_TYPE = 132
\ No newline at end of file
from pimdm.utils import TYPE_CHECKING
from ..wrapper import NoListenersPresent
from ..wrapper import ListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_report')
group_state.set_timer()
group_state.set_state(ListenersPresent)
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: group_membership_timeout')
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: retransmit_timeout')
# do nothing
return
from pimdm.utils import TYPE_CHECKING
from ..wrapper import NoListenersPresent
from ..wrapper import CheckingListeners
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_report')
group_state.set_timer()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_group_specific_query')
group_state.set_timer(alternative=True, max_response_time=max_response_time)
group_state.set_state(CheckingListeners)
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: group_membership_timeout')
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: retransmit_timeout')
# do nothing
return
from pimdm.utils import TYPE_CHECKING
from ..wrapper import ListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_report')
group_state.set_timer()
group_state.set_state(ListenersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: group_membership_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: retransmit_timeout')
# do nothing
return
from ipaddress import IPv6Address
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import QueryResponseInterval, LastListenerQueryCount
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import NoListenersPresent, ListenersPresent, CheckingListeners
if TYPE_CHECKING:
from ..RouterState import RouterState
class NonQuerier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: general_query_timeout')
# do nothing
return
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: other_querier_present_timeout')
#change state to Querier
router_state.change_interface_state(querier=True)
# send general query
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=QueryResponseInterval*1000)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('NonQuerier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv6Address(source_ip) >= IPv6Address(router_state.interface.get_ip()):
return
# reset other present querier timer
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Non Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return (max_response_time/1000.0) * LastListenerQueryCount
# State
@staticmethod
def get_checking_listeners_state():
return CheckingListeners
@staticmethod
def get_listeners_present_state():
return ListenersPresent
@staticmethod
def get_no_listeners_present_state():
return NoListenersPresent
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import LastListenerQueryInterval
from ..wrapper import ListenersPresent, NoListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: receive_report')
group_state.set_timer()
group_state.clear_retransmit_timer()
group_state.set_state(ListenersPresent)
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier CheckingListeners: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: group_membership_timeout')
group_state.clear_retransmit_timer()
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: retransmit_timeout')
group_addr = group_state.group_ip
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=LastListenerQueryInterval*1000, group_address=group_addr)
group_state.router_state.send(data=packet.bytes(), address=group_addr)
group_state.set_retransmit_timer()
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import LastListenerQueryInterval
from ..wrapper import CheckingListeners, NoListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: receive_report')
group_state.set_timer()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: receive_done')
group_ip = group_state.group_ip
group_state.set_timer(alternative=True)
group_state.set_retransmit_timer()
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=LastListenerQueryInterval*1000, group_address=group_ip)
group_state.router_state.send(data=packet.bytes(), address=group_ip)
group_state.set_state(CheckingListeners)
def receive_group_specific_query(group_state: 'GroupState', max_response_time):
group_state.group_state_logger.debug('Querier ListenersPresent: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: group_membership_timeout')
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: retransmit_timeout')
# do nothing
return
from pimdm.utils import TYPE_CHECKING
from ..wrapper import ListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: receive_report')
group_state.set_timer()
group_state.set_state(ListenersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier NoListenersPresent: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: group_membership_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: retransmit_timeout')
# do nothing
return
from ipaddress import IPv6Address
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import LastListenerQueryInterval, LastListenerQueryCount, QueryResponseInterval
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import CheckingListeners, ListenersPresent, NoListenersPresent
if TYPE_CHECKING:
from ..RouterState import RouterState
class Querier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: general_query_timeout')
# send general query
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=QueryResponseInterval*1000)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('Querier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv6Address(source_ip) >= IPv6Address(router_state.interface.get_ip()):
return
# if source ip of membership query lower than the ip of the received interface => change state
# change state of interface
# Querier -> Non Querier
router_state.change_interface_state(querier=False)
# set other present querier timer
router_state.clear_general_query_timer()
router_state.set_other_querier_present_timer()
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: other_querier_present_timeout')
# do nothing
return
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return LastListenerQueryInterval * LastListenerQueryCount
# State
@staticmethod
def get_checking_listeners_state():
return CheckingListeners
@staticmethod
def get_listeners_present_state():
return ListenersPresent
@staticmethod
def get_no_listeners_present_state():
return NoListenersPresent
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_checking_listeners_state()
def print_state():
return "CheckingListeners"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_listeners_present_state()
def print_state():
return "ListenersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_no_listeners_present_state()
def print_state():
return "NoListenersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
import struct
import socket
class PacketIpHeader:
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version|
+-+-+-+-+
"""
IP_HDR = "! B"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len):
self.version = ver
self.hdr_length = hdr_len
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, ) = struct.unpack(PacketIpHeader.IP_HDR, data[:PacketIpHeader.IP_HDR_LEN])
ver = (verhlen & 0xF0) >> 4
print("ver:", ver)
return PACKET_HEADER.get(ver).parse_bytes(data)
class PacketIpv4Header(PacketIpHeader):
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
IP_HDR = "! BBH HH BBH 4s 4s"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
super().__init__(ver, hdr_len)
self.ttl = ttl
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpv4Header.IP_HDR, data[:PacketIpv4Header.IP_HDR_LEN])
ver = (verhlen & 0xf0) >> 4
hlen = (verhlen & 0x0f) * 4
'''
"VER": ver,
"HLEN": hlen,
"TOS": tos,
"IPLEN": iplen,
"IPID": ipid,
"FRAG": frag,
"TTL": ttl,
"PROTO": proto,
"CKSUM": cksum,
"SRC": socket.inet_ntoa(src),
"DST": socket.inet_ntoa(dst)
'''
src_ip = socket.inet_ntoa(src)
dst_ip = socket.inet_ntoa(dst)
return PacketIpv4Header(ver, hlen, ttl, proto, src_ip, dst_ip)
class PacketIpv6Header(PacketIpHeader):
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| Traffic Class | Flow Label |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Payload Length | Next Header | Hop Limit |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Source Address +
| |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Destination Address +
| |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
IP6_HDR = "! I HBB 16s 16s"
IP6_HDR_LEN = struct.calcsize(IP6_HDR)
def __init__(self, ver, next_header, hop_limit, ip_src, ip_dst):
# TODO: confirm hdr_length in case of multiple options/headers
super().__init__(ver, PacketIpv6Header.IP6_HDR_LEN)
self.next_header = next_header
self.hop_limit = hop_limit
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return PacketIpv6Header.IP6_HDR_LEN
@staticmethod
def parse_bytes(data: bytes):
(ver_tc_fl, _, next_header, hop_limit, src, dst) = \
struct.unpack(PacketIpv6Header.IP6_HDR, data[:PacketIpv6Header.IP6_HDR_LEN])
ver = (ver_tc_fl & 0xf0000000) >> 28
#tc = (ver_tc_fl & 0x0ff00000) >> 20
#fl = (ver_tc_fl & 0x000fffff)
'''
"VER": ver,
"TRAFFIC CLASS": tc,
"FLOW LABEL": fl,
"PAYLOAD LEN": payload_length,
"NEXT HEADER": next_header,
"HOP LIMIT": hop_limit,
"SRC": socket.inet_atop(socket.AF_INET6, src),
"DST": socket.inet_atop(socket.AF_INET6, dst)
'''
src_ip = socket.inet_ntop(socket.AF_INET6, src)
dst_ip = socket.inet_ntop(socket.AF_INET6, dst)
return PacketIpv6Header(ver, next_header, hop_limit, src_ip, dst_ip)
PACKET_HEADER = {
4: PacketIpv4Header,
6: PacketIpv6Header,
}
import struct
import socket
from .PacketPayload import PacketPayload
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Code | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Maximum Response Delay | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Multicast Address +
| |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
class PacketMLDHeader(PacketPayload):
MLD_TYPE = 58
MLD_HDR = "! BB H H H 16s"
MLD_HDR_LEN = struct.calcsize(MLD_HDR)
MULTICAST_LISTENER_QUERY_TYPE = 130
MULTICAST_LISTENER_REPORT_TYPE = 131
MULTICAST_LISTENER_DONE_TYPE = 132
def __init__(self, type: int, max_resp_delay: int, group_address: str = "::"):
# todo check type
self.type = type
self.max_resp_delay = max_resp_delay
self.group_address = group_address
def get_mld_type(self):
return self.type
def bytes(self) -> bytes:
# obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketMLDHeader.MLD_HDR, self.type, 0, 0, self.max_resp_delay, 0,
socket.inet_pton(socket.AF_INET6, self.group_address))
#mld_checksum = checksum(msg_without_chcksum)
#msg = msg_without_chcksum[0:2] + struct.pack("! H", mld_checksum) + msg_without_chcksum[4:]
# checksum handled by linux kernel
return msg_without_chcksum
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
mld_hdr = data[0:PacketMLDHeader.MLD_HDR_LEN]
if len(mld_hdr) < PacketMLDHeader.MLD_HDR_LEN:
raise Exception("MLD packet length is lower than expected")
(mld_type, _, _, max_resp_delay, _, group_address) = struct.unpack(PacketMLDHeader.MLD_HDR, mld_hdr)
# checksum is handled by linux kernel
mld_hdr = mld_hdr[PacketMLDHeader.MLD_HDR_LEN:]
group_address = socket.inet_ntop(socket.AF_INET6, group_address)
pkt = PacketMLDHeader(mld_type, max_resp_delay, group_address)
return pkt
...@@ -55,7 +55,7 @@ class PacketPimEncodedGroupAddress: ...@@ -55,7 +55,7 @@ class PacketPimEncodedGroupAddress:
elif version == 6: elif version == 6:
return (PacketPimEncodedGroupAddress.IPV6_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV6, socket.AF_INET6) return (PacketPimEncodedGroupAddress.IPV6_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV6, socket.AF_INET6)
else: else:
raise Exception raise Exception("Unknown address family")
def __len__(self): def __len__(self):
version = ipaddress.ip_address(self.group_address).version version = ipaddress.ip_address(self.group_address).version
...@@ -64,7 +64,7 @@ class PacketPimEncodedGroupAddress: ...@@ -64,7 +64,7 @@ class PacketPimEncodedGroupAddress:
elif version == 6: elif version == 6:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6 return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6
else: else:
raise Exception raise Exception("Unknown address family")
@staticmethod @staticmethod
def parse_bytes(data: bytes): def parse_bytes(data: bytes):
...@@ -72,13 +72,14 @@ class PacketPimEncodedGroupAddress: ...@@ -72,13 +72,14 @@ class PacketPimEncodedGroupAddress:
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS, data_without_group_addr) (addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS, data_without_group_addr)
data_group_addr = data[PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN:] data_group_addr = data[PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV4: if addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV4_HDR, data_group_addr[:4]) (ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV4_HDR, data_group_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip) ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV6: elif addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV6_HDR, data_group_addr[:16]) (ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV6_HDR, data_group_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip) ip = socket.inet_ntop(socket.AF_INET6, ip)
else:
raise Exception("Unknown address family")
if encoding != 0: if encoding != 0:
print("unknown encoding") print("unknown encoding")
......
...@@ -57,7 +57,7 @@ class PacketPimEncodedSourceAddress: ...@@ -57,7 +57,7 @@ class PacketPimEncodedSourceAddress:
elif version == 6: elif version == 6:
return (PacketPimEncodedSourceAddress.IPV6_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV6, socket.AF_INET6) return (PacketPimEncodedSourceAddress.IPV6_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV6, socket.AF_INET6)
else: else:
raise Exception raise Exception("Unknown address family")
def __len__(self): def __len__(self):
version = ipaddress.ip_address(self.source_address).version version = ipaddress.ip_address(self.source_address).version
...@@ -66,7 +66,7 @@ class PacketPimEncodedSourceAddress: ...@@ -66,7 +66,7 @@ class PacketPimEncodedSourceAddress:
elif version == 6: elif version == 6:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6 return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
else: else:
raise Exception raise Exception("Unknown address family")
@staticmethod @staticmethod
def parse_bytes(data: bytes): def parse_bytes(data: bytes):
...@@ -74,13 +74,14 @@ class PacketPimEncodedSourceAddress: ...@@ -74,13 +74,14 @@ class PacketPimEncodedSourceAddress:
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS, data_without_source_addr) (addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS, data_without_source_addr)
data_source_addr = data[PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN:] data_source_addr = data[PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV4: if addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV4_HDR, data_source_addr[:4]) (ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV4_HDR, data_source_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip) ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV6: elif addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV6_HDR, data_source_addr[:16]) (ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV6_HDR, data_source_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip) ip = socket.inet_ntop(socket.AF_INET6, ip)
else:
raise Exception("Unknown address family")
if encoding != 0: if encoding != 0:
print("unknown encoding") print("unknown encoding")
......
...@@ -46,7 +46,7 @@ class PacketPimEncodedUnicastAddress: ...@@ -46,7 +46,7 @@ class PacketPimEncodedUnicastAddress:
elif version == 6: elif version == 6:
return (PacketPimEncodedUnicastAddress.IPV6_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV6, socket.AF_INET6) return (PacketPimEncodedUnicastAddress.IPV6_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV6, socket.AF_INET6)
else: else:
raise Exception raise Exception("Unknown address family")
def __len__(self): def __len__(self):
version = ipaddress.ip_address(self.unicast_address).version version = ipaddress.ip_address(self.unicast_address).version
...@@ -55,7 +55,7 @@ class PacketPimEncodedUnicastAddress: ...@@ -55,7 +55,7 @@ class PacketPimEncodedUnicastAddress:
elif version == 6: elif version == 6:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6 return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
else: else:
raise Exception raise Exception("Unknown address family")
@staticmethod @staticmethod
def parse_bytes(data: bytes): def parse_bytes(data: bytes):
...@@ -69,6 +69,8 @@ class PacketPimEncodedUnicastAddress: ...@@ -69,6 +69,8 @@ class PacketPimEncodedUnicastAddress:
elif addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV6: elif addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV6_HDR, data_unicast_addr[:16]) (ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV6_HDR, data_unicast_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip) ip = socket.inet_ntop(socket.AF_INET6, ip)
else:
raise Exception("Unknown address family")
if encoding != 0: if encoding != 0:
print("unknown encoding") print("unknown encoding")
......
import socket
from .Packet import Packet from .Packet import Packet
from .PacketIpHeader import PacketIpHeader
from .PacketIGMPHeader import PacketIGMPHeader
from .PacketPimHeader import PacketPimHeader from .PacketPimHeader import PacketPimHeader
from .PacketMLDHeader import PacketMLDHeader
from .PacketIGMPHeader import PacketIGMPHeader
from .PacketIpHeader import PacketIpv4Header, PacketIpv6Header
from pimdm.utils import TYPE_CHECKING from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from pimdm.Interface import Interface from pimdm.Interface import Interface
...@@ -13,13 +15,32 @@ class ReceivedPacket(Packet): ...@@ -13,13 +15,32 @@ class ReceivedPacket(Packet):
def __init__(self, raw_packet: bytes, interface: 'Interface'): def __init__(self, raw_packet: bytes, interface: 'Interface'):
self.interface = interface self.interface = interface
# Parse ao packet e preencher objeto Packet
packet_ip_hdr = raw_packet[:PacketIpHeader.IP_HDR_LEN] # Parse packet and fill Packet super class
ip_header = PacketIpHeader.parse_bytes(packet_ip_hdr) ip_header = PacketIpv4Header.parse_bytes(raw_packet)
protocol_number = ip_header.proto protocol_number = ip_header.proto
packet_without_ip_hdr = raw_packet[ip_header.hdr_length:] packet_without_ip_hdr = raw_packet[ip_header.hdr_length:]
payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr) payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr)
super().__init__(ip_header=ip_header, payload=payload) super().__init__(ip_header=ip_header, payload=payload)
class ReceivedPacket_v6(Packet):
# choose payload protocol class based on ip protocol number
payload_protocol_v6 = {58: PacketMLDHeader, 103: PacketPimHeader}
def __init__(self, raw_packet: bytes, ancdata: list, src_addr: str, next_header: int, interface: 'Interface'):
self.interface = interface
# Parse packet and fill Packet super class
dst_addr = "::"
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.IPPROTO_IPV6 and cmsg_type == socket.IPV6_PKTINFO:
dst_addr = socket.inet_ntop(socket.AF_INET6, cmsg_data[:16])
break
src_addr = src_addr[0].split("%")[0]
ipv6_packet = PacketIpv6Header(ver=6, hop_limit=1, next_header=next_header, ip_src=src_addr, ip_dst=dst_addr)
payload = ReceivedPacket_v6.payload_protocol_v6[next_header].parse_bytes(raw_packet)
super().__init__(ip_header=ipv6_packet, payload=payload)
import logging
from time import time
from threading import Lock, RLock
from pimdm import UnicastRouting
from .metric import AssertMetric
from .tree_if_upstream import TreeInterfaceUpstream from .tree_if_upstream import TreeInterfaceUpstream
from .tree_if_downstream import TreeInterfaceDownstream from .tree_if_downstream import TreeInterfaceDownstream
from .tree_interface import TreeInterface from .tree_interface import TreeInterface
from threading import Lock, RLock
from .metric import AssertMetric
from pimdm import UnicastRouting, Main
from time import time
import logging
class KernelEntry: class KernelEntry:
TREE_TIMEOUT = 180
KERNEL_LOGGER = logging.getLogger('pim.KernelEntry') KERNEL_LOGGER = logging.getLogger('pim.KernelEntry')
def __init__(self, source_ip: str, group_ip: str): def __init__(self, source_ip: str, group_ip: str, kernel_entry_interface):
self.kernel_entry_logger = logging.LoggerAdapter(KernelEntry.KERNEL_LOGGER, {'tree': '(' + source_ip + ',' + group_ip + ')'}) self.kernel_entry_logger = logging.LoggerAdapter(KernelEntry.KERNEL_LOGGER,
{'tree': '(' + source_ip + ',' + group_ip + ')'})
self.kernel_entry_logger.debug('Create KernelEntry') self.kernel_entry_logger.debug('Create KernelEntry')
self.source_ip = source_ip self.source_ip = source_ip
self.group_ip = group_ip self.group_ip = group_ip
self._kernel_entry_interface = kernel_entry_interface
# OBTAIN UNICAST ROUTING INFORMATION################################################### # OBTAIN UNICAST ROUTING INFORMATION###################################################
(metric_administrative_distance, metric_cost, rpf_node, root_if, mask) = \ (metric_administrative_distance, metric_cost, rpf_node, root_if, mask) = \
UnicastRouting.get_unicast_info(source_ip) UnicastRouting.get_unicast_info(source_ip)
...@@ -38,7 +42,7 @@ class KernelEntry: ...@@ -38,7 +42,7 @@ class KernelEntry:
self.interface_state = {} # type: Dict[int, TreeInterface] self.interface_state = {} # type: Dict[int, TreeInterface]
with self.CHANGE_STATE_LOCK: with self.CHANGE_STATE_LOCK:
for i in Main.kernel.vif_index_to_name_dic.keys(): for i in self.get_kernel().vif_index_to_name_dic.keys():
try: try:
if i == self.inbound_interface_index: if i == self.inbound_interface_index:
self.interface_state[i] = TreeInterfaceUpstream(self, i) self.interface_state[i] = TreeInterfaceUpstream(self, i)
...@@ -55,22 +59,31 @@ class KernelEntry: ...@@ -55,22 +59,31 @@ class KernelEntry:
print('Tree created') print('Tree created')
def get_inbound_interface_index(self): def get_inbound_interface_index(self):
"""
Get VIF of root interface of this tree
"""
return self.inbound_interface_index return self.inbound_interface_index
def get_outbound_interfaces_indexes(self): def get_outbound_interfaces_indexes(self):
outbound_indexes = [0] * Main.kernel.MAXVIFS """
for (index, state) in self.interface_state.items(): Get OIL of this tree
outbound_indexes[index] = state.is_forwarding() """
return outbound_indexes return self._kernel_entry_interface.get_outbound_interfaces_indexes(self)
################################################ ################################################
# Receive (S,G) data packets or control packets # Receive (S,G) data packets or control packets
################################################ ################################################
def recv_data_msg(self, index): def recv_data_msg(self, index):
"""
Receive data packet regarding this tree in interface with VIF index
"""
print("recv data") print("recv data")
self.interface_state[index].recv_data_msg() self.interface_state[index].recv_data_msg()
def recv_assert_msg(self, index, packet): def recv_assert_msg(self, index, packet):
"""
Receive assert packet regarding this tree in interface with VIF index
"""
print("recv assert") print("recv assert")
pkt_assert = packet.payload.payload pkt_assert = packet.payload.payload
metric = pkt_assert.metric metric = pkt_assert.metric
...@@ -81,28 +94,43 @@ class KernelEntry: ...@@ -81,28 +94,43 @@ class KernelEntry:
self.interface_state[index].recv_assert_msg(received_metric) self.interface_state[index].recv_assert_msg(received_metric)
def recv_prune_msg(self, index, packet): def recv_prune_msg(self, index, packet):
"""
Receive Prune packet regarding this tree in interface with VIF index
"""
print("recv prune msg") print("recv prune msg")
holdtime = packet.payload.payload.hold_time holdtime = packet.payload.payload.hold_time
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
self.interface_state[index].recv_prune_msg(upstream_neighbor_address=upstream_neighbor_address, holdtime=holdtime) self.interface_state[index].recv_prune_msg(upstream_neighbor_address=upstream_neighbor_address, holdtime=holdtime)
def recv_join_msg(self, index, packet): def recv_join_msg(self, index, packet):
"""
Receive Join packet regarding this tree in interface with VIF index
"""
print("recv join msg") print("recv join msg")
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
self.interface_state[index].recv_join_msg(upstream_neighbor_address) self.interface_state[index].recv_join_msg(upstream_neighbor_address)
def recv_graft_msg(self, index, packet): def recv_graft_msg(self, index, packet):
"""
Receive Graft packet regarding this tree in interface with VIF index
"""
print("recv graft msg") print("recv graft msg")
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
source_ip = packet.ip_header.ip_src source_ip = packet.ip_header.ip_src
self.interface_state[index].recv_graft_msg(upstream_neighbor_address, source_ip) self.interface_state[index].recv_graft_msg(upstream_neighbor_address, source_ip)
def recv_graft_ack_msg(self, index, packet): def recv_graft_ack_msg(self, index, packet):
"""
Receive GraftAck packet regarding this tree in interface with VIF index
"""
print("recv graft ack msg") print("recv graft ack msg")
source_ip = packet.ip_header.ip_src source_ip = packet.ip_header.ip_src
self.interface_state[index].recv_graft_ack_msg(source_ip) self.interface_state[index].recv_graft_ack_msg(source_ip)
def recv_state_refresh_msg(self, index, packet): def recv_state_refresh_msg(self, index, packet):
"""
Receive StateRefresh packet regarding this tree in interface with VIF index
"""
print("recv state refresh msg") print("recv state refresh msg")
source_of_state_refresh = packet.ip_header.ip_src source_of_state_refresh = packet.ip_header.ip_src
...@@ -129,11 +157,13 @@ class KernelEntry: ...@@ -129,11 +157,13 @@ class KernelEntry:
self.forward_state_refresh_msg(packet.payload.payload) self.forward_state_refresh_msg(packet.payload.payload)
################################################ ################################################
# Send state refresh msg # Send state refresh msg
################################################ ################################################
def forward_state_refresh_msg(self, state_refresh_packet): def forward_state_refresh_msg(self, state_refresh_packet):
"""
Forward StateRefresh packet through all interfaces
"""
for interface in self.interface_state.values(): for interface in self.interface_state.values():
interface.send_state_refresh(state_refresh_packet) interface.send_state_refresh(state_refresh_packet)
...@@ -142,6 +172,9 @@ class KernelEntry: ...@@ -142,6 +172,9 @@ class KernelEntry:
# Unicast Changes to RPF # Unicast Changes to RPF
############################################################### ###############################################################
def network_update(self): def network_update(self):
"""
Unicast routing table suffered an update and this tree might be affected by it
"""
# TODO TALVEZ OUTRO LOCK PARA BLOQUEAR ENTRADA DE PACOTES # TODO TALVEZ OUTRO LOCK PARA BLOQUEAR ENTRADA DE PACOTES
with self.CHANGE_STATE_LOCK: with self.CHANGE_STATE_LOCK:
...@@ -184,24 +217,34 @@ class KernelEntry: ...@@ -184,24 +217,34 @@ class KernelEntry:
self.rpf_node = rpf_node self.rpf_node = rpf_node
self.interface_state[self.inbound_interface_index].change_on_unicast_routing() self.interface_state[self.inbound_interface_index].change_on_unicast_routing()
# check if add/removal of neighbors from interface afects olist and forward/prune state of interface
def change_at_number_of_neighbors(self): def change_at_number_of_neighbors(self):
"""
Check if modification of number of neighbors causes changes to OIL and interest of interface
"""
with self.CHANGE_STATE_LOCK: with self.CHANGE_STATE_LOCK:
self.change() self.change()
self.evaluate_olist_change() self.evaluate_olist_change()
def new_or_reset_neighbor(self, if_index, neighbor_ip): def new_or_reset_neighbor(self, if_index, neighbor_ip):
"""
An interface identified by if_index has a new neighbor
"""
# todo maybe lock de interfaces # todo maybe lock de interfaces
self.interface_state[if_index].new_or_reset_neighbor(neighbor_ip) self.interface_state[if_index].new_or_reset_neighbor(neighbor_ip)
def is_olist_null(self): def is_olist_null(self):
"""
Check if olist is null
"""
for interface in self.interface_state.values(): for interface in self.interface_state.values():
if interface.is_forwarding(): if interface.is_forwarding():
return False return False
return True return True
def evaluate_olist_change(self): def evaluate_olist_change(self):
"""
React to changes on the olist
"""
with self._lock_test2: with self._lock_test2:
is_olist_null = self.is_olist_null() is_olist_null = self.is_olist_null()
...@@ -214,33 +257,74 @@ class KernelEntry: ...@@ -214,33 +257,74 @@ class KernelEntry:
self._was_olist_null = is_olist_null self._was_olist_null = is_olist_null
def get_source(self): def get_source(self):
"""
Get source IP of multicast source
"""
return self.source_ip return self.source_ip
def get_group(self): def get_group(self):
"""
Get group IP of multicast tree
"""
return self.group_ip return self.group_ip
def change(self): def change(self):
"""
Trigger an update on the multicast routing table
"""
with self._multicast_change: with self._multicast_change:
Main.kernel.set_multicast_route(self) self.get_kernel().set_multicast_route(self)
def delete(self): def delete(self):
"""
Remove kernel entry
"""
with self._multicast_change: with self._multicast_change:
for state in self.interface_state.values(): for state in self.interface_state.values():
state.delete() state.delete()
Main.kernel.remove_multicast_route(self) self.get_kernel().remove_multicast_route(self)
def get_interface_name(self, interface_id):
"""
Get interface name of interface identified by interface_id
"""
return self._kernel_entry_interface.get_interface_name(interface_id)
def get_interface(self, interface_id):
"""
Get PIM interface
"""
return self._kernel_entry_interface.get_interface(self, interface_id)
def get_membership_interface(self, interface_id):
"""
Get IGMP/MLD interface
"""
return self._kernel_entry_interface.get_membership_interface(self, interface_id)
def get_kernel(self):
"""
Get kernel
"""
return self._kernel_entry_interface.get_kernel()
###################################### ######################################
# Interface change # Interface change
####################################### #######################################
def new_interface(self, index): def new_interface(self, index):
"""
React to a new interface that was added and in which a tree was already built
"""
with self.CHANGE_STATE_LOCK: with self.CHANGE_STATE_LOCK:
self.interface_state[index] = TreeInterfaceDownstream(self, index) self.interface_state[index] = TreeInterfaceDownstream(self, index)
self.change() self.change()
self.evaluate_olist_change() self.evaluate_olist_change()
def remove_interface(self, index): def remove_interface(self, index):
"""
React to removal of an interface of a tree that was already built
"""
with self.CHANGE_STATE_LOCK: with self.CHANGE_STATE_LOCK:
#check if removed interface is root interface #check if removed interface is root interface
if self.inbound_interface_index == index: if self.inbound_interface_index == index:
......
from pimdm import Main
from abc import abstractmethod, ABCMeta
class KernelEntryInterface(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def get_outbound_interfaces_indexes(kernel_tree):
"""
Get OIL of this tree
"""
pass
@staticmethod
@abstractmethod
def get_interface_name(interface_id):
"""
Get name of interface from vif id
"""
pass
@staticmethod
@abstractmethod
def get_interface(kernel_tree, interface_id):
"""
Get PIM interface from interface id
"""
pass
@staticmethod
@abstractmethod
def get_membership_interface(kernel_tree, interface_id):
"""
Get IGMP/MLD interface from interface id
"""
pass
@staticmethod
@abstractmethod
def get_kernel():
"""
Get kernel
"""
pass
class KernelEntry4Interface(KernelEntryInterface):
@staticmethod
def get_outbound_interfaces_indexes(kernel_tree):
"""
Get OIL of this tree
"""
outbound_indexes = [0] * Main.kernel.MAXVIFS
for (index, state) in kernel_tree.interface_state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
@staticmethod
def get_interface_name(interface_id):
"""
Get name of interface from vif id
"""
return Main.kernel.vif_index_to_name_dic[interface_id]
@staticmethod
def get_interface(kernel_tree, interface_id):
"""
Get PIM interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.interfaces.get(interface_name, None)
@staticmethod
def get_membership_interface(kernel_tree, interface_id):
"""
Get IGMP interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.igmp_interfaces.get(interface_name, None) # type: InterfaceIGMP
@staticmethod
def get_kernel():
"""
Get kernel
"""
return Main.kernel
class KernelEntry6Interface(KernelEntryInterface):
@staticmethod
def get_outbound_interfaces_indexes(kernel_tree):
"""
Get OIL of this tree
"""
outbound_indexes = [0] * 8
for (index, state) in kernel_tree.interface_state.items():
outbound_indexes[index // 32] |= state.is_forwarding() << (index % 32)
return outbound_indexes
@staticmethod
def get_interface_name(interface_id):
"""
Get name of interface from vif id
"""
return Main.kernel_v6.vif_index_to_name_dic[interface_id]
@staticmethod
def get_interface(kernel_tree, interface_id):
"""
Get PIM interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.interfaces_v6.get(interface_name, None)
@staticmethod
def get_membership_interface(kernel_tree, interface_id):
"""
Get MLD interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.mld_interfaces.get(interface_name, None) # type: InterfaceMLD
@staticmethod
def get_kernel():
"""
Get kernel
"""
return Main.kernel_v6
...@@ -112,12 +112,10 @@ class AssertStateABC(metaclass=ABCMeta): ...@@ -112,12 +112,10 @@ class AssertStateABC(metaclass=ABCMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
def _sendAssert_setAT(interface: "TreeInterfaceDownstream"): def _sendAssert_setAT(interface: "TreeInterfaceDownstream"):
interface.set_assert_timer(pim_globals.ASSERT_TIME) interface.set_assert_timer(pim_globals.ASSERT_TIME)
interface.send_assert() interface.send_assert()
# Override # Override
def __str__(self) -> str: def __str__(self) -> str:
return "AssertSM:" + self.__class__.__name__ return "AssertSM:" + self.__class__.__name__
...@@ -289,7 +287,6 @@ class WinnerState(AssertStateABC): ...@@ -289,7 +287,6 @@ class WinnerState(AssertStateABC):
return "Winner" return "Winner"
class LoserState(AssertStateABC): class LoserState(AssertStateABC):
''' '''
I am Assert Loser (L) I am Assert Loser (L)
...@@ -370,6 +367,7 @@ class LoserState(AssertStateABC): ...@@ -370,6 +367,7 @@ class LoserState(AssertStateABC):
def __str__(self) -> str: def __str__(self) -> str:
return "Loser" return "Loser"
class AssertState(): class AssertState():
NoInfo = NoInfoState() NoInfo = NoInfoState()
Winner = WinnerState() Winner = WinnerState()
......
import subprocess
import struct import struct
import socket import socket
import ipaddress
import subprocess
from ctypes import create_string_buffer, addressof from ctypes import create_string_buffer, addressof
SO_ATTACH_FILTER = 26 SO_ATTACH_FILTER = 26
ETH_P_IP = 0x0800 # Internet Protocol packet ETH_P_IP = 0x0800 # Internet Protocol packet
ETH_P_IPV6 = 0x86DD # IPv6 over bluebook
SO_RCVBUFFORCE = 33 SO_RCVBUFFORCE = 33
def get_s_g_bpf_filter_code(source, group, interface_name): def get_s_g_bpf_filter_code(source, group, interface_name):
#cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group) ip_source_version = ipaddress.ip_address(source).version
cmd = "tcpdump -ddd \"(ip proto not 2) and host %s and dst %s\"" % (source, group) ip_group_version = ipaddress.ip_address(source).version
if ip_source_version == ip_group_version == 4:
# cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group)
cmd = "tcpdump -ddd \"(ip proto not 2) and host %s and dst %s\"" % (source, group)
protocol = ETH_P_IP
elif ip_source_version == ip_group_version == 6:
# TODO: allow ICMPv6 echo request/echo response to be considered multicast packets
cmd = "tcpdump -ddd \"(ip6 proto not 58) and host %s and dst %s\"" % (source, group)
protocol = ETH_P_IPV6
else:
raise Exception("Unknown IP family")
result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
bpf_filter = b'' bpf_filter = b''
...@@ -28,10 +42,10 @@ def get_s_g_bpf_filter_code(source, group, interface_name): ...@@ -28,10 +42,10 @@ def get_s_g_bpf_filter_code(source, group, interface_name):
# Create listening socket with filters # Create listening socket with filters
s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, ETH_P_IP) s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, protocol)
s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog) s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
# todo pequeno ajuste (tamanho de buffer pequeno para o caso de trafego em rajadas): # todo pequeno ajuste (tamanho de buffer pequeno para o caso de trafego em rajadas):
#s.setsockopt(socket.SOL_SOCKET, SO_RCVBUFFORCE, 1) #s.setsockopt(socket.SOL_SOCKET, SO_RCVBUFFORCE, 1)
s.bind((interface_name, ETH_P_IP)) s.bind((interface_name, protocol))
return s return s
...@@ -7,4 +7,8 @@ REFRESH_INTERVAL = 60 # State Refresh Interval ...@@ -7,4 +7,8 @@ REFRESH_INTERVAL = 60 # State Refresh Interval
SOURCE_LIFETIME = 210 SOURCE_LIFETIME = 210
T_LIMIT = 210 T_LIMIT = 210
HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF
HELLO_HOLD_TIME = 160
HELLO_HOLD_TIME_TIMEOUT = 0
ASSERT_CANCEL_METRIC = 0xFFFFFFFF ASSERT_CANCEL_METRIC = 0xFFFFFFFF
\ No newline at end of file
from abc import ABCMeta, abstractstaticmethod from abc import ABCMeta, abstractmethod
class OriginatorStateABC(metaclass=ABCMeta): class OriginatorStateABC(metaclass=ABCMeta):
@abstractstaticmethod @staticmethod
@abstractmethod
def recvDataMsgFromSource(tree): def recvDataMsgFromSource(tree):
pass pass
@abstractstaticmethod @staticmethod
@abstractmethod
def SRTexpires(tree): def SRTexpires(tree):
pass pass
@abstractstaticmethod @staticmethod
@abstractmethod
def SATexpires(tree): def SATexpires(tree):
pass pass
@abstractstaticmethod @staticmethod
@abstractmethod
def SourceNotConnected(tree): def SourceNotConnected(tree):
pass pass
......
'''
Created on Jul 16, 2015
@author: alex
'''
from threading import Timer from threading import Timer
from pimdm.CustomTimer.RemainingTimer import RemainingTimer from pimdm.custom_timer.RemainingTimer import RemainingTimer
from .assert_ import AssertState from .assert_state import AssertState
from .downstream_prune import DownstreamState, DownstreamStateABS from .downstream_prune import DownstreamState, DownstreamStateABS
from .tree_interface import TreeInterface from .tree_interface import TreeInterface
from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh
from pimdm.Packet.Packet import Packet from pimdm.packet.Packet import Packet
from pimdm.Packet.PacketPimHeader import PacketPimHeader from pimdm.packet.PacketPimHeader import PacketPimHeader
import traceback import traceback
import logging import logging
from .. import Main
class TreeInterfaceDownstream(TreeInterface): class TreeInterfaceDownstream(TreeInterface):
...@@ -22,7 +16,7 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -22,7 +16,7 @@ class TreeInterfaceDownstream(TreeInterface):
def __init__(self, kernel_entry, interface_id): def __init__(self, kernel_entry, interface_id):
extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy() extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy()
extra_dict_logger['vif'] = interface_id extra_dict_logger['vif'] = interface_id
extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id] extra_dict_logger['interfacename'] = kernel_entry.get_interface_name(interface_id)
logger = logging.LoggerAdapter(TreeInterfaceDownstream.LOGGER, extra_dict_logger) logger = logging.LoggerAdapter(TreeInterfaceDownstream.LOGGER, extra_dict_logger)
TreeInterface.__init__(self, kernel_entry, interface_id, logger) TreeInterface.__init__(self, kernel_entry, interface_id, logger)
self.logger.debug('Created DownstreamInterface') self.logger.debug('Created DownstreamInterface')
......
'''
Created on Jul 16, 2015
@author: alex
'''
from .tree_interface import TreeInterface from .tree_interface import TreeInterface
from .upstream_prune import UpstreamState from .upstream_prune import UpstreamState
from threading import Timer from threading import Timer
from pimdm.CustomTimer.RemainingTimer import RemainingTimer from pimdm.custom_timer.RemainingTimer import RemainingTimer
from .globals import * from .globals import *
import random import random
from .metric import AssertMetric from .metric import AssertMetric
from .originator import OriginatorState, OriginatorStateABC from .originator import OriginatorState, OriginatorStateABC
from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh
import traceback import traceback
from . import DataPacketsSocket from . import data_packets_socket
import threading import threading
import logging import logging
from .. import Main
class TreeInterfaceUpstream(TreeInterface): class TreeInterfaceUpstream(TreeInterface):
...@@ -25,7 +19,7 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -25,7 +19,7 @@ class TreeInterfaceUpstream(TreeInterface):
def __init__(self, kernel_entry, interface_id): def __init__(self, kernel_entry, interface_id):
extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy() extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy()
extra_dict_logger['vif'] = interface_id extra_dict_logger['vif'] = interface_id
extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id] extra_dict_logger['interfacename'] = kernel_entry.get_interface_name(interface_id)
logger = logging.LoggerAdapter(TreeInterfaceUpstream.LOGGER, extra_dict_logger) logger = logging.LoggerAdapter(TreeInterfaceUpstream.LOGGER, extra_dict_logger)
TreeInterface.__init__(self, kernel_entry, interface_id, logger) TreeInterface.__init__(self, kernel_entry, interface_id, logger)
...@@ -47,15 +41,16 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -47,15 +41,16 @@ class TreeInterfaceUpstream(TreeInterface):
if self.is_S_directly_conn(): if self.is_S_directly_conn():
self._graft_prune_state.sourceIsNowDirectConnect(self) self._graft_prune_state.sourceIsNowDirectConnect(self)
if self.get_interface().is_state_refresh_enabled(): interface = self.get_interface()
if interface is not None and interface.is_state_refresh_enabled():
self._originator_state.recvDataMsgFromSource(self) self._originator_state.recvDataMsgFromSource(self)
# TODO TESTE SOCKET RECV DATA PCKTS # TODO TESTE SOCKET RECV DATA PCKTS
self.socket_is_enabled = True self.socket_is_enabled = True
(s,g) = self.get_tree_id() (s, g) = self.get_tree_id()
interface_name = self.get_interface().interface_name interface_name = self.get_interface_name()
self.socket_pkt = DataPacketsSocket.get_s_g_bpf_filter_code(s, g, interface_name) self.socket_pkt = data_packets_socket.get_s_g_bpf_filter_code(s, g, interface_name)
# run receive method in background # run receive method in background
receive_thread = threading.Thread(target=self.socket_recv) receive_thread = threading.Thread(target=self.socket_recv)
...@@ -182,8 +177,10 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -182,8 +177,10 @@ class TreeInterfaceUpstream(TreeInterface):
def recv_data_msg(self): def recv_data_msg(self):
if not self.is_prune_limit_timer_running() and not self.is_S_directly_conn() and self.is_olist_null(): if not self.is_prune_limit_timer_running() and not self.is_S_directly_conn() and self.is_olist_null():
self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self) self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self)
elif self.is_S_directly_conn() and self.get_interface().is_state_refresh_enabled(): elif self.is_S_directly_conn():
self._originator_state.recvDataMsgFromSource(self) interface = self.get_interface()
if interface is not None and interface.is_state_refresh_enabled():
self._originator_state.recvDataMsgFromSource(self)
def recv_join_msg(self, upstream_neighbor_address): def recv_join_msg(self, upstream_neighbor_address):
......
'''
Created on Jul 16, 2015
@author: alex
'''
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from .. import Main
from threading import RLock from threading import RLock
import traceback import traceback
from .downstream_prune import DownstreamState from .downstream_prune import DownstreamState
from .assert_ import AssertState, AssertStateABC from .assert_state import AssertState, AssertStateABC
from pimdm.Packet.PacketPimGraft import PacketPimGraft from pimdm.packet.PacketPimGraft import PacketPimGraft
from pimdm.Packet.PacketPimGraftAck import PacketPimGraftAck from pimdm.packet.PacketPimGraftAck import PacketPimGraftAck
from pimdm.Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup from pimdm.packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from pimdm.Packet.PacketPimHeader import PacketPimHeader from pimdm.packet.PacketPimHeader import PacketPimHeader
from pimdm.Packet.Packet import Packet from pimdm.packet.Packet import Packet
from pimdm.Packet.PacketPimJoinPrune import PacketPimJoinPrune from pimdm.packet.PacketPimJoinPrune import PacketPimJoinPrune
from pimdm.Packet.PacketPimAssert import PacketPimAssert from pimdm.packet.PacketPimAssert import PacketPimAssert
from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh
from .metric import AssertMetric from .metric import AssertMetric
from threading import Timer from threading import Timer
from .local_membership import LocalMembership from .local_membership import LocalMembership
from .globals import * from .globals import T_LIMIT
import logging import logging
class TreeInterface(metaclass=ABCMeta): class TreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id, logger: logging.LoggerAdapter): def __init__(self, kernel_entry, interface_id, logger: logging.LoggerAdapter):
self._kernel_entry = kernel_entry self._kernel_entry = kernel_entry
...@@ -36,9 +31,8 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -36,9 +31,8 @@ class TreeInterface(metaclass=ABCMeta):
# Local Membership State # Local Membership State
try: try:
interface_name = Main.kernel.vif_index_to_name_dic[interface_id] membership_interface = self.get_membership_interface()
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP group_state = membership_interface.interface_state.get_group_state(kernel_entry.group_ip)
group_state = igmp_interface.interface_state.get_group_state(kernel_entry.group_ip)
#self._igmp_has_members = group_state.add_multicast_routing_entry(self) #self._igmp_has_members = group_state.add_multicast_routing_entry(self)
igmp_has_members = group_state.add_multicast_routing_entry(self) igmp_has_members = group_state.add_multicast_routing_entry(self)
self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo
...@@ -60,8 +54,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -60,8 +54,7 @@ class TreeInterface(metaclass=ABCMeta):
# Received prune hold time # Received prune hold time
self._received_prune_holdtime = None self._received_prune_holdtime = None
self._igmp_lock = RLock() self._membership_lock = RLock()
############################################ ############################################
# Set ASSERT State # Set ASSERT State
...@@ -90,7 +83,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -90,7 +83,6 @@ class TreeInterface(metaclass=ABCMeta):
finally: finally:
self._assert_winner_metric = new_assert_metric self._assert_winner_metric = new_assert_metric
############################################ ############################################
# ASSERT Timer # ASSERT Timer
############################################ ############################################
...@@ -106,7 +98,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -106,7 +98,6 @@ class TreeInterface(metaclass=ABCMeta):
def assert_timeout(self): def assert_timeout(self):
self._assert_state.assertTimerExpires(self) self._assert_state.assertTimerExpires(self)
########################################### ###########################################
# Recv packets # Recv packets
########################################### ###########################################
...@@ -145,7 +136,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -145,7 +136,6 @@ class TreeInterface(metaclass=ABCMeta):
def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator): def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator):
self.recv_assert_msg(received_metric) self.recv_assert_msg(received_metric)
###################################### ######################################
# Send messages # Send messages
###################################### ######################################
...@@ -163,7 +153,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -163,7 +153,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc() traceback.print_exc()
return return
def send_graft_ack(self, ip_sender): def send_graft_ack(self, ip_sender):
print("send graft ack") print("send graft ack")
try: try:
...@@ -177,7 +166,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -177,7 +166,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc() traceback.print_exc()
return return
def send_prune(self, holdtime=None): def send_prune(self, holdtime=None):
if holdtime is None: if holdtime is None:
holdtime = T_LIMIT holdtime = T_LIMIT
...@@ -195,7 +183,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -195,7 +183,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc() traceback.print_exc()
return return
def send_pruneecho(self): def send_pruneecho(self):
holdtime = T_LIMIT holdtime = T_LIMIT
try: try:
...@@ -210,7 +197,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -210,7 +197,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc() traceback.print_exc()
return return
def send_join(self): def send_join(self):
print("send join") print("send join")
...@@ -225,7 +211,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -225,7 +211,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc() traceback.print_exc()
return return
def send_assert(self): def send_assert(self):
print("send assert") print("send assert")
...@@ -240,7 +225,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -240,7 +225,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc() traceback.print_exc()
return return
def send_assert_cancel(self): def send_assert_cancel(self):
print("send assert cancel") print("send assert cancel")
...@@ -254,7 +238,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -254,7 +238,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc() traceback.print_exc()
return return
def send_state_refresh(self, state_refresh_msg_received: PacketPimStateRefresh): def send_state_refresh(self, state_refresh_msg_received: PacketPimStateRefresh):
pass pass
...@@ -282,9 +265,8 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -282,9 +265,8 @@ class TreeInterface(metaclass=ABCMeta):
(s, g) = self.get_tree_id() (s, g) = self.get_tree_id()
# unsubscribe igmp information # unsubscribe igmp information
try: try:
interface_name = Main.kernel.vif_index_to_name_dic[self._interface_id] membership_interface = self.get_membership_interface()
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP group_state = membership_interface.interface_state.get_group_state(g)
group_state = igmp_interface.interface_state.get_group_state(g)
group_state.remove_multicast_routing_entry(self) group_state.remove_multicast_routing_entry(self)
except: except:
pass pass
...@@ -306,29 +288,29 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -306,29 +288,29 @@ class TreeInterface(metaclass=ABCMeta):
def evaluate_ingroup(self): def evaluate_ingroup(self):
self._kernel_entry.evaluate_olist_change() self._kernel_entry.evaluate_olist_change()
############################################################# #############################################################
# Local Membership (IGMP) # Local Membership (IGMP)
############################################################ ############################################################
def notify_igmp(self, has_members: bool): def notify_membership(self, has_members: bool):
with self.get_state_lock(): with self.get_state_lock():
with self._igmp_lock: with self._membership_lock:
if has_members != self._local_membership_state.has_members(): if has_members != self._local_membership_state.has_members():
self._local_membership_state = LocalMembership.Include if has_members else LocalMembership.NoInfo self._local_membership_state = LocalMembership.Include if has_members else LocalMembership.NoInfo
self.change_tree() self.change_tree()
self.evaluate_ingroup() self.evaluate_ingroup()
def igmp_has_members(self): def igmp_has_members(self):
with self._igmp_lock: with self._membership_lock:
return self._local_membership_state.has_members() return self._local_membership_state.has_members()
def get_interface_name(self):
return self._kernel_entry.get_interface_name(self._interface_id)
def get_interface(self): def get_interface(self):
kernel = Main.kernel return self._kernel_entry.get_interface(self._interface_id)
interface_name = kernel.vif_index_to_name_dic[self._interface_id]
interface = Main.interfaces[interface_name]
return interface
def get_membership_interface(self):
return self._kernel_entry.get_membership_interface(self._interface_id)
def get_ip(self): def get_ip(self):
ip = self.get_interface().get_ip() ip = self.get_interface().get_ip()
...@@ -353,9 +335,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -353,9 +335,6 @@ class TreeInterface(metaclass=ABCMeta):
def is_downstream(self): def is_downstream(self):
raise NotImplementedError() raise NotImplementedError()
# obtain ip of RPF'(S) # obtain ip of RPF'(S)
def get_neighbor_RPF(self): def get_neighbor_RPF(self):
''' '''
...@@ -375,8 +354,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -375,8 +354,6 @@ class TreeInterface(metaclass=ABCMeta):
def get_received_prune_holdtime(self): def get_received_prune_holdtime(self):
return self._received_prune_holdtime return self._received_prune_holdtime
################################################### ###################################################
# ASSERT # ASSERT
################################################### ###################################################
......
import array import array
'''
import struct
if struct.pack("H",1) == "\x00\x01": # big endian
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return s & 0xffff
else:
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s>>8)&0xff)|s<<8) & 0xffff
'''
HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF
HELLO_HOLD_TIME = 160
HELLO_HOLD_TIME_TIMEOUT = 0
def checksum(pkt: bytes) -> bytes: def checksum(pkt: bytes) -> bytes:
...@@ -36,35 +11,6 @@ def checksum(pkt: bytes) -> bytes: ...@@ -36,35 +11,6 @@ def checksum(pkt: bytes) -> bytes:
return (((s >> 8) & 0xff) | s << 8) & 0xffff return (((s >> 8) & 0xff) | s << 8) & 0xffff
import ctypes
import ctypes.util
libc = ctypes.CDLL(ctypes.util.find_library('c'))
def if_nametoindex(name):
if not isinstance(name, str):
raise TypeError('name must be a string.')
ret = libc.if_nametoindex(name)
if not ret:
raise RuntimeError("Invalid Name")
return ret
def if_indextoname(index):
if not isinstance(index, int):
raise TypeError('index must be an int.')
libc.if_indextoname.argtypes = [ctypes.c_uint32, ctypes.c_char_p]
libc.if_indextoname.restype = ctypes.c_char_p
ifname = ctypes.create_string_buffer(32)
ifname = libc.if_indextoname(index, ifname)
if not ifname:
raise RuntimeError ("Inavlid Index")
return ifname.decode("utf-8")
# obtain TYPE_CHECKING (for type hinting) # obtain TYPE_CHECKING (for type hinting)
try: try:
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
......
...@@ -12,8 +12,8 @@ setup( ...@@ -12,8 +12,8 @@ setup(
description="PIM-DM protocol", description="PIM-DM protocol",
long_description=open("README.md", "r").read(), long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords="PIM-DM Multicast Routing Protocol Dense-Mode Router RFC3973", keywords="PIM-DM Multicast Routing Protocol Dense-Mode Router RFC3973 IPv4 IPv6",
version="1.0.4.2", version="1.1",
url="http://github.com/pedrofran12/pim_dm", url="http://github.com/pedrofran12/pim_dm",
author="Pedro Oliveira", author="Pedro Oliveira",
author_email="pedro.francisco.oliveira@tecnico.ulisboa.pt", author_email="pedro.francisco.oliveira@tecnico.ulisboa.pt",
...@@ -38,7 +38,6 @@ setup( ...@@ -38,7 +38,6 @@ setup(
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.2",
"Programming Language :: Python :: 3.3", "Programming Language :: Python :: 3.3",
"Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.5",
...@@ -46,5 +45,5 @@ setup( ...@@ -46,5 +45,5 @@ setup(
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
], ],
python_requires=">=3.2", python_requires=">=3.3",
) )
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment