Commit b26c6216 authored by Pedro Oliveira's avatar Pedro Oliveira Committed by GitHub

IPv6 multicast routing support (#3)

parent f6ac9c82
......@@ -8,11 +8,14 @@ We have implemented PIM-DM specification ([RFC3973](https://tools.ietf.org/html/
This repository stores the implementation of this protocol. The implementation is written in Python language and is destined to Linux systems.
Additionally, IGMPv2 and MLDv1 are implemented alongside with PIM-DM to detect interest of hosts.
# Requirements
- Linux machine
- Python3 (we have written all code to be compatible with at least Python v3.2)
- Unicast routing protocol
- Python3 (we have written all code to be compatible with at least Python v3.3)
- pip (to install all dependencies)
- tcpdump
......@@ -41,6 +44,8 @@ In order to start the protocol you first need to explicitly start it. This will
sudo pim-dm -start
```
IPv4 and IPv6 multicast is supported. By default all commands will be executed on IPv4 daemon. To execute a command on the IPv6 daemon use `-6`.
#### Add interface
......@@ -49,21 +54,27 @@ After starting the protocol process you can enable the protocol in specific inte
- To enable PIM-DM without State-Refresh, in a given interface, you need to run the following command:
```
sudo pim-dm -ai INTERFACE_NAME
sudo pim-dm -ai INTERFACE_NAME [-4 | -6]
```
- To enable PIM-DM with State-Refresh, in a given interface, you need to run the following command:
```
sudo pim-dm -aisf INTERFACE_NAME
sudo pim-dm -aisf INTERFACE_NAME [-4 | -6]
```
- To enable IGMP in a given interface, you need to run the following command:
- To enable IGMP/MLD in a given interface, you need to run the following command:
- IGMP:
```
sudo pim-dm -aiigmp INTERFACE_NAME
```
- MLD:
```
sudo pim-dm -aimld INTERFACE_NAME
```
#### Remove interface
To remove a previously added interface, you need run the following commands:
......@@ -71,15 +82,20 @@ To remove a previously added interface, you need run the following commands:
- To remove a previously added PIM-DM interface:
```
sudo pim-dm -ri INTERFACE_NAME
sudo pim-dm -ri INTERFACE_NAME [-4 | -6]
```
- To remove a previously added IGMP interface:
- To remove a previously added IGMP/MLD interface:
- IGMP:
```
sudo pim-dm -riigmp INTERFACE_NAME
```
- MLD:
```
sudo pim-dm -rimld INTERFACE_NAME
```
#### Stop protocol process
......@@ -96,31 +112,31 @@ We have built some list commands that can be used to check the "internals" of th
- #### List interfaces:
Show all router interfaces and which ones have PIM-DM and IGMP enabled. For IGMP enabled interfaces check the IGMP Querier state.
Show all router interfaces and which ones have PIM-DM and IGMP/MLD enabled. For IGMP/MLD enabled interfaces you can check the Querier state.
```
sudo pim-dm -li
sudo pim-dm -li [-4 | -6]
```
- #### List neighbors
Verify neighbors that have established a neighborhood relationship.
```
sudo pim-dm -ln
sudo pim-dm -ln [-4 | -6]
```
- #### List state
List all state machines and corresponding state of all trees that are being monitored. Also list IGMP state for each group being monitored.
```
sudo pim-dm -ls
sudo pim-dm -ls [-4 | -6]
```
- #### Multicast Routing Table
List Linux Multicast Routing Table (equivalent to `ip mroute -show`)
List Linux Multicast Routing Table (equivalent to `ip mroute show`)
```
sudo pim-dm -mr
sudo pim-dm -mr [-4 | -6]
```
......@@ -131,15 +147,10 @@ In order to determine which commands and corresponding arguments are available y
pim-dm -h
```
or
```
pim-dm --help
```
## Change settings
Files tree/globals.py and igmp/igmp_globals.py store all timer values and some configurations regarding IGMP and the PIM-DM. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface.
Files tree/globals.py, igmp/igmp_globals.py and mld/mld_globals.py store all timer values and some configurations regarding PIM-DM, IGMP and MLD. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface.
## Tests
......@@ -151,4 +162,4 @@ We have performed tests to our PIM-DM implementation. You can check on the corre
- [Test_PIM_Assert](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Assert) - Topology used to test the election of the AssertWinner.
- [Test_PIM_Join_Prune_Graft](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Join_Prune_Graft) - Topology used to test the Pruning and Grafting of the multicast distribution tree.
- [Test_PIM_StateRefresh](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_StateRefresh) - Topology used to test PIM-DM State Refresh.
- [Test_IGMP](https://github.com/pedrofran12/hpim_dm/tree/Test_IGMP) - Topology used to test our IGMPv2 implementation.
- [Test_IGMP](https://github.com/pedrofran12/pim_dm/tree/Test_IGMP) - Topology used to test our IGMPv2 implementation.
......@@ -18,8 +18,11 @@ class Interface(metaclass=ABCMeta):
self._recv_socket = recv_socket
self.interface_enabled = False
def _enable(self):
"""
Enable this interface
This will start a thread to be executed in the background to be used in the reception of control packets
"""
self.interface_enabled = True
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
......@@ -27,24 +30,39 @@ class Interface(metaclass=ABCMeta):
receive_thread.start()
def receive(self):
"""
Method that will be executed in the background for the reception of control packets
"""
while self.interface_enabled:
try:
(raw_bytes, _) = self._recv_socket.recvfrom(256 * 1024)
(raw_bytes, ancdata, _, src_addr) = self._recv_socket.recvmsg(256 * 1024, 500)
if raw_bytes:
self._receive(raw_bytes)
self._receive(raw_bytes, ancdata, src_addr)
except Exception:
traceback.print_exc()
continue
@abstractmethod
def _receive(self, raw_bytes):
def _receive(self, raw_bytes, ancdata, src_addr):
"""
Subclass method to be implemented
This method will be invoked whenever a new control packet is received
"""
raise NotImplementedError
def send(self, data: bytes, group_ip: str):
"""
Send a control packet through this interface
Explicitly destined to group_ip (can be unicast or multicast IP)
"""
if self.interface_enabled and data:
self._send_socket.sendto(data, (group_ip, 0))
def remove(self):
"""
This interface is no longer active....
Clear all state regarding it
"""
self.interface_enabled = False
try:
self._recv_socket.shutdown(socket.SHUT_RDWR)
......@@ -54,8 +72,14 @@ class Interface(metaclass=ABCMeta):
self._send_socket.close()
def is_enabled(self):
"""
Verify if this interface is enabled
"""
return self.interface_enabled
@abstractmethod
def get_ip(self):
"""
Get IP of this interface
"""
raise NotImplementedError
......@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from ctypes import create_string_buffer, addressof
import netifaces
from pimdm.Interface import Interface
from pimdm.Packet.ReceivedPacket import ReceivedPacket
from pimdm.packet.ReceivedPacket import ReceivedPacket
from pimdm.igmp.igmp_globals import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query
if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25
......@@ -48,18 +48,20 @@ class InterfaceIGMP(Interface):
self.interface_state = RouterState(self)
super()._enable()
def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
@property
def ip_interface(self):
"""
Get IP of this interface
"""
return self.get_ip()
def send(self, data: bytes, address: str="224.0.0.1"):
super().send(data, address)
def _receive(self, raw_bytes):
def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes:
raw_bytes = raw_bytes[14:]
packet = ReceivedPacket(raw_bytes, self)
......@@ -91,7 +93,8 @@ class InterfaceIGMP(Interface):
def receive_membership_query(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"):
if (IPv4Address(igmp_group).is_multicast and ip_dst == igmp_group) or \
(ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"):
self.interface_state.receive_query(packet)
def receive_unknown_type(self, packet):
......
import socket
import struct
import netifaces
import ipaddress
from socket import if_nametoindex
from ipaddress import IPv6Address
from .Interface import Interface
from .packet.ReceivedPacket import ReceivedPacket_v6
from .mld.mld_globals import MULTICAST_LISTENER_QUERY_TYPE, MULTICAST_LISTENER_DONE_TYPE, MULTICAST_LISTENER_REPORT_TYPE
from ctypes import create_string_buffer, addressof
ETH_P_IPV6 = 0x86DD # IPv6 over bluebook
SO_ATTACH_FILTER = 26
ICMP6_FILTER = 1
IPV6_ROUTER_ALERT = 22
def ICMP6_FILTER_SETBLOCKALL():
return struct.pack("I"*8, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
def ICMP6_FILTER_SETPASS(type, filterp):
return filterp[:type >> 5] + (bytes([(filterp[type >> 5] & ~(1 << ((type) & 31)))])) + filterp[(type >> 5) + 1:]
class InterfaceMLD(Interface):
IPv6_LINK_SCOPE_ALL_NODES = IPv6Address("ff02::1")
IPv6_LINK_SCOPE_ALL_ROUTERS = IPv6Address("ff02::2")
IPv6_ALL_ZEROS = IPv6Address("::")
FILTER_MLD = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
struct.pack('HBBI', 0x15, 0, 9, 0x000086dd),
struct.pack('HBBI', 0x30, 0, 0, 0x00000014),
struct.pack('HBBI', 0x15, 0, 7, 0x00000000),
struct.pack('HBBI', 0x30, 0, 0, 0x00000036),
struct.pack('HBBI', 0x15, 0, 5, 0x0000003a),
struct.pack('HBBI', 0x30, 0, 0, 0x0000003e),
struct.pack('HBBI', 0x15, 2, 0, 0x00000082),
struct.pack('HBBI', 0x15, 1, 0, 0x00000083),
struct.pack('HBBI', 0x15, 0, 1, 0x00000084),
struct.pack('HBBI', 0x6, 0, 0, 0x00040000),
struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
]
def __init__(self, interface_name: str, vif_index: int):
# SEND SOCKET
s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_ICMPV6)
# set socket output interface
if_index = if_nametoindex(interface_name)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, struct.pack('@I', if_index))
"""
# set ICMP6 filter to only receive MLD packets
icmp6_filter = ICMP6_FILTER_SETBLOCKALL()
icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_QUERY_TYPE, icmp6_filter)
icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_REPORT_TYPE, icmp6_filter)
icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_DONE_TYPE, icmp6_filter)
s.setsockopt(socket.IPPROTO_ICMPV6, ICMP6_FILTER, icmp6_filter)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_RECVPKTINFO, True)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, False)
s.setsockopt(socket.IPPROTO_IPV6, self.IPV6_ROUTER_ALERT, 0)
rcv_s = s
"""
ip_interface = "::"
for if_addr in netifaces.ifaddresses(interface_name)[netifaces.AF_INET6]:
ip_interface = if_addr["addr"]
if ipaddress.IPv6Address(ip_interface.split("%")[0]).is_link_local:
# bind to interface
s.bind(socket.getaddrinfo(ip_interface, None, 0, socket.SOCK_RAW, 0, socket.AI_PASSIVE)[0][4])
ip_interface = ip_interface.split("%")[0]
break
self.ip_interface = ip_interface
# RECEIVE SOCKET
rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(ETH_P_IPV6))
# receive only MLD packets by setting a BPF filter
bpf_filter = b''.join(InterfaceMLD.FILTER_MLD)
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', len(InterfaceMLD.FILTER_MLD), mem_addr_of_filters)
rcv_s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
# bind to interface
rcv_s.bind((interface_name, ETH_P_IPV6))
super().__init__(interface_name=interface_name, recv_socket=rcv_s, send_socket=s, vif_index=vif_index)
self.interface_enabled = True
from .mld.RouterState import RouterState
self.interface_state = RouterState(self)
super()._enable()
@staticmethod
def _get_address_family():
return socket.AF_INET6
def get_ip(self):
return self.ip_interface
def send(self, data: bytes, address: str = "FF02::1"):
# send router alert option
cmsg_level = socket.IPPROTO_IPV6
cmsg_type = socket.IPV6_HOPOPTS
cmsg_data = b'\x3a\x00\x05\x02\x00\x00\x01\x00'
self._send_socket.sendmsg([data], [(cmsg_level, cmsg_type, cmsg_data)], 0, (address, 0))
"""
def receive(self):
while self.interface_enabled:
try:
(raw_bytes, ancdata, _, src_addr) = self._recv_socket.recvmsg(256 * 1024, 500)
if raw_bytes:
self._receive(raw_bytes, ancdata, src_addr)
except Exception:
import traceback
traceback.print_exc()
continue
"""
def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes:
raw_bytes = raw_bytes[14:]
src_addr = (socket.inet_ntop(socket.AF_INET6, raw_bytes[8:24]),)
print("MLD IP_SRC bf= ", src_addr)
dst_addr = raw_bytes[24:40]
(next_header,) = struct.unpack("B", raw_bytes[6:7])
print("NEXT HEADER:", next_header)
payload_starts_at_len = 40
if next_header == 0:
# Hop by Hop options
(next_header,) = struct.unpack("B", raw_bytes[40:41])
if next_header != 58:
return
(hdr_ext_len,) = struct.unpack("B", raw_bytes[payload_starts_at_len +1:payload_starts_at_len + 2])
if hdr_ext_len > 0:
payload_starts_at_len = payload_starts_at_len + 1 + hdr_ext_len*8
else:
payload_starts_at_len = payload_starts_at_len + 8
raw_bytes = raw_bytes[payload_starts_at_len:]
ancdata = [(socket.IPPROTO_IPV6, socket.IPV6_PKTINFO, dst_addr)]
print("RECEIVE MLD")
print("ANCDATA: ", ancdata, "; SRC_ADDR: ", src_addr)
packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 58, self)
ip_src = packet.ip_header.ip_src
print("MLD IP_SRC = ", ip_src)
if not (ip_src == "::" or IPv6Address(ip_src).is_multicast):
self.PKT_FUNCTIONS.get(packet.payload.get_mld_type(), InterfaceMLD.receive_unknown_type)(self, packet)
"""
def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes:
packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 58, self)
self.PKT_FUNCTIONS[packet.payload.get_mld_type(), InterfaceMLD.receive_unknown_type](self, packet)
"""
###########################################
# Recv packets
###########################################
def receive_multicast_listener_report(self, packet):
print("RECEIVE MULTICAST LISTENER REPORT")
ip_dst = packet.ip_header.ip_dst
mld_group = packet.payload.group_address
ipv6_group = IPv6Address(mld_group)
ipv6_dst = IPv6Address(ip_dst)
if ipv6_dst == ipv6_group and ipv6_group.is_multicast:
self.interface_state.receive_report(packet)
def receive_multicast_listener_done(self, packet):
print("RECEIVE MULTICAST LISTENER DONE")
ip_dst = packet.ip_header.ip_dst
mld_group = packet.payload.group_address
if IPv6Address(ip_dst) == self.IPv6_LINK_SCOPE_ALL_ROUTERS and IPv6Address(mld_group).is_multicast:
self.interface_state.receive_done(packet)
def receive_multicast_listener_query(self, packet):
print("RECEIVE MULTICAST LISTENER QUERY")
ip_dst = packet.ip_header.ip_dst
mld_group = packet.payload.group_address
ipv6_group = IPv6Address(mld_group)
ipv6_dst = IPv6Address(ip_dst)
if (ipv6_group.is_multicast and ipv6_dst == ipv6_group) or\
(ipv6_dst == self.IPv6_LINK_SCOPE_ALL_NODES and ipv6_group == self.IPv6_ALL_ZEROS):
self.interface_state.receive_query(packet)
def receive_unknown_type(self, packet):
raise Exception("UNKNOWN MLD TYPE: " + str(packet.payload.get_mld_type()))
PKT_FUNCTIONS = {
MULTICAST_LISTENER_REPORT_TYPE: receive_multicast_listener_report,
MULTICAST_LISTENER_DONE_TYPE: receive_multicast_listener_done,
MULTICAST_LISTENER_QUERY_TYPE: receive_multicast_listener_query,
}
##################
def remove(self):
super().remove()
self.interface_state.remove()
import socket
import random
from pimdm.Interface import Interface
from pimdm.Packet.ReceivedPacket import ReceivedPacket
from pimdm import Main
import logging
import netifaces
import traceback
from pimdm.RWLock.RWLock import RWLockWrite
from pimdm.Packet.PacketPimHelloOptions import *
from pimdm.Packet.PacketPimHello import PacketPimHello
from pimdm.Packet.PacketPimHeader import PacketPimHeader
from pimdm.Packet.Packet import Packet
from pimdm.utils import HELLO_HOLD_TIME_TIMEOUT
from threading import Timer
from pimdm.tree.globals import REFRESH_INTERVAL
import socket
import netifaces
import logging
from pimdm.Interface import Interface
from pimdm.packet.ReceivedPacket import ReceivedPacket
from pimdm import Main
from pimdm.rwlock.RWLock import RWLockWrite
from pimdm.packet.PacketPimHelloOptions import *
from pimdm.packet.PacketPimHello import PacketPimHello
from pimdm.packet.PacketPimHeader import PacketPimHeader
from pimdm.packet.Packet import Packet
from pimdm.tree.globals import HELLO_HOLD_TIME_TIMEOUT, REFRESH_INTERVAL
class InterfacePim(Interface):
......@@ -83,18 +83,37 @@ class InterfacePim(Interface):
self.force_send_hello()
def get_ip(self):
"""
Get IP of this interface
"""
return self.ip_interface
def _receive(self, raw_bytes):
@staticmethod
def get_kernel():
"""
Get Kernel object
"""
return Main.kernel
def _receive(self, raw_bytes, ancdata, src_addr):
"""
Interface received a new control packet
"""
if raw_bytes:
packet = ReceivedPacket(raw_bytes, self)
self.PKT_FUNCTIONS[packet.payload.get_pim_type()](self, packet)
self.PKT_FUNCTIONS.get(packet.payload.get_pim_type(), InterfacePim.receive_unknown)(self, packet)
def send(self, data: bytes, group_ip: str=MCAST_GRP):
"""
Send a new packet destined to group_ip IP
"""
super().send(data=data, group_ip=group_ip)
#Random interval for initial Hello message on bootup or triggered Hello message to a rebooting neighbor
def force_send_hello(self):
"""
Force the transmission of a new Hello message
"""
if self.hello_timer is not None:
self.hello_timer.cancel()
......@@ -103,6 +122,10 @@ class InterfacePim(Interface):
self.hello_timer.start()
def send_hello(self):
"""
Send a new Hello message
Include in it the HelloHoldTime and GenerationID
"""
self.interface_logger.debug('Send Hello message')
self.hello_timer.cancel()
......@@ -125,6 +148,10 @@ class InterfacePim(Interface):
self.hello_timer.start()
def remove(self):
"""
Remove this interface
Clear all state
"""
self.hello_timer.cancel()
self.hello_timer = None
......@@ -136,17 +163,20 @@ class InterfacePim(Interface):
packet = Packet(payload=ph)
self.send(packet.bytes())
Main.kernel.interface_change_number_of_neighbors()
self.get_kernel().interface_change_number_of_neighbors()
super().remove()
def check_number_of_neighbors(self):
has_neighbors = len(self.neighbors) > 0
if has_neighbors != self._had_neighbors:
self._had_neighbors = has_neighbors
Main.kernel.interface_change_number_of_neighbors()
self.get_kernel().interface_change_number_of_neighbors()
def new_or_reset_neighbor(self, neighbor_ip):
Main.kernel.new_or_reset_neighbor(self.vif_index, neighbor_ip)
"""
React to new neighbor or restart of known neighbor
"""
self.get_kernel().new_or_reset_neighbor(self.vif_index, neighbor_ip)
'''
def add_neighbor(self, ip, random_number, hello_hold_time):
......@@ -160,27 +190,44 @@ class InterfacePim(Interface):
'''
def get_neighbors(self):
"""
Get list of known neighbors
"""
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
"""
Get specific neighbor by its IP
"""
with self.neighbors_lock.genRlock():
return self.neighbors.get(ip)
def remove_neighbor(self, ip):
"""
Remove known neighbor
"""
with self.neighbors_lock.genWlock():
del self.neighbors[ip]
self.interface_logger.debug("Remove neighbor: " + ip)
self.check_number_of_neighbors()
def set_state_refresh_capable(self, value):
"""
Change StateRefresh capability of interface
"""
self._state_refresh_capable = value
def is_state_refresh_enabled(self):
"""
Check if state refresh is enabled
"""
return self._state_refresh_capable
# check if Interface is StateRefreshCapable
def is_state_refresh_capable(self):
"""
Check StateRefresh capability of interface neighbors
"""
with self.neighbors_lock.genWlock():
if len(self.neighbors) == 0:
return False
......@@ -214,6 +261,9 @@ class InterfacePim(Interface):
# Recv packets
###########################################
def receive_hello(self, packet):
"""
Receive an Hello packet
"""
ip = packet.ip_header.ip_src
print("ip = ", ip)
options = packet.payload.payload.get_options()
......@@ -226,7 +276,6 @@ class InterfacePim(Interface):
state_refresh_capable = (21 in options)
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
if hello_hold_time == 0:
......@@ -244,17 +293,23 @@ class InterfacePim(Interface):
neighbor.receive_hello(generation_id, hello_hold_time, state_refresh_capable)
def receive_assert(self, packet):
"""
Receive an Assert packet
"""
pkt_assert = packet.payload.payload # type: PacketPimAssert
source = pkt_assert.source_address
group = pkt_assert.multicast_group_address
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_assert_msg(self.vif_index, packet)
self.get_kernel().get_routing_entry(source_group).recv_assert_msg(self.vif_index, packet)
except:
traceback.print_exc()
def receive_join_prune(self, packet):
"""
Receive Join/Prune packet
"""
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
join_prune_groups = pkt_join_prune.groups
......@@ -266,7 +321,7 @@ class InterfacePim(Interface):
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_join_msg(self.vif_index, packet)
self.get_kernel().get_routing_entry(source_group).recv_join_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
......@@ -274,12 +329,15 @@ class InterfacePim(Interface):
for source_address in pruned_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_prune_msg(self.vif_index, packet)
self.get_kernel().get_routing_entry(source_group).recv_prune_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
def receive_graft(self, packet):
"""
Receive Graft packet
"""
pkt_join_prune = packet.payload.payload # type: PacketPimGraft
join_prune_groups = pkt_join_prune.groups
......@@ -290,12 +348,15 @@ class InterfacePim(Interface):
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_msg(self.vif_index, packet)
self.get_kernel().get_routing_entry(source_group).recv_graft_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
def receive_graft_ack(self, packet):
"""
Receive an GraftAck packet
"""
pkt_join_prune = packet.payload.payload # type: PacketPimGraftAck
join_prune_groups = pkt_join_prune.groups
......@@ -306,12 +367,15 @@ class InterfacePim(Interface):
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(self.vif_index, packet)
self.get_kernel().get_routing_entry(source_group).recv_graft_ack_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
def receive_state_refresh(self, packet):
"""
Receive an StateRefresh packet
"""
if not self.is_state_refresh_enabled():
return
pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh
......@@ -320,10 +384,15 @@ class InterfacePim(Interface):
group = pkt_state_refresh.multicast_group_adress
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(self.vif_index, packet)
self.get_kernel().get_routing_entry(source_group).recv_state_refresh_msg(self.vif_index, packet)
except:
traceback.print_exc()
def receive_unknown(self, packet):
"""
Receive an unknown packet
"""
raise Exception("Unknown PIM type: " + str(packet.payload.get_pim_type()))
PKT_FUNCTIONS = {
0: receive_hello,
......
import socket
import random
import struct
import logging
import ipaddress
import netifaces
from pimdm import Main
from socket import if_nametoindex
from pimdm.Interface import Interface
from .InterfacePIM import InterfacePim
from pimdm.rwlock.RWLock import RWLockWrite
from pimdm.packet.ReceivedPacket import ReceivedPacket_v6
class InterfacePim6(InterfacePim):
MCAST_GRP = "ff02::d"
def __init__(self, interface_name: str, vif_index:int, state_refresh_capable:bool=False):
# generation id
self.generation_id = random.getrandbits(32)
# When PIM is enabled on an interface or when a router first starts, the Hello Timer (HT)
# MUST be set to random value between 0 and Triggered_Hello_Delay
self.hello_timer = None
# state refresh capable
self._state_refresh_capable = state_refresh_capable
self._neighbors_state_refresh_capable = False
# todo: lan delay enabled
self._lan_delay_enabled = False
# todo: propagation delay
self._propagation_delay = self.PROPAGATION_DELAY
# todo: override interval
self._override_interval = self.OVERRIDE_INTERNAL
# pim neighbors
self._had_neighbors = False
self.neighbors = {}
self.neighbors_lock = RWLockWrite()
self.interface_logger = logging.LoggerAdapter(InterfacePim.LOGGER, {'vif': vif_index, 'interfacename': interface_name})
# SOCKET
s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_PIM)
ip_interface = ""
for if_addr in netifaces.ifaddresses(interface_name)[netifaces.AF_INET6]:
ip_interface = if_addr["addr"]
if ipaddress.IPv6Address(if_addr['addr'].split("%")[0]).is_link_local:
ip_interface = if_addr['addr'].split("%")[0]
# bind to interface
s.bind(socket.getaddrinfo(if_addr['addr'], None, 0, socket.SOCK_RAW, 0, socket.AI_PASSIVE)[0][4])
break
self.ip_interface = ip_interface
# allow other sockets to bind this port too
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# explicitly join the multicast group on the interface specified
if_index = if_nametoindex(interface_name)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP,
socket.inet_pton(socket.AF_INET6, InterfacePim6.MCAST_GRP) + struct.pack('@I', if_index))
s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8'))
# set socket output interface
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, struct.pack('@I', if_index))
# set socket TTL to 1
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 1)
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_UNICAST_HOPS, 1)
# don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 0)
Interface.__init__(self, interface_name, s, s, vif_index)
Interface._enable(self)
self.force_send_hello()
@staticmethod
def get_kernel():
return Main.kernel_v6
def send(self, data: bytes, group_ip: str=MCAST_GRP):
super().send(data=data, group_ip=group_ip)
def _receive(self, raw_bytes, ancdata, src_addr):
if raw_bytes:
packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 103, self)
self.PKT_FUNCTIONS[packet.payload.get_pim_type()](self, packet)
import socket
import struct
from threading import RLock, Thread
import traceback
import ipaddress
import traceback
from socket import if_nametoindex
from threading import RLock, Thread
from abc import abstractmethod, ABCMeta
from pimdm.RWLock.RWLock import RWLockWrite
from pimdm import UnicastRouting, Main
from pimdm.rwlock.RWLock import RWLockWrite
from pimdm.InterfacePIM import InterfacePim
from pimdm.InterfaceMLD import InterfaceMLD
from pimdm.InterfaceIGMP import InterfaceIGMP
from pimdm.InterfacePIM import InterfacePim
from pimdm.InterfacePIM6 import InterfacePim6
from pimdm.tree.KernelEntry import KernelEntry
from pimdm import UnicastRouting, Main
class Kernel:
# MRT
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */
MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */
MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */
MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */
#MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */
#MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */
#MRT_MAX = (MRT_BASE + 11)
from pimdm.tree.KernelEntryInterface import KernelEntry4Interface, KernelEntry6Interface
class Kernel(metaclass=ABCMeta):
# Max Number of Virtual Interfaces
MAXVIFS = 32
# SIGNAL MSG TYPE
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
# Interface flags
VIFF_TUNNEL = 0x1 # IPIP tunnel
VIFF_SRCRT = 0x2 # NI
VIFF_REGISTER = 0x4 # register vif
VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface
def __init__(self):
def __init__(self, kernel_socket):
# Kernel is running
self.running = True
# KEY : interface_ip, VALUE : vif_index
self.vif_dic = {}
self.vif_index_to_name_dic = {}
self.vif_name_to_index_dic = {}
self.vif_index_to_name_dic = {} # KEY : vif_index, VALUE : interface_name
self.vif_name_to_index_dic = {} # KEY : interface_name, VALUE : vif_index
# KEY : source_ip, VALUE : {group_ip: KernelEntry}
self.routing = {}
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# MRT INIT
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_PIM, 0)
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ASSERT, 1)
self.socket = s
self.socket = kernel_socket
self.rwlock = RWLockWrite()
self.interface_lock = RLock()
......@@ -74,9 +40,9 @@ class Kernel:
# todo useless in PIM-DM... useful in PIM-SM
#self.create_virtual_interface("0.0.0.0", "pimreg", index=0, flags=Kernel.VIFF_REGISTER)
# interfaces being monitored by this process
self.pim_interface = {} # name: interface_pim
self.igmp_interface = {} # name: interface_igmp
self.membership_interface = {} # name: interface_igmp or interface_mld
# logs
self.interface_logger = Main.logger.getChild('KernelInterface')
......@@ -101,55 +67,43 @@ class Kernel:
struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */
};
'''
@abstractmethod
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0):
if type(ip_interface) is str:
ip_interface = socket.inet_aton(ip_interface)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface,
socket.inet_aton("0.0.0.0"))
with self.rwlock.genWlock():
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_dic[socket.inet_ntoa(ip_interface)] = index
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.new_interface(index)
self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index)
return index
raise NotImplementedError
def create_pim_interface(self, interface_name: str, state_refresh_capable:bool):
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
membership_interface = self.membership_interface.get(interface_name)
vif_already_exists = pim_interface or membership_interface
if pim_interface:
# already exists
pim_interface.set_state_refresh_capable(state_refresh_capable)
return
elif igmp_interface:
index = igmp_interface.vif_index
elif membership_interface:
index = membership_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.pim_interface:
pim_interface = InterfacePim(interface_name, index, state_refresh_capable)
pim_interface = self._create_pim_interface_object(interface_name, index, state_refresh_capable)
self.pim_interface[interface_name] = pim_interface
ip_interface = pim_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
def create_igmp_interface(self, interface_name: str):
@abstractmethod
def _create_pim_interface_object(self, interface_name, index, state_refresh_capable):
raise NotImplementedError
def create_membership_interface(self, interface_name: str):
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if igmp_interface:
membership_interface = self.membership_interface.get(interface_name)
vif_already_exists = pim_interface or membership_interface
if membership_interface:
# already exists
return
elif pim_interface:
......@@ -158,47 +112,222 @@ class Kernel:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.igmp_interface:
igmp_interface = InterfaceIGMP(interface_name, index)
self.igmp_interface[interface_name] = igmp_interface
if interface_name not in self.membership_interface:
igmp_interface = self._create_membership_interface_object(interface_name, index)
self.membership_interface[interface_name] = igmp_interface
ip_interface = igmp_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
@abstractmethod
def _create_membership_interface_object(self, interface_name, index):
raise NotImplementedError
def remove_interface(self, interface_name, igmp:bool=False, pim:bool=False):
def remove_interface(self, interface_name, membership: bool = False, pim: bool = False):
with self.interface_lock:
ip_interface = None
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
if (igmp and not igmp_interface) or (pim and not pim_interface) or (not igmp and not pim):
membership_interface = self.membership_interface.get(interface_name)
if (membership and not membership_interface) or (pim and not pim_interface) or (not membership and not pim):
return
if pim:
pim_interface = self.pim_interface.pop(interface_name)
ip_interface = pim_interface.ip_interface
pim_interface.remove()
elif igmp:
igmp_interface = self.igmp_interface.pop(interface_name)
ip_interface = igmp_interface.ip_interface
igmp_interface.remove()
elif membership:
membership_interface = self.membership_interface.pop(interface_name)
membership_interface.remove()
if not self.membership_interface.get(interface_name) and not self.pim_interface.get(interface_name):
self.remove_virtual_interface(interface_name)
@abstractmethod
def remove_virtual_interface(self, interface_name):
raise NotImplementedError
#############################################
# Manipulate multicast routing table
#############################################
@abstractmethod
def set_multicast_route(self, kernel_entry: KernelEntry):
raise NotImplementedError
@abstractmethod
def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index):
raise NotImplementedError
@abstractmethod
def remove_multicast_route(self, kernel_entry: KernelEntry):
raise NotImplementedError
@abstractmethod
def exit(self):
raise NotImplementedError
@abstractmethod
def handler(self):
raise NotImplementedError
def get_routing_entry(self, source_group: tuple, create_if_not_existent=True):
ip_src = source_group[0]
ip_dst = source_group[1]
with self.rwlock.genRlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
with self.rwlock.genWlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
elif create_if_not_existent:
kernel_entry = KernelEntry(ip_src, ip_dst, self._get_kernel_entry_interface())
if ip_src not in self.routing:
self.routing[ip_src] = {}
iif = UnicastRouting.check_rpf(ip_src)
self.set_flood_multicast_route(ip_src, ip_dst, iif)
self.routing[ip_src][ip_dst] = kernel_entry
return kernel_entry
else:
return None
@staticmethod
@abstractmethod
def _get_kernel_entry_interface():
pass
# notify KernelEntries about changes at the unicast routing table
def notify_unicast_changes(self, subnet):
with self.rwlock.genWlock():
for source_ip in list(self.routing.keys()):
source_ip_obj = ipaddress.ip_address(source_ip)
if source_ip_obj not in subnet:
continue
for group_ip in list(self.routing[source_ip].keys()):
self.routing[source_ip][group_ip].network_update()
# notify about changes at the interface (IP)
'''
def notify_interface_change(self, interface_name):
with self.interface_lock:
# check if interface was already added
if interface_name not in self.vif_name_to_index_dic:
return
print("trying to change ip")
pim_interface = self.pim_interface.get(interface_name)
if pim_interface:
old_ip = pim_interface.get_ip()
pim_interface.change_interface()
new_ip = pim_interface.get_ip()
if old_ip != new_ip:
self.vif_dic[new_ip] = self.vif_dic.pop(old_ip)
igmp_interface = self.igmp_interface.get(interface_name)
if igmp_interface:
igmp_interface.change_interface()
'''
# When interface changes number of neighbors verify if olist changes and prune/forward respectively
def interface_change_number_of_neighbors(self):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.change_at_number_of_neighbors()
# When new neighbor connects try to resend last state refresh msg (if AssertWinner)
def new_or_reset_neighbor(self, vif_index, neighbor_ip):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.new_or_reset_neighbor(vif_index, neighbor_ip)
class Kernel4(Kernel):
# MRT
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */
MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */
MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */
MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */
#MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */
#MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */
#MRT_MAX = (MRT_BASE + 11)
# Max Number of Virtual Interfaces
MAXVIFS = 32
# SIGNAL MSG TYPE
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
# Interface flags
VIFF_TUNNEL = 0x1 # IPIP tunnel
VIFF_SRCRT = 0x2 # NI
VIFF_REGISTER = 0x4 # register vif
VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface
def __init__(self):
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# MRT INIT
s.setsockopt(socket.IPPROTO_IP, self.MRT_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IP, self.MRT_PIM, 0)
s.setsockopt(socket.IPPROTO_IP, self.MRT_ASSERT, 1)
super().__init__(s)
'''
Structure to create/remove virtual interfaces
struct vifctl {
vifi_t vifc_vifi; /* Index of VIF */
unsigned char vifc_flags; /* VIFF_ flags */
unsigned char vifc_threshold; /* ttl limit */
unsigned int vifc_rate_limit; /* Rate limiter values (NI) */
union {
struct in_addr vifc_lcl_addr; /* Local interface address */
int vifc_lcl_ifindex; /* Local interface index */
};
struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */
};
'''
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0):
if type(ip_interface) is str:
ip_interface = socket.inet_aton(ip_interface)
if (not self.igmp_interface.get(interface_name) and not self.pim_interface.get(interface_name)):
self.remove_virtual_interface(ip_interface)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface,
socket.inet_aton("0.0.0.0"))
with self.rwlock.genWlock():
self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.new_interface(index)
def remove_virtual_interface(self, ip_interface):
self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index)
return index
def remove_virtual_interface(self, interface_name):
#with self.interface_lock:
index = self.vif_dic[ip_interface]
index = self.vif_name_to_index_dic.pop(interface_name, None)
struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_VIF, struct_vifctl)
self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DEL_VIF, struct_vifctl)
del self.vif_dic[ip_interface]
del self.vif_name_to_index_dic[self.vif_index_to_name_dic[index]]
interface_name = self.vif_index_to_name_dic.pop(index)
# alterar MFC's para colocar a 0 esta interface
# change MFC's to not forward traffic by this interface (set OIL to 0 for this interface)
with self.rwlock.genWlock():
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
......@@ -235,7 +364,7 @@ class Kernel:
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, kernel_entry.inbound_interface_index, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_MFC, struct_mfcctl)
def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index):
source_ip = socket.inet_aton(source_ip)
......@@ -250,7 +379,7 @@ class Kernel:
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_MFC, struct_mfcctl)
def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip)
......@@ -258,7 +387,7 @@ class Kernel:
outbound_interfaces_and_other_parameters = [0] + [0]*Kernel.MAXVIFS + [0]*4
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_MFC, struct_mfcctl)
self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DEL_MFC, struct_mfcctl)
self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip)
if len(self.routing[kernel_entry.source_ip]) == 0:
self.routing.pop(kernel_entry.source_ip)
......@@ -267,7 +396,7 @@ class Kernel:
self.running = False
# MRT DONE
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DONE, 1)
self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DONE, 1)
self.socket.close()
......@@ -304,10 +433,10 @@ class Kernel:
ip_src = socket.inet_ntoa(im_src)
ip_dst = socket.inet_ntoa(im_dst)
if im_msgtype == Kernel.IGMPMSG_NOCACHE:
if im_msgtype == self.IGMPMSG_NOCACHE:
print("IGMP NO CACHE")
self.igmpmsg_nocache_handler(ip_src, ip_dst, im_vif)
elif im_msgtype == Kernel.IGMPMSG_WRONGVIF:
elif im_msgtype == self.IGMPMSG_WRONGVIF:
print("WRONG VIF HANDLER")
self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif)
#elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT:
......@@ -338,73 +467,266 @@ class Kernel:
#kernel_entry.recv_data_msg(iif)
'''
@staticmethod
def _get_kernel_entry_interface():
return KernelEntry4Interface
def _create_pim_interface_object(self, interface_name, index, state_refresh_capable):
return InterfacePim(interface_name, index, state_refresh_capable)
def _create_membership_interface_object(self, interface_name, index):
return InterfaceIGMP(interface_name, index)
class Kernel6(Kernel):
# MRT6
MRT6_BASE = 200
MRT6_INIT = (MRT6_BASE) # /* Activate the kernel mroute code */
MRT6_DONE = (MRT6_BASE + 1) # /* Shutdown the kernel mroute */
MRT6_ADD_MIF = (MRT6_BASE + 2) # /* Add a virtual interface */
MRT6_DEL_MIF = (MRT6_BASE + 3) # /* Delete a virtual interface */
MRT6_ADD_MFC = (MRT6_BASE + 4) # /* Add a multicast forwarding entry */
MRT6_DEL_MFC = (MRT6_BASE + 5) # /* Delete a multicast forwarding entry */
MRT6_VERSION = (MRT6_BASE + 6) # /* Get the kernel multicast version */
MRT6_ASSERT = (MRT6_BASE + 7) # /* Activate PIM assert mode */
MRT6_PIM = (MRT6_BASE + 8) # /* enable PIM code */
MRT6_TABLE = (MRT6_BASE + 9) # /* Specify mroute table ID */
MRT6_ADD_MFC_PROXY = (MRT6_BASE + 10) # /* Add a (*,*|G) mfc entry */
MRT6_DEL_MFC_PROXY = (MRT6_BASE + 11) # /* Del a (*,*|G) mfc entry */
MRT6_MAX = (MRT6_BASE + 11)
# define SIOCGETMIFCNT_IN6 SIOCPROTOPRIVATE /* IP protocol privates */
# define SIOCGETSGCNT_IN6 (SIOCPROTOPRIVATE+1)
# define SIOCGETRPF (SIOCPROTOPRIVATE+2)
# Max Number of Virtual Interfaces
MAXVIFS = 32
def get_routing_entry(self, source_group: tuple, create_if_not_existent=True):
ip_src = source_group[0]
ip_dst = source_group[1]
with self.rwlock.genRlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
# SIGNAL MSG TYPE
MRT6MSG_NOCACHE = 1
MRT6MSG_WRONGMIF = 2
MRT6MSG_WHOLEPKT = 3 # /* used for use level encap */
# Interface flags
MIFF_REGISTER = 0x1 # /* register vif */
def __init__(self):
s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_ICMPV6)
# MRT INIT
s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_PIM, 0)
s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ASSERT, 1)
super().__init__(s)
'''
Structure to create/remove multicast interfaces
struct mif6ctl {
mifi_t mif6c_mifi; /* Index of MIF */
unsigned char mif6c_flags; /* MIFF_ flags */
unsigned char vifc_threshold; /* ttl limit */
__u16 mif6c_pifi; /* the index of the physical IF */
unsigned int vifc_rate_limit; /* Rate limiter values (NI) */
};
'''
def create_virtual_interface(self, ip_interface, interface_name: str, index, flags=0x0):
physical_if_index = if_nametoindex(interface_name)
struct_mrt_add_vif = struct.pack("HBBHI", index, flags, 1, physical_if_index, 0)
with self.rwlock.genWlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
elif create_if_not_existent:
kernel_entry = KernelEntry(ip_src, ip_dst)
if ip_src not in self.routing:
self.routing[ip_src] = {}
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MIF, struct_mrt_add_vif)
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
iif = UnicastRouting.check_rpf(ip_src)
self.set_flood_multicast_route(ip_src, ip_dst, iif)
self.routing[ip_src][ip_dst] = kernel_entry
return kernel_entry
else:
return None
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.new_interface(index)
# notify KernelEntries about changes at the unicast routing table
def notify_unicast_changes(self, subnet):
self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index)
return index
def remove_virtual_interface(self, interface_name):
# with self.interface_lock:
mif_index = self.vif_name_to_index_dic.pop(interface_name, None)
interface_name = self.vif_index_to_name_dic.pop(mif_index)
physical_if_index = if_nametoindex(interface_name)
struct_vifctl = struct.pack("HBBHI", mif_index, 0, 0, physical_if_index, 0)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DEL_MIF, struct_vifctl)
# alterar MFC's para colocar a 0 esta interface
with self.rwlock.genWlock():
for source_ip in list(self.routing.keys()):
source_ip_obj = ipaddress.ip_address(source_ip)
if source_ip_obj not in subnet:
continue
for group_ip in list(self.routing[source_ip].keys()):
self.routing[source_ip][group_ip].network_update()
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.remove_interface(mif_index)
self.interface_logger.debug('Remove virtual interface: %s -> %d', interface_name, mif_index)
# notify about changes at the interface (IP)
'''
def notify_interface_change(self, interface_name):
with self.interface_lock:
# check if interface was already added
if interface_name not in self.vif_name_to_index_dic:
return
/* Cache manipulation structures for mrouted and PIMd */
typedef __u32 if_mask;
typedef struct if_set {
if_mask ifs_bits[__KERNEL_DIV_ROUND_UP(IF_SETSIZE, NIFBITS)];
} if_set;
struct mf6cctl {
struct sockaddr_in6 mf6cc_origin; /* Origin of mcast */
struct sockaddr_in6 mf6cc_mcastgrp; /* Group in question */
mifi_t mf6cc_parent; /* Where it arrived */
struct if_set mf6cc_ifset; /* Where it is going */
};
'''
def set_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.source_ip)
sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0)
group_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.group_ip)
sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0)
print("trying to change ip")
pim_interface = self.pim_interface.get(interface_name)
if pim_interface:
old_ip = pim_interface.get_ip()
pim_interface.change_interface()
new_ip = pim_interface.get_ip()
if old_ip != new_ip:
self.vif_dic[new_ip] = self.vif_dic.pop(old_ip)
outbound_interfaces = kernel_entry.get_outbound_interfaces_indexes()
if len(outbound_interfaces) != 8:
raise Exception
# outbound_interfaces_and_other_parameters = list(kernel_entry.outbound_interfaces) + [0]*4
# outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4
outgoing_interface_list = outbound_interfaces
# outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group,
kernel_entry.inbound_interface_index,
*outgoing_interface_list)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MFC, struct_mf6cctl)
def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index):
source_ip = socket.inet_pton(socket.AF_INET6, source_ip)
sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0)
group_ip = socket.inet_pton(socket.AF_INET6, group_ip)
sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0)
outbound_interfaces = [255] * 8
outbound_interfaces[inbound_interface_index // 32] = 0xFFFFFFFF & ~(1 << (inbound_interface_index % 32))
# outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters)
struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group,
inbound_interface_index, *outbound_interfaces)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MFC, struct_mf6cctl)
def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.source_ip)
sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0)
group_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.group_ip)
sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0)
outbound_interfaces = [0] * 8
# struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters)
struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group, 0,
*outbound_interfaces)
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DEL_MFC, struct_mf6cctl)
self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip)
if len(self.routing[kernel_entry.source_ip]) == 0:
self.routing.pop(kernel_entry.source_ip)
def exit(self):
self.running = False
# MRT DONE
self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DONE, 1)
self.socket.close()
igmp_interface = self.igmp_interface.get(interface_name)
if igmp_interface:
igmp_interface.change_interface()
'''
/*
* Structure used to communicate from kernel to multicast router.
* We'll overlay the structure onto an MLD header (not an IPv6 heder like igmpmsg{}
* used for IPv4 implementation). This is because this structure will be passed via an
* IPv6 raw socket, on which an application will only receiver the payload i.e the data after
* the IPv6 header and all the extension headers. (See section 3 of RFC 3542)
*/
# When interface changes number of neighbors verify if olist changes and prune/forward respectively
def interface_change_number_of_neighbors(self):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.change_at_number_of_neighbors()
struct mrt6msg {
__u8 im6_mbz; /* must be zero */
__u8 im6_msgtype; /* what type of message */
__u16 im6_mif; /* mif rec'd on */
__u32 im6_pad; /* padding for 64 bit arch */
struct in6_addr im6_src, im6_dst;
};
# When new neighbor connects try to resend last state refresh msg (if AssertWinner)
def new_or_reset_neighbor(self, vif_index, neighbor_ip):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.new_or_reset_neighbor(vif_index, neighbor_ip)
/* ip6mr netlink cache report attributes */
enum {
IP6MRA_CREPORT_UNSPEC,
IP6MRA_CREPORT_MSGTYPE,
IP6MRA_CREPORT_MIF_ID,
IP6MRA_CREPORT_SRC_ADDR,
IP6MRA_CREPORT_DST_ADDR,
IP6MRA_CREPORT_PKT,
__IP6MRA_CREPORT_MAX
};
'''
def handler(self):
while self.running:
try:
msg = self.socket.recv(500)
if len(msg) < 40:
continue
(im6_mbz, im6_msgtype, im6_mif, _, im6_src, im6_dst) = struct.unpack("B B H I 16s 16s", msg[:40])
# print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
if im6_mbz != 0:
continue
print(im6_mbz)
print(im6_msgtype)
print(im6_mif)
print(socket.inet_ntop(socket.AF_INET6, im6_src))
print(socket.inet_ntop(socket.AF_INET6, im6_dst))
# print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
ip_src = socket.inet_ntop(socket.AF_INET6, im6_src)
ip_dst = socket.inet_ntop(socket.AF_INET6, im6_dst)
if im6_msgtype == self.MRT6MSG_NOCACHE:
print("MRT6 NO CACHE")
self.msg_nocache_handler(ip_src, ip_dst, im6_mif)
elif im6_msgtype == self.MRT6MSG_WRONGMIF:
print("WRONG MIF HANDLER")
self.msg_wrongvif_handler(ip_src, ip_dst, im6_mif)
# elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT:
# print("IGMP_WHOLEPKT")
# self.igmpmsg_wholepacket_handler(ip_src, ip_dst)
else:
raise Exception
except Exception:
traceback.print_exc()
continue
# receive multicast (S,G) packet and multicast routing table has no (S,G) entry
def msg_nocache_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
# receive multicast (S,G) packet in a outbound_interface
def msg_wrongvif_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
''' useless in PIM-DM... useful in PIM-SM
def msg_wholepacket_handler(self, ip_src, ip_dst):
#kernel_entry = self.routing[(ip_src, ip_dst)]
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg()
#kernel_entry.recv_data_msg(iif)
'''
@staticmethod
def _get_kernel_entry_interface():
return KernelEntry6Interface
def _create_pim_interface_object(self, interface_name, index, state_refresh_capable):
return InterfacePim6(interface_name, index, state_refresh_capable)
def _create_membership_interface_object(self, interface_name, index):
return InterfaceMLD(interface_name, index)
import sys
import time
import netifaces
import logging, logging.handlers
import logging
import logging.handlers
from prettytable import PrettyTable
from pimdm.TestLogger import RootFilter
from pimdm import UnicastRouting
from pimdm import UnicastRouting
from pimdm.TestLogger import RootFilter
interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces
interfaces_v6 = {} # pim v6 interfaces
mld_interfaces = {} # mld interfaces
kernel = None
kernel_v6 = None
unicast_routing = None
logger = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False):
def add_pim_interface(interface_name, state_refresh_capable: bool = False, ipv4=True, ipv6=False):
if interface_name == "*":
for interface_name in netifaces.interfaces():
add_pim_interface(interface_name, ipv4, ipv6)
return
if ipv4 and kernel is not None:
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
if ipv6 and kernel_v6 is not None:
kernel_v6.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
def add_membership_interface(interface_name, ipv4=True, ipv6=False):
if interface_name == "*":
for interface_name in netifaces.interfaces():
add_membership_interface(interface_name, ipv4, ipv6)
return
if ipv4 and kernel is not None:
kernel.create_membership_interface(interface_name=interface_name)
if ipv6 and kernel_v6 is not None:
kernel_v6.create_membership_interface(interface_name=interface_name)
def add_igmp_interface(interface_name):
kernel.create_igmp_interface(interface_name=interface_name)
'''
def add_interface(interface_name, pim=False, igmp=False):
#if pim is True and interface_name not in interfaces:
# interface = InterfacePim(interface_name)
# interfaces[interface_name] = interface
# interface.create_virtual_interface()
#if igmp is True and interface_name not in igmp_interfaces:
# interface = InterfaceIGMP(interface_name)
# igmp_interfaces[interface_name] = interface
kernel.create_interface(interface_name=interface_name, pim=pim, igmp=igmp)
#if pim:
# interfaces[interface_name] = kernel.pim_interface[interface_name]
#if igmp:
# igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name]
'''
def remove_interface(interface_name, pim=False, igmp=False):
#if pim is True and ((interface_name in interfaces) or interface_name == "*"):
# if interface_name == "*":
# interface_name_list = list(interfaces.keys())
# else:
# interface_name_list = [interface_name]
# for if_name in interface_name_list:
# interface_obj = interfaces.pop(if_name)
# interface_obj.remove()
# #interfaces[if_name].remove()
# #del interfaces[if_name]
# print("removido interface")
# print(interfaces)
#if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"):
# if interface_name == "*":
# interface_name_list = list(igmp_interfaces.keys())
# else:
# interface_name_list = [interface_name]
# for if_name in interface_name_list:
# igmp_interfaces[if_name].remove()
# del igmp_interfaces[if_name]
# print("removido interface")
# print(igmp_interfaces)
kernel.remove_interface(interface_name, pim=pim, igmp=igmp)
def list_neighbors():
def remove_interface(interface_name, pim=False, membership=False, ipv4=True, ipv6=False):
if interface_name == "*":
for interface_name in netifaces.interfaces():
remove_interface(interface_name, pim, membership, ipv4, ipv6)
return
if ipv4 and kernel is not None:
kernel.remove_interface(interface_name, pim=pim, membership=membership)
if ipv6 and kernel_v6 is not None:
kernel_v6.remove_interface(interface_name, pim=pim, membership=membership)
def list_neighbors(ipv4=False, ipv6=False):
if ipv4:
interfaces_list = interfaces.values()
elif ipv6:
interfaces_list = interfaces_v6.values()
else:
return "Unknown IP family"
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
check_time = time.time()
for interface in interfaces_list:
......@@ -76,38 +74,62 @@ def list_neighbors():
print(t)
return str(t)
def list_enabled_interfaces():
global interfaces
def list_enabled_interfaces(ipv4=False, ipv6=False):
if ipv4:
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'State Refresh Enabled', 'IGMP State'])
family = netifaces.AF_INET
pim_interfaces = interfaces
membership_interfaces = igmp_interfaces
elif ipv6:
t = PrettyTable(['Interface', 'IP', 'PIM/MLD Enabled', 'State Refresh Enabled', 'MLD State'])
family = netifaces.AF_INET6
pim_interfaces = interfaces_v6
membership_interfaces = mld_interfaces
else:
return "Unknown IP family"
for interface in netifaces.interfaces():
try:
# TODO: fix same interface with multiple ips
ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr']
pim_enabled = interface in interfaces
igmp_enabled = interface in igmp_interfaces
enabled = str(pim_enabled) + "/" + str(igmp_enabled)
ip = netifaces.ifaddresses(interface)[family][0]['addr']
pim_enabled = interface in pim_interfaces
membership_enabled = interface in membership_interfaces
enabled = str(pim_enabled) + "/" + str(membership_enabled)
state_refresh_enabled = "-"
if pim_enabled:
state_refresh_enabled = interfaces[interface].is_state_refresh_enabled()
igmp_state = "-"
if igmp_enabled:
igmp_state = igmp_interfaces[interface].interface_state.print_state()
t.add_row([interface, ip, enabled, state_refresh_enabled, igmp_state])
state_refresh_enabled = pim_interfaces[interface].is_state_refresh_enabled()
membership_state = "-"
if membership_enabled:
membership_state = membership_interfaces[interface].interface_state.print_state()
t.add_row([interface, ip, enabled, state_refresh_enabled, membership_state])
except Exception:
continue
print(t)
return str(t)
def list_state():
state_text = "IGMP State:\n" + list_igmp_state() + "\n\n\n\n" + "Multicast Routing State:\n" + list_routing_state()
def list_state(ipv4=True, ipv6=False):
state_text = ""
if ipv4:
state_text = "IGMP State:\n{}\n\n\n\nMulticast Routing State:\n{}"
elif ipv6:
state_text = "MLD State:\n{}\n\n\n\nMulticast Routing State:\n{}"
else:
return state_text
return state_text.format(list_membership_state(ipv4, ipv6), list_routing_state(ipv4, ipv6))
def list_igmp_state():
def list_membership_state(ipv4=True, ipv6=False):
t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState'])
for (interface_name, interface_obj) in list(igmp_interfaces.items()):
if ipv4:
membership_interfaces = igmp_interfaces
elif ipv6:
membership_interfaces = mld_interfaces
else:
membership_interfaces = {}
for (interface_name, interface_obj) in list(membership_interfaces.items()):
interface_state = interface_obj.interface_state
state_txt = interface_state.print_state()
print(interface_state.group_state.items())
......@@ -119,12 +141,22 @@ def list_igmp_state():
return str(t)
def list_routing_state():
def list_routing_state(ipv4=False, ipv6=False):
if ipv4:
routes = kernel.routing.values()
vif_indexes = kernel.vif_index_to_name_dic.keys()
dict_index_to_name = kernel.vif_index_to_name_dic
elif ipv6:
routes = kernel_v6.routing.values()
vif_indexes = kernel_v6.vif_index_to_name_dic.keys()
dict_index_to_name = kernel_v6.vif_index_to_name_dic
else:
raise Exception("Unknown IP family")
routing_entries = []
for a in list(kernel.routing.values()):
for a in list(routes):
for b in list(a.values()):
routing_entries.append(b)
vif_indexes = kernel.vif_index_to_name_dic.keys()
t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', 'LocalMembership', "Is Forwarding?"])
for entry in routing_entries:
......@@ -134,7 +166,7 @@ def list_routing_state():
for index in vif_indexes:
interface_state = entry.interface_state[index]
interface_name = kernel.vif_index_to_name_dic[index]
interface_name = dict_index_to_name[index]
local_membership = type(interface_state._local_membership_state).__name__
try:
assert_state = type(interface_state._assert_state).__name__
......@@ -154,8 +186,11 @@ def list_routing_state():
def stop():
remove_interface("*", pim=True, igmp=True)
remove_interface("*", pim=True, membership=True, ipv4=True, ipv6=True)
if kernel is not None:
kernel.exit()
if kernel_v6 is not None:
kernel_v6.exit()
unicast_routing.stop()
......@@ -169,6 +204,22 @@ def test(router_name, server_logger_ip):
logger.addHandler(socketHandler)
def enable_ipv6_kernel():
"""
Function to explicitly enable IPv6 Multicast Routing stack.
This may not be enabled by default due to some old linux kernels that may not have IPv6 stack or do not have
IPv6 multicast routing support
"""
global kernel_v6
from pimdm.Kernel import Kernel6
kernel_v6 = Kernel6()
global interfaces_v6
global mld_interfaces
interfaces_v6 = kernel_v6.pim_interface
mld_interfaces = kernel_v6.membership_interface
def main():
# logging
global logger
......@@ -177,8 +228,8 @@ def main():
logger.addHandler(logging.StreamHandler(sys.stdout))
global kernel
from pimdm.Kernel import Kernel
kernel = Kernel()
from pimdm.Kernel import Kernel4
kernel = Kernel4()
global unicast_routing
unicast_routing = UnicastRouting.UnicastRouting()
......@@ -186,4 +237,9 @@ def main():
global interfaces
global igmp_interfaces
interfaces = kernel.pim_interface
igmp_interfaces = kernel.igmp_interface
igmp_interfaces = kernel.membership_interface
try:
enable_ipv6_kernel()
except:
pass
from threading import Timer
import time
from pimdm.utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING
from threading import Lock, RLock
import logging
from threading import Timer
from threading import Lock, RLock
from pimdm.tree.globals import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from pimdm.InterfacePIM import InterfacePim
......@@ -10,7 +12,6 @@ if TYPE_CHECKING:
class Neighbor:
LOGGER = logging.getLogger('pim.Interface.Neighbor')
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int,
state_refresh_capable: bool):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
......@@ -37,7 +38,6 @@ class Neighbor:
self.tree_interface_nlt_subscribers = []
self.tree_interface_nlt_subscribers_lock = RLock()
def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time
if self.neighbor_liveness_timer is not None:
......@@ -85,11 +85,9 @@ class Neighbor:
for tree_if in self.tree_interface_nlt_subscribers:
tree_if.assert_winner_nlt_expires()
def reset(self):
self.contact_interface.new_or_reset_neighbor(self.ip)
def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable):
self.neighbor_logger.debug('Receive Hello message with HelloHoldTime: ' + str(hello_hold_time) +
'; GenerationID: ' + str(generation_id) + '; StateRefreshCapable: ' +
......
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIpHeader:
IP_HDR = "! BBH HH BBH 4s 4s"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
self.version = ver
self.hdr_length = hdr_len
self.ttl = ttl
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, data)
ver = (verhlen & 0xf0) >> 4
hlen = (verhlen & 0x0f) * 4
'''
"VER": ver,
"HLEN": hlen,
"TOS": tos,
"IPLEN": iplen,
"IPID": ipid,
"FRAG": frag,
"TTL": ttl,
"PROTO": proto,
"CKSUM": cksum,
"SRC": socket.inet_ntoa(src),
"DST": socket.inet_ntoa(dst)
'''
src_ip = socket.inet_ntoa(src)
dst_ip = socket.inet_ntoa(dst)
return PacketIpHeader(ver, hlen, ttl, proto, src_ip, dst_ip)
#!/usr/bin/env python3
from pimdm.Daemon.Daemon import Daemon
from pimdm import Main
import _pickle as pickle
import socket
import sys
import os
import sys
import socket
import argparse
import traceback
import _pickle as pickle
from pimdm import Main
from pimdm.daemon.Daemon import Daemon
VERSION = "1.1"
VERSION = "1.0.4.2"
def client_socket(data_to_send):
# Create a UDS socket
......@@ -58,26 +60,36 @@ class MyDaemon(Daemon):
print(sys.stderr, 'sending data back to the client')
print(pickle.loads(data))
args = pickle.loads(data)
if 'ipv4' not in args and 'ipv6' not in args or not (args.ipv4 or args.ipv6):
args.ipv4 = True
args.ipv6 = False
if 'list_interfaces' in args and args.list_interfaces:
connection.sendall(pickle.dumps(Main.list_enabled_interfaces()))
connection.sendall(pickle.dumps(Main.list_enabled_interfaces(ipv4=args.ipv4, ipv6=args.ipv6)))
elif 'list_neighbors' in args and args.list_neighbors:
connection.sendall(pickle.dumps(Main.list_neighbors()))
connection.sendall(pickle.dumps(Main.list_neighbors(ipv4=args.ipv4, ipv6=args.ipv6)))
elif 'list_state' in args and args.list_state:
connection.sendall(pickle.dumps(Main.list_state()))
connection.sendall(pickle.dumps(Main.list_state(ipv4=args.ipv4, ipv6=args.ipv6)))
elif 'add_interface' in args and args.add_interface:
Main.add_pim_interface(args.add_interface[0], False)
Main.add_pim_interface(args.add_interface[0], False, ipv4=args.ipv4, ipv6=args.ipv6)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_sr' in args and args.add_interface_sr:
Main.add_pim_interface(args.add_interface_sr[0], True)
Main.add_pim_interface(args.add_interface_sr[0], True, ipv4=args.ipv4, ipv6=args.ipv6)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_igmp' in args and args.add_interface_igmp:
Main.add_igmp_interface(args.add_interface_igmp[0])
Main.add_membership_interface(interface_name=args.add_interface_igmp[0], ipv4=True, ipv6=False)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_mld' in args and args.add_interface_mld:
Main.add_membership_interface(interface_name=args.add_interface_mld[0], ipv4=False, ipv6=True)
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface' in args and args.remove_interface:
Main.remove_interface(args.remove_interface[0], pim=True)
Main.remove_interface(args.remove_interface[0], pim=True, ipv4=args.ipv4, ipv6=args.ipv6)
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface_igmp' in args and args.remove_interface_igmp:
Main.remove_interface(args.remove_interface_igmp[0], igmp=True)
Main.remove_interface(args.remove_interface_igmp[0], membership=True, ipv4=True, ipv6=False)
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface_mld' in args and args.remove_interface_mld:
Main.remove_interface(args.remove_interface_mld[0], membership=True, ipv4=False, ipv6=True)
connection.shutdown(socket.SHUT_RDWR)
elif 'stop' in args and args.stop:
Main.stop()
......@@ -102,18 +114,30 @@ def main():
group.add_argument("-start", "--start", action="store_true", default=False, help="Start PIM")
group.add_argument("-stop", "--stop", action="store_true", default=False, help="Stop PIM")
group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM")
group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces")
group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors")
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface")
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled")
group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces. "
"Use -4 or -6 to specify IPv4 or IPv6 interfaces.")
group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors. "
"Use -4 or -6 to specify IPv4 or IPv6 PIM neighbors.")
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List IGMP/MLD and PIM-DM state machines."
" Use -4 or -6 to specify IPv4 or IPv6 state respectively.")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table. "
"Use -4 or -6 to specify IPv4 or IPv6 multicast routing table.")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface. "
"Use -4 or -6 to specify IPv4 or IPv6 interface.")
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled. "
"Use -4 or -6 to specify IPv4 or IPv6 interface.")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface")
group.add_argument("-aimld", "--add_interface_mld", nargs=1, metavar='INTERFACE_NAME', help="Add MLD interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface. "
"Use -4 or -6 to specify IPv4 or IPv6 interface.")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
group.add_argument("-rimld", "--remove_interface_mld", nargs=1, metavar='INTERFACE_NAME', help="Remove MLD interface")
group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)")
group.add_argument("-t", "--test", nargs=2, metavar=('ROUTER_NAME', 'SERVER_LOG_IP'), help="Tester... send log information to SERVER_LOG_IP. Set the router name to ROUTER_NAME")
group.add_argument("--version", action='version', version='%(prog)s ' + VERSION)
group_ipversion = parser.add_mutually_exclusive_group(required=False)
group_ipversion.add_argument("-4", "--ipv4", action="store_true", default=False, help="Setting for IPv4")
group_ipversion.add_argument("-6", "--ipv6", action="store_true", default=False, help="Setting for IPv6")
args = parser.parse_args()
#print(parser.parse_args())
......@@ -137,7 +161,10 @@ def main():
os.system("tail -f /var/log/pimdm/stdout")
sys.exit(0)
elif args.multicast_routes:
if args.ipv4 or not args.ipv6:
os.system("ip mroute show")
elif args.ipv6:
os.system("ip -6 mroute show")
sys.exit(0)
elif not daemon.is_running():
print("PIM-DM is not running")
......
import socket
import ipaddress
from pyroute2 import IPDB
from threading import RLock
from pimdm.utils import if_indextoname
from socket import if_indextoname
from pyroute2 import IPDB
def get_route(ip_dst: str):
......@@ -48,27 +47,34 @@ class UnicastRouting(object):
@staticmethod
def get_route(ip_dst: str):
ip_bytes = socket.inet_aton(ip_dst)
ip_int = int.from_bytes(ip_bytes, byteorder='big')
ip_version = ipaddress.ip_address(ip_dst).version
if ip_version == 4:
family = socket.AF_INET
full_mask = 32
elif ip_version == 6:
family = socket.AF_INET6
full_mask = 128
else:
raise Exception("Unknown IP version")
info = None
with UnicastRouting.lock:
ipdb = UnicastRouting.ipdb # type:IPDB
for mask_len in range(32, 0, -1):
ip_bytes = (ip_int & (0xFFFFFFFF << (32 - mask_len))).to_bytes(4, "big")
ip_dst = socket.inet_ntoa(ip_bytes) + "/" + str(mask_len)
print(ip_dst)
if ip_dst in ipdb.routes:
for mask_len in range(full_mask, 0, -1):
dst_network = str(ipaddress.ip_interface(ip_dst + "/" + str(mask_len)).network)
print(dst_network)
if dst_network in ipdb.routes:
print(info)
if ipdb.routes[ip_dst]['ipdb_scope'] != 'gc':
info = ipdb.routes[ip_dst]
if ipdb.routes[{'dst': dst_network, 'family': family}]['ipdb_scope'] != 'gc':
info = ipdb.routes[dst_network]
break
else:
continue
if not info:
print("0.0.0.0/0")
print("0.0.0.0/0 or ::/0")
if "default" in ipdb.routes:
info = ipdb.routes["default"]
info = ipdb.routes[{'dst': 'default', 'family': family}]
print(info)
return info
......@@ -85,13 +91,16 @@ class UnicastRouting(object):
oif = unicast_route.get("oif")
next_hop = unicast_route["gateway"]
multipaths = unicast_route["multipath"]
# prefsrc = unicast_route.get("prefsrc")
#prefsrc = unicast_route.get("prefsrc")
# rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop
#rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop
rpf_node = next_hop if next_hop is not None else ip_dst
if ipaddress.ip_address(ip_dst).version == 4:
highest_ip = ipaddress.ip_address("0.0.0.0")
else:
highest_ip = ipaddress.ip_address("::")
for m in multipaths:
if m["gateway"] is None:
if m.get("gateway", None) is None:
oif = m.get('oif')
rpf_node = ip_dst
break
......@@ -107,14 +116,22 @@ class UnicastRouting(object):
interface_name = None if oif is None else if_indextoname(int(oif))
from pimdm import Main
if ipaddress.ip_address(ip_dst).version == 4:
rpf_if = Main.kernel.vif_name_to_index_dic.get(interface_name)
else:
rpf_if = Main.kernel_v6.vif_name_to_index_dic.get(interface_name)
return (metric_administrative_distance, metric_cost, rpf_node, rpf_if, mask)
@staticmethod
def unicast_changes(ipdb, msg, action):
"""
Kernel notified about a change
Verify the type of change and recheck all trees if necessary
"""
print("unicast change?")
print(action)
UnicastRouting.lock.acquire()
family = msg['family']
if action == "RTM_NEWROUTE" or action == "RTM_DELROUTE":
print(ipdb.routes)
mask_len = msg["dst_len"]
......@@ -126,8 +143,10 @@ class UnicastRouting(object):
if key == "RTA_DST":
network_address = value
break
if network_address is None:
if network_address is None and family == socket.AF_INET:
network_address = "0.0.0.0"
elif network_address is None and family == socket.AF_INET6:
network_address = "::"
print(network_address)
print(mask_len)
print(network_address + "/" + str(mask_len))
......@@ -135,7 +154,10 @@ class UnicastRouting(object):
print(str(subnet))
UnicastRouting.lock.release()
from pimdm import Main
if family == socket.AF_INET:
Main.kernel.notify_unicast_changes(subnet)
elif family == socket.AF_INET6:
Main.kernel_v6.notify_unicast_changes(subnet)
'''
elif action == "RTM_NEWADDR" or action == "RTM_DELADDR":
print(action)
......@@ -154,7 +176,7 @@ class UnicastRouting(object):
import traceback
traceback.print_exc()
pass
bnet = ipaddress.ip_network("0.0.0.0/0")
subnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet)
elif action == "RTM_NEWLINK" or action == "RTM_DELLINK":
attrs = msg["attrs"]
......@@ -172,7 +194,7 @@ class UnicastRouting(object):
print(if_name + ": " + operation)
UnicastRouting.lock.release()
if operation == 'DOWN':
Main.kernel.remove_interface(if_name, igmp=True, pim=True)
Main.kernel.remove_interface(if_name, membership=True, pim=True)
subnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet)
'''
......@@ -180,6 +202,10 @@ class UnicastRouting(object):
UnicastRouting.lock.release()
def stop(self):
"""
No longer monitor unicast changes....
Invoked whenever the protocol is stopped
"""
if self._ipdb:
self._ipdb.release()
if UnicastRouting.ipdb:
......
......@@ -129,13 +129,13 @@ class GroupState(object):
with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=True)
interface_state.notify_membership(has_members=True)
def notify_routing_remove(self):
with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
interface_state.notify_membership(has_members=False)
def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
......@@ -155,5 +155,5 @@ class GroupState(object):
self.clear_timer()
self.clear_v1_host_timer()
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
interface_state.notify_membership(has_members=False)
del self.multicast_interface_state[:]
from threading import Timer
import logging
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.Packet.ReceivedPacket import ReceivedPacket
from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from pimdm.utils import TYPE_CHECKING
from pimdm.RWLock.RWLock import RWLockWrite
from pimdm.rwlock.RWLock import RWLockWrite
from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState
......
from ipaddress import IPv4Address
from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.Packet.ReceivedPacket import ReceivedPacket
from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import NoMembersPresent, MembersPresent, CheckingMembership
if TYPE_CHECKING:
......
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, LastMemberQueryInterval
from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent
......
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, LastMemberQueryInterval
from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent
......
......@@ -3,8 +3,8 @@ from ipaddress import IPv4Address
from pimdm.utils import TYPE_CHECKING
from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval
from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.Packet.ReceivedPacket import ReceivedPacket
from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent
if TYPE_CHECKING:
......
import logging
from threading import Lock
from threading import Timer
from pimdm.utils import TYPE_CHECKING
from .wrapper import NoListenersPresent
from .mld_globals import MulticastListenerInterval, LastListenerQueryInterval
if TYPE_CHECKING:
from .RouterState import RouterState
class GroupState(object):
LOGGER = logging.getLogger('pim.mld.RouterState.GroupState')
def __init__(self, router_state: 'RouterState', group_ip: str):
#logger
extra_dict_logger = router_state.router_state_logger.extra.copy()
extra_dict_logger['tree'] = '(*,' + group_ip + ')'
self.group_state_logger = logging.LoggerAdapter(GroupState.LOGGER, extra_dict_logger)
#timers and state
self.router_state = router_state
self.group_ip = group_ip
self.state = NoListenersPresent
self.timer = None
self.retransmit_timer = None
# lock
self.lock = Lock()
# KernelEntry's instances to notify change of igmp state
self.multicast_interface_state = []
self.multicast_interface_state_lock = Lock()
def print_state(self):
return self.state.print_state()
###########################################
# Set state
###########################################
def set_state(self, state):
self.state = state
self.group_state_logger.debug("change membership state to: " + state.print_state())
###########################################
# Set timers
###########################################
def set_timer(self, alternative: bool=False, max_response_time: int=None):
self.clear_timer()
if not alternative:
time = MulticastListenerInterval
else:
time = self.router_state.interface_state.get_group_membership_time(max_response_time)
timer = Timer(time, self.group_membership_timeout)
timer.start()
self.timer = timer
def clear_timer(self):
if self.timer is not None:
self.timer.cancel()
def set_retransmit_timer(self):
self.clear_retransmit_timer()
retransmit_timer = Timer(LastListenerQueryInterval, self.retransmit_timeout)
retransmit_timer.start()
self.retransmit_timer = retransmit_timer
def clear_retransmit_timer(self):
if self.retransmit_timer is not None:
self.retransmit_timer.cancel()
###########################################
# Get group state from specific interface state
###########################################
def get_interface_group_state(self):
return self.state.get_state(self.router_state)
###########################################
# Timer timeout
###########################################
def group_membership_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_timeout(self)
def retransmit_timeout(self):
with self.lock:
self.get_interface_group_state().retransmit_timeout(self)
###########################################
# Receive Packets
###########################################
def receive_report(self):
with self.lock:
self.get_interface_group_state().receive_report(self)
def receive_done(self):
with self.lock:
self.get_interface_group_state().receive_done(self)
def receive_group_specific_query(self, max_response_time: int):
with self.lock:
self.get_interface_group_state().receive_group_specific_query(self, max_response_time)
###########################################
# Notify Routing
###########################################
def notify_routing_add(self):
with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_membership(has_members=True)
def notify_routing_remove(self):
with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_membership(has_members=False)
def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.append(kernel_entry)
return self.has_members()
def remove_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.remove(kernel_entry)
def has_members(self):
return self.state is not NoListenersPresent
def remove(self):
with self.multicast_interface_state_lock:
self.clear_retransmit_timer()
self.clear_timer()
for interface_state in self.multicast_interface_state:
interface_state.notify_membership(has_members=False)
del self.multicast_interface_state[:]
import logging
from threading import Timer
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from pimdm.utils import TYPE_CHECKING
from pimdm.rwlock.RWLock import RWLockWrite
from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState
from .mld_globals import QueryResponseInterval, QueryInterval, OtherQuerierPresentInterval, MULTICAST_LISTENER_QUERY_TYPE
if TYPE_CHECKING:
from pimdm.InterfaceMLD import InterfaceMLD
class RouterState(object):
ROUTER_STATE_LOGGER = logging.getLogger('pim.mld.RouterState')
def __init__(self, interface: 'InterfaceMLD'):
#logger
logger_extra = dict()
logger_extra['vif'] = interface.vif_index
logger_extra['interfacename'] = interface.interface_name
self.router_state_logger = logging.LoggerAdapter(RouterState.ROUTER_STATE_LOGGER, logger_extra)
# interface of the router connected to the network
self.interface = interface
# state of the router (Querier/NonQuerier)
self.interface_state = Querier
# state of each group
# Key: GroupIPAddress, Value: GroupState object
self.group_state = {}
self.group_state_lock = RWLockWrite()
# send general query
packet = PacketMLDHeader(type=MULTICAST_LISTENER_QUERY_TYPE, max_resp_delay=QueryResponseInterval*1000)
self.interface.send(packet.bytes())
# set initial general query timer
timer = Timer(QueryInterval, self.general_query_timeout)
timer.start()
self.general_query_timer = timer
# present timer
self.other_querier_present_timer = None
# Send packet via interface
def send(self, data: bytes, address: str):
self.interface.send(data, address)
############################################
# interface_state methods
############################################
def print_state(self):
return self.interface_state.state_name()
def set_general_query_timer(self):
self.clear_general_query_timer()
general_query_timer = Timer(QueryInterval, self.general_query_timeout)
general_query_timer.start()
self.general_query_timer = general_query_timer
def clear_general_query_timer(self):
if self.general_query_timer is not None:
self.general_query_timer.cancel()
def set_other_querier_present_timer(self):
self.clear_other_querier_present_timer()
other_querier_present_timer = Timer(OtherQuerierPresentInterval, self.other_querier_present_timeout)
other_querier_present_timer.start()
self.other_querier_present_timer = other_querier_present_timer
def clear_other_querier_present_timer(self):
if self.other_querier_present_timer is not None:
self.other_querier_present_timer.cancel()
def general_query_timeout(self):
self.interface_state.general_query_timeout(self)
def other_querier_present_timeout(self):
self.interface_state.other_querier_present_timeout(self)
def change_interface_state(self, querier: bool):
if querier:
self.interface_state = Querier
self.router_state_logger.debug('change querier state to -> Querier')
else:
self.interface_state = NonQuerier
self.router_state_logger.debug('change querier state to -> NonQuerier')
############################################
# group state methods
############################################
def get_group_state(self, group_ip):
with self.group_state_lock.genRlock():
if group_ip in self.group_state:
return self.group_state[group_ip]
with self.group_state_lock.genWlock():
if group_ip in self.group_state:
group_state = self.group_state[group_ip]
else:
group_state = GroupState(self, group_ip)
self.group_state[group_ip] = group_state
return group_state
def receive_report(self, packet: ReceivedPacket):
mld_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v2_membership_report()
self.get_group_state(mld_group).receive_report()
def receive_done(self, packet: ReceivedPacket):
mld_group = packet.payload.group_address
#if igmp_group in self.group_state:
# self.group_state[igmp_group].receive_leave_group()
self.get_group_state(mld_group).receive_done()
def receive_query(self, packet: ReceivedPacket):
self.interface_state.receive_query(self, packet)
mld_group = packet.payload.group_address
# process group specific query
if mld_group != "::" and mld_group in self.group_state:
#if igmp_group != "0.0.0.0":
max_response_time = packet.payload.max_resp_delay
#self.group_state[igmp_group].receive_group_specific_query(max_response_time)
self.get_group_state(mld_group).receive_group_specific_query(max_response_time)
def remove(self):
for group in self.group_state.values():
group.remove()
#MLD timers (in seconds)
RobustnessVariable = 2
QueryInterval = 125
QueryResponseInterval = 10
MulticastListenerInterval = (RobustnessVariable * QueryInterval) + (QueryResponseInterval)
OtherQuerierPresentInterval = (RobustnessVariable * QueryInterval) + 0.5 * QueryResponseInterval
StartupQueryInterval = (1/4) * QueryInterval
StartupQueryCount = RobustnessVariable
LastListenerQueryInterval = 1
LastListenerQueryCount = RobustnessVariable
UnsolicitedReportInterval = 10
# MLD msg type
MULTICAST_LISTENER_QUERY_TYPE = 130
MULTICAST_LISTENER_REPORT_TYPE = 131
MULTICAST_LISTENER_DONE_TYPE = 132
\ No newline at end of file
from pimdm.utils import TYPE_CHECKING
from ..wrapper import NoListenersPresent
from ..wrapper import ListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_report')
group_state.set_timer()
group_state.set_state(ListenersPresent)
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: group_membership_timeout')
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingListeners: retransmit_timeout')
# do nothing
return
from pimdm.utils import TYPE_CHECKING
from ..wrapper import NoListenersPresent
from ..wrapper import CheckingListeners
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_report')
group_state.set_timer()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_group_specific_query')
group_state.set_timer(alternative=True, max_response_time=max_response_time)
group_state.set_state(CheckingListeners)
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: group_membership_timeout')
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier ListenersPresent: retransmit_timeout')
# do nothing
return
from pimdm.utils import TYPE_CHECKING
from ..wrapper import ListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_report')
group_state.set_timer()
group_state.set_state(ListenersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: group_membership_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoListenersPresent: retransmit_timeout')
# do nothing
return
from ipaddress import IPv6Address
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import QueryResponseInterval, LastListenerQueryCount
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import NoListenersPresent, ListenersPresent, CheckingListeners
if TYPE_CHECKING:
from ..RouterState import RouterState
class NonQuerier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: general_query_timeout')
# do nothing
return
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: other_querier_present_timeout')
#change state to Querier
router_state.change_interface_state(querier=True)
# send general query
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=QueryResponseInterval*1000)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('NonQuerier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv6Address(source_ip) >= IPv6Address(router_state.interface.get_ip()):
return
# reset other present querier timer
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Non Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return (max_response_time/1000.0) * LastListenerQueryCount
# State
@staticmethod
def get_checking_listeners_state():
return CheckingListeners
@staticmethod
def get_listeners_present_state():
return ListenersPresent
@staticmethod
def get_no_listeners_present_state():
return NoListenersPresent
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import LastListenerQueryInterval
from ..wrapper import ListenersPresent, NoListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: receive_report')
group_state.set_timer()
group_state.clear_retransmit_timer()
group_state.set_state(ListenersPresent)
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier CheckingListeners: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: group_membership_timeout')
group_state.clear_retransmit_timer()
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingListeners: retransmit_timeout')
group_addr = group_state.group_ip
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=LastListenerQueryInterval*1000, group_address=group_addr)
group_state.router_state.send(data=packet.bytes(), address=group_addr)
group_state.set_retransmit_timer()
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import LastListenerQueryInterval
from ..wrapper import CheckingListeners, NoListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: receive_report')
group_state.set_timer()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: receive_done')
group_ip = group_state.group_ip
group_state.set_timer(alternative=True)
group_state.set_retransmit_timer()
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=LastListenerQueryInterval*1000, group_address=group_ip)
group_state.router_state.send(data=packet.bytes(), address=group_ip)
group_state.set_state(CheckingListeners)
def receive_group_specific_query(group_state: 'GroupState', max_response_time):
group_state.group_state_logger.debug('Querier ListenersPresent: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: group_membership_timeout')
group_state.set_state(NoListenersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier ListenersPresent: retransmit_timeout')
# do nothing
return
from pimdm.utils import TYPE_CHECKING
from ..wrapper import ListenersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def receive_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: receive_report')
group_state.set_timer()
group_state.set_state(ListenersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_done(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: receive_done')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier NoListenersPresent: receive_group_specific_query')
# do nothing
return
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: group_membership_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoListenersPresent: retransmit_timeout')
# do nothing
return
from ipaddress import IPv6Address
from pimdm.utils import TYPE_CHECKING
from ..mld_globals import LastListenerQueryInterval, LastListenerQueryCount, QueryResponseInterval
from pimdm.packet.PacketMLDHeader import PacketMLDHeader
from pimdm.packet.ReceivedPacket import ReceivedPacket
from . import CheckingListeners, ListenersPresent, NoListenersPresent
if TYPE_CHECKING:
from ..RouterState import RouterState
class Querier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: general_query_timeout')
# send general query
packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE,
max_resp_delay=QueryResponseInterval*1000)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('Querier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv6Address(source_ip) >= IPv6Address(router_state.interface.get_ip()):
return
# if source ip of membership query lower than the ip of the received interface => change state
# change state of interface
# Querier -> Non Querier
router_state.change_interface_state(querier=False)
# set other present querier timer
router_state.clear_general_query_timer()
router_state.set_other_querier_present_timer()
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: other_querier_present_timeout')
# do nothing
return
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return LastListenerQueryInterval * LastListenerQueryCount
# State
@staticmethod
def get_checking_listeners_state():
return CheckingListeners
@staticmethod
def get_listeners_present_state():
return ListenersPresent
@staticmethod
def get_no_listeners_present_state():
return NoListenersPresent
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_checking_listeners_state()
def print_state():
return "CheckingListeners"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_listeners_present_state()
def print_state():
return "ListenersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_no_listeners_present_state()
def print_state():
return "NoListenersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
import struct
import socket
class PacketIpHeader:
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version|
+-+-+-+-+
"""
IP_HDR = "! B"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len):
self.version = ver
self.hdr_length = hdr_len
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, ) = struct.unpack(PacketIpHeader.IP_HDR, data[:PacketIpHeader.IP_HDR_LEN])
ver = (verhlen & 0xF0) >> 4
print("ver:", ver)
return PACKET_HEADER.get(ver).parse_bytes(data)
class PacketIpv4Header(PacketIpHeader):
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
IP_HDR = "! BBH HH BBH 4s 4s"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
super().__init__(ver, hdr_len)
self.ttl = ttl
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpv4Header.IP_HDR, data[:PacketIpv4Header.IP_HDR_LEN])
ver = (verhlen & 0xf0) >> 4
hlen = (verhlen & 0x0f) * 4
'''
"VER": ver,
"HLEN": hlen,
"TOS": tos,
"IPLEN": iplen,
"IPID": ipid,
"FRAG": frag,
"TTL": ttl,
"PROTO": proto,
"CKSUM": cksum,
"SRC": socket.inet_ntoa(src),
"DST": socket.inet_ntoa(dst)
'''
src_ip = socket.inet_ntoa(src)
dst_ip = socket.inet_ntoa(dst)
return PacketIpv4Header(ver, hlen, ttl, proto, src_ip, dst_ip)
class PacketIpv6Header(PacketIpHeader):
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| Traffic Class | Flow Label |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Payload Length | Next Header | Hop Limit |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Source Address +
| |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Destination Address +
| |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
IP6_HDR = "! I HBB 16s 16s"
IP6_HDR_LEN = struct.calcsize(IP6_HDR)
def __init__(self, ver, next_header, hop_limit, ip_src, ip_dst):
# TODO: confirm hdr_length in case of multiple options/headers
super().__init__(ver, PacketIpv6Header.IP6_HDR_LEN)
self.next_header = next_header
self.hop_limit = hop_limit
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return PacketIpv6Header.IP6_HDR_LEN
@staticmethod
def parse_bytes(data: bytes):
(ver_tc_fl, _, next_header, hop_limit, src, dst) = \
struct.unpack(PacketIpv6Header.IP6_HDR, data[:PacketIpv6Header.IP6_HDR_LEN])
ver = (ver_tc_fl & 0xf0000000) >> 28
#tc = (ver_tc_fl & 0x0ff00000) >> 20
#fl = (ver_tc_fl & 0x000fffff)
'''
"VER": ver,
"TRAFFIC CLASS": tc,
"FLOW LABEL": fl,
"PAYLOAD LEN": payload_length,
"NEXT HEADER": next_header,
"HOP LIMIT": hop_limit,
"SRC": socket.inet_atop(socket.AF_INET6, src),
"DST": socket.inet_atop(socket.AF_INET6, dst)
'''
src_ip = socket.inet_ntop(socket.AF_INET6, src)
dst_ip = socket.inet_ntop(socket.AF_INET6, dst)
return PacketIpv6Header(ver, next_header, hop_limit, src_ip, dst_ip)
PACKET_HEADER = {
4: PacketIpv4Header,
6: PacketIpv6Header,
}
import struct
import socket
from .PacketPayload import PacketPayload
"""
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Code | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Maximum Response Delay | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Multicast Address +
| |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
class PacketMLDHeader(PacketPayload):
MLD_TYPE = 58
MLD_HDR = "! BB H H H 16s"
MLD_HDR_LEN = struct.calcsize(MLD_HDR)
MULTICAST_LISTENER_QUERY_TYPE = 130
MULTICAST_LISTENER_REPORT_TYPE = 131
MULTICAST_LISTENER_DONE_TYPE = 132
def __init__(self, type: int, max_resp_delay: int, group_address: str = "::"):
# todo check type
self.type = type
self.max_resp_delay = max_resp_delay
self.group_address = group_address
def get_mld_type(self):
return self.type
def bytes(self) -> bytes:
# obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketMLDHeader.MLD_HDR, self.type, 0, 0, self.max_resp_delay, 0,
socket.inet_pton(socket.AF_INET6, self.group_address))
#mld_checksum = checksum(msg_without_chcksum)
#msg = msg_without_chcksum[0:2] + struct.pack("! H", mld_checksum) + msg_without_chcksum[4:]
# checksum handled by linux kernel
return msg_without_chcksum
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
mld_hdr = data[0:PacketMLDHeader.MLD_HDR_LEN]
if len(mld_hdr) < PacketMLDHeader.MLD_HDR_LEN:
raise Exception("MLD packet length is lower than expected")
(mld_type, _, _, max_resp_delay, _, group_address) = struct.unpack(PacketMLDHeader.MLD_HDR, mld_hdr)
# checksum is handled by linux kernel
mld_hdr = mld_hdr[PacketMLDHeader.MLD_HDR_LEN:]
group_address = socket.inet_ntop(socket.AF_INET6, group_address)
pkt = PacketMLDHeader(mld_type, max_resp_delay, group_address)
return pkt
......@@ -55,7 +55,7 @@ class PacketPimEncodedGroupAddress:
elif version == 6:
return (PacketPimEncodedGroupAddress.IPV6_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
raise Exception("Unknown address family")
def __len__(self):
version = ipaddress.ip_address(self.group_address).version
......@@ -64,7 +64,7 @@ class PacketPimEncodedGroupAddress:
elif version == 6:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6
else:
raise Exception
raise Exception("Unknown address family")
@staticmethod
def parse_bytes(data: bytes):
......@@ -72,13 +72,14 @@ class PacketPimEncodedGroupAddress:
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS, data_without_group_addr)
data_group_addr = data[PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV4_HDR, data_group_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV6_HDR, data_group_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
else:
raise Exception("Unknown address family")
if encoding != 0:
print("unknown encoding")
......
......@@ -57,7 +57,7 @@ class PacketPimEncodedSourceAddress:
elif version == 6:
return (PacketPimEncodedSourceAddress.IPV6_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
raise Exception("Unknown address family")
def __len__(self):
version = ipaddress.ip_address(self.source_address).version
......@@ -66,7 +66,7 @@ class PacketPimEncodedSourceAddress:
elif version == 6:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
raise Exception("Unknown address family")
@staticmethod
def parse_bytes(data: bytes):
......@@ -74,13 +74,14 @@ class PacketPimEncodedSourceAddress:
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS, data_without_source_addr)
data_source_addr = data[PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV4_HDR, data_source_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV6_HDR, data_source_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
else:
raise Exception("Unknown address family")
if encoding != 0:
print("unknown encoding")
......
......@@ -46,7 +46,7 @@ class PacketPimEncodedUnicastAddress:
elif version == 6:
return (PacketPimEncodedUnicastAddress.IPV6_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
raise Exception("Unknown address family")
def __len__(self):
version = ipaddress.ip_address(self.unicast_address).version
......@@ -55,7 +55,7 @@ class PacketPimEncodedUnicastAddress:
elif version == 6:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
raise Exception("Unknown address family")
@staticmethod
def parse_bytes(data: bytes):
......@@ -69,6 +69,8 @@ class PacketPimEncodedUnicastAddress:
elif addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV6_HDR, data_unicast_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
else:
raise Exception("Unknown address family")
if encoding != 0:
print("unknown encoding")
......
import socket
from .Packet import Packet
from .PacketIpHeader import PacketIpHeader
from .PacketIGMPHeader import PacketIGMPHeader
from .PacketPimHeader import PacketPimHeader
from .PacketMLDHeader import PacketMLDHeader
from .PacketIGMPHeader import PacketIGMPHeader
from .PacketIpHeader import PacketIpv4Header, PacketIpv6Header
from pimdm.utils import TYPE_CHECKING
if TYPE_CHECKING:
from pimdm.Interface import Interface
......@@ -13,13 +15,32 @@ class ReceivedPacket(Packet):
def __init__(self, raw_packet: bytes, interface: 'Interface'):
self.interface = interface
# Parse ao packet e preencher objeto Packet
packet_ip_hdr = raw_packet[:PacketIpHeader.IP_HDR_LEN]
ip_header = PacketIpHeader.parse_bytes(packet_ip_hdr)
# Parse packet and fill Packet super class
ip_header = PacketIpv4Header.parse_bytes(raw_packet)
protocol_number = ip_header.proto
packet_without_ip_hdr = raw_packet[ip_header.hdr_length:]
payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr)
super().__init__(ip_header=ip_header, payload=payload)
class ReceivedPacket_v6(Packet):
# choose payload protocol class based on ip protocol number
payload_protocol_v6 = {58: PacketMLDHeader, 103: PacketPimHeader}
def __init__(self, raw_packet: bytes, ancdata: list, src_addr: str, next_header: int, interface: 'Interface'):
self.interface = interface
# Parse packet and fill Packet super class
dst_addr = "::"
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.IPPROTO_IPV6 and cmsg_type == socket.IPV6_PKTINFO:
dst_addr = socket.inet_ntop(socket.AF_INET6, cmsg_data[:16])
break
src_addr = src_addr[0].split("%")[0]
ipv6_packet = PacketIpv6Header(ver=6, hop_limit=1, next_header=next_header, ip_src=src_addr, ip_dst=dst_addr)
payload = ReceivedPacket_v6.payload_protocol_v6[next_header].parse_bytes(raw_packet)
super().__init__(ip_header=ipv6_packet, payload=payload)
import logging
from time import time
from threading import Lock, RLock
from pimdm import UnicastRouting
from .metric import AssertMetric
from .tree_if_upstream import TreeInterfaceUpstream
from .tree_if_downstream import TreeInterfaceDownstream
from .tree_interface import TreeInterface
from threading import Lock, RLock
from .metric import AssertMetric
from pimdm import UnicastRouting, Main
from time import time
import logging
class KernelEntry:
TREE_TIMEOUT = 180
KERNEL_LOGGER = logging.getLogger('pim.KernelEntry')
def __init__(self, source_ip: str, group_ip: str):
self.kernel_entry_logger = logging.LoggerAdapter(KernelEntry.KERNEL_LOGGER, {'tree': '(' + source_ip + ',' + group_ip + ')'})
def __init__(self, source_ip: str, group_ip: str, kernel_entry_interface):
self.kernel_entry_logger = logging.LoggerAdapter(KernelEntry.KERNEL_LOGGER,
{'tree': '(' + source_ip + ',' + group_ip + ')'})
self.kernel_entry_logger.debug('Create KernelEntry')
self.source_ip = source_ip
self.group_ip = group_ip
self._kernel_entry_interface = kernel_entry_interface
# OBTAIN UNICAST ROUTING INFORMATION###################################################
(metric_administrative_distance, metric_cost, rpf_node, root_if, mask) = \
UnicastRouting.get_unicast_info(source_ip)
......@@ -38,7 +42,7 @@ class KernelEntry:
self.interface_state = {} # type: Dict[int, TreeInterface]
with self.CHANGE_STATE_LOCK:
for i in Main.kernel.vif_index_to_name_dic.keys():
for i in self.get_kernel().vif_index_to_name_dic.keys():
try:
if i == self.inbound_interface_index:
self.interface_state[i] = TreeInterfaceUpstream(self, i)
......@@ -55,22 +59,31 @@ class KernelEntry:
print('Tree created')
def get_inbound_interface_index(self):
"""
Get VIF of root interface of this tree
"""
return self.inbound_interface_index
def get_outbound_interfaces_indexes(self):
outbound_indexes = [0] * Main.kernel.MAXVIFS
for (index, state) in self.interface_state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
"""
Get OIL of this tree
"""
return self._kernel_entry_interface.get_outbound_interfaces_indexes(self)
################################################
# Receive (S,G) data packets or control packets
################################################
def recv_data_msg(self, index):
"""
Receive data packet regarding this tree in interface with VIF index
"""
print("recv data")
self.interface_state[index].recv_data_msg()
def recv_assert_msg(self, index, packet):
"""
Receive assert packet regarding this tree in interface with VIF index
"""
print("recv assert")
pkt_assert = packet.payload.payload
metric = pkt_assert.metric
......@@ -81,28 +94,43 @@ class KernelEntry:
self.interface_state[index].recv_assert_msg(received_metric)
def recv_prune_msg(self, index, packet):
"""
Receive Prune packet regarding this tree in interface with VIF index
"""
print("recv prune msg")
holdtime = packet.payload.payload.hold_time
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
self.interface_state[index].recv_prune_msg(upstream_neighbor_address=upstream_neighbor_address, holdtime=holdtime)
def recv_join_msg(self, index, packet):
"""
Receive Join packet regarding this tree in interface with VIF index
"""
print("recv join msg")
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
self.interface_state[index].recv_join_msg(upstream_neighbor_address)
def recv_graft_msg(self, index, packet):
"""
Receive Graft packet regarding this tree in interface with VIF index
"""
print("recv graft msg")
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
source_ip = packet.ip_header.ip_src
self.interface_state[index].recv_graft_msg(upstream_neighbor_address, source_ip)
def recv_graft_ack_msg(self, index, packet):
"""
Receive GraftAck packet regarding this tree in interface with VIF index
"""
print("recv graft ack msg")
source_ip = packet.ip_header.ip_src
self.interface_state[index].recv_graft_ack_msg(source_ip)
def recv_state_refresh_msg(self, index, packet):
"""
Receive StateRefresh packet regarding this tree in interface with VIF index
"""
print("recv state refresh msg")
source_of_state_refresh = packet.ip_header.ip_src
......@@ -129,11 +157,13 @@ class KernelEntry:
self.forward_state_refresh_msg(packet.payload.payload)
################################################
# Send state refresh msg
################################################
def forward_state_refresh_msg(self, state_refresh_packet):
"""
Forward StateRefresh packet through all interfaces
"""
for interface in self.interface_state.values():
interface.send_state_refresh(state_refresh_packet)
......@@ -142,6 +172,9 @@ class KernelEntry:
# Unicast Changes to RPF
###############################################################
def network_update(self):
"""
Unicast routing table suffered an update and this tree might be affected by it
"""
# TODO TALVEZ OUTRO LOCK PARA BLOQUEAR ENTRADA DE PACOTES
with self.CHANGE_STATE_LOCK:
......@@ -184,24 +217,34 @@ class KernelEntry:
self.rpf_node = rpf_node
self.interface_state[self.inbound_interface_index].change_on_unicast_routing()
# check if add/removal of neighbors from interface afects olist and forward/prune state of interface
def change_at_number_of_neighbors(self):
"""
Check if modification of number of neighbors causes changes to OIL and interest of interface
"""
with self.CHANGE_STATE_LOCK:
self.change()
self.evaluate_olist_change()
def new_or_reset_neighbor(self, if_index, neighbor_ip):
"""
An interface identified by if_index has a new neighbor
"""
# todo maybe lock de interfaces
self.interface_state[if_index].new_or_reset_neighbor(neighbor_ip)
def is_olist_null(self):
"""
Check if olist is null
"""
for interface in self.interface_state.values():
if interface.is_forwarding():
return False
return True
def evaluate_olist_change(self):
"""
React to changes on the olist
"""
with self._lock_test2:
is_olist_null = self.is_olist_null()
......@@ -214,33 +257,74 @@ class KernelEntry:
self._was_olist_null = is_olist_null
def get_source(self):
"""
Get source IP of multicast source
"""
return self.source_ip
def get_group(self):
"""
Get group IP of multicast tree
"""
return self.group_ip
def change(self):
"""
Trigger an update on the multicast routing table
"""
with self._multicast_change:
Main.kernel.set_multicast_route(self)
self.get_kernel().set_multicast_route(self)
def delete(self):
"""
Remove kernel entry
"""
with self._multicast_change:
for state in self.interface_state.values():
state.delete()
Main.kernel.remove_multicast_route(self)
self.get_kernel().remove_multicast_route(self)
def get_interface_name(self, interface_id):
"""
Get interface name of interface identified by interface_id
"""
return self._kernel_entry_interface.get_interface_name(interface_id)
def get_interface(self, interface_id):
"""
Get PIM interface
"""
return self._kernel_entry_interface.get_interface(self, interface_id)
def get_membership_interface(self, interface_id):
"""
Get IGMP/MLD interface
"""
return self._kernel_entry_interface.get_membership_interface(self, interface_id)
def get_kernel(self):
"""
Get kernel
"""
return self._kernel_entry_interface.get_kernel()
######################################
# Interface change
#######################################
def new_interface(self, index):
"""
React to a new interface that was added and in which a tree was already built
"""
with self.CHANGE_STATE_LOCK:
self.interface_state[index] = TreeInterfaceDownstream(self, index)
self.change()
self.evaluate_olist_change()
def remove_interface(self, index):
"""
React to removal of an interface of a tree that was already built
"""
with self.CHANGE_STATE_LOCK:
#check if removed interface is root interface
if self.inbound_interface_index == index:
......
from pimdm import Main
from abc import abstractmethod, ABCMeta
class KernelEntryInterface(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def get_outbound_interfaces_indexes(kernel_tree):
"""
Get OIL of this tree
"""
pass
@staticmethod
@abstractmethod
def get_interface_name(interface_id):
"""
Get name of interface from vif id
"""
pass
@staticmethod
@abstractmethod
def get_interface(kernel_tree, interface_id):
"""
Get PIM interface from interface id
"""
pass
@staticmethod
@abstractmethod
def get_membership_interface(kernel_tree, interface_id):
"""
Get IGMP/MLD interface from interface id
"""
pass
@staticmethod
@abstractmethod
def get_kernel():
"""
Get kernel
"""
pass
class KernelEntry4Interface(KernelEntryInterface):
@staticmethod
def get_outbound_interfaces_indexes(kernel_tree):
"""
Get OIL of this tree
"""
outbound_indexes = [0] * Main.kernel.MAXVIFS
for (index, state) in kernel_tree.interface_state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
@staticmethod
def get_interface_name(interface_id):
"""
Get name of interface from vif id
"""
return Main.kernel.vif_index_to_name_dic[interface_id]
@staticmethod
def get_interface(kernel_tree, interface_id):
"""
Get PIM interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.interfaces.get(interface_name, None)
@staticmethod
def get_membership_interface(kernel_tree, interface_id):
"""
Get IGMP interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.igmp_interfaces.get(interface_name, None) # type: InterfaceIGMP
@staticmethod
def get_kernel():
"""
Get kernel
"""
return Main.kernel
class KernelEntry6Interface(KernelEntryInterface):
@staticmethod
def get_outbound_interfaces_indexes(kernel_tree):
"""
Get OIL of this tree
"""
outbound_indexes = [0] * 8
for (index, state) in kernel_tree.interface_state.items():
outbound_indexes[index // 32] |= state.is_forwarding() << (index % 32)
return outbound_indexes
@staticmethod
def get_interface_name(interface_id):
"""
Get name of interface from vif id
"""
return Main.kernel_v6.vif_index_to_name_dic[interface_id]
@staticmethod
def get_interface(kernel_tree, interface_id):
"""
Get PIM interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.interfaces_v6.get(interface_name, None)
@staticmethod
def get_membership_interface(kernel_tree, interface_id):
"""
Get MLD interface from interface id
"""
interface_name = kernel_tree.get_interface_name(interface_id)
return Main.mld_interfaces.get(interface_name, None) # type: InterfaceMLD
@staticmethod
def get_kernel():
"""
Get kernel
"""
return Main.kernel_v6
......@@ -112,12 +112,10 @@ class AssertStateABC(metaclass=ABCMeta):
"""
raise NotImplementedError()
def _sendAssert_setAT(interface: "TreeInterfaceDownstream"):
interface.set_assert_timer(pim_globals.ASSERT_TIME)
interface.send_assert()
# Override
def __str__(self) -> str:
return "AssertSM:" + self.__class__.__name__
......@@ -289,7 +287,6 @@ class WinnerState(AssertStateABC):
return "Winner"
class LoserState(AssertStateABC):
'''
I am Assert Loser (L)
......@@ -370,6 +367,7 @@ class LoserState(AssertStateABC):
def __str__(self) -> str:
return "Loser"
class AssertState():
NoInfo = NoInfoState()
Winner = WinnerState()
......
import subprocess
import struct
import socket
import ipaddress
import subprocess
from ctypes import create_string_buffer, addressof
SO_ATTACH_FILTER = 26
ETH_P_IP = 0x0800 # Internet Protocol packet
ETH_P_IPV6 = 0x86DD # IPv6 over bluebook
SO_RCVBUFFORCE = 33
def get_s_g_bpf_filter_code(source, group, interface_name):
#cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group)
ip_source_version = ipaddress.ip_address(source).version
ip_group_version = ipaddress.ip_address(source).version
if ip_source_version == ip_group_version == 4:
# cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group)
cmd = "tcpdump -ddd \"(ip proto not 2) and host %s and dst %s\"" % (source, group)
protocol = ETH_P_IP
elif ip_source_version == ip_group_version == 6:
# TODO: allow ICMPv6 echo request/echo response to be considered multicast packets
cmd = "tcpdump -ddd \"(ip6 proto not 58) and host %s and dst %s\"" % (source, group)
protocol = ETH_P_IPV6
else:
raise Exception("Unknown IP family")
result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
bpf_filter = b''
......@@ -28,10 +42,10 @@ def get_s_g_bpf_filter_code(source, group, interface_name):
# Create listening socket with filters
s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, ETH_P_IP)
s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, protocol)
s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
# todo pequeno ajuste (tamanho de buffer pequeno para o caso de trafego em rajadas):
#s.setsockopt(socket.SOL_SOCKET, SO_RCVBUFFORCE, 1)
s.bind((interface_name, ETH_P_IP))
s.bind((interface_name, protocol))
return s
......@@ -7,4 +7,8 @@ REFRESH_INTERVAL = 60 # State Refresh Interval
SOURCE_LIFETIME = 210
T_LIMIT = 210
HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF
HELLO_HOLD_TIME = 160
HELLO_HOLD_TIME_TIMEOUT = 0
ASSERT_CANCEL_METRIC = 0xFFFFFFFF
\ No newline at end of file
from abc import ABCMeta, abstractstaticmethod
from abc import ABCMeta, abstractmethod
class OriginatorStateABC(metaclass=ABCMeta):
@abstractstaticmethod
@staticmethod
@abstractmethod
def recvDataMsgFromSource(tree):
pass
@abstractstaticmethod
@staticmethod
@abstractmethod
def SRTexpires(tree):
pass
@abstractstaticmethod
@staticmethod
@abstractmethod
def SATexpires(tree):
pass
@abstractstaticmethod
@staticmethod
@abstractmethod
def SourceNotConnected(tree):
pass
......
'''
Created on Jul 16, 2015
@author: alex
'''
from threading import Timer
from pimdm.CustomTimer.RemainingTimer import RemainingTimer
from .assert_ import AssertState
from pimdm.custom_timer.RemainingTimer import RemainingTimer
from .assert_state import AssertState
from .downstream_prune import DownstreamState, DownstreamStateABS
from .tree_interface import TreeInterface
from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh
from pimdm.Packet.Packet import Packet
from pimdm.Packet.PacketPimHeader import PacketPimHeader
from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh
from pimdm.packet.Packet import Packet
from pimdm.packet.PacketPimHeader import PacketPimHeader
import traceback
import logging
from .. import Main
class TreeInterfaceDownstream(TreeInterface):
......@@ -22,7 +16,7 @@ class TreeInterfaceDownstream(TreeInterface):
def __init__(self, kernel_entry, interface_id):
extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy()
extra_dict_logger['vif'] = interface_id
extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id]
extra_dict_logger['interfacename'] = kernel_entry.get_interface_name(interface_id)
logger = logging.LoggerAdapter(TreeInterfaceDownstream.LOGGER, extra_dict_logger)
TreeInterface.__init__(self, kernel_entry, interface_id, logger)
self.logger.debug('Created DownstreamInterface')
......
'''
Created on Jul 16, 2015
@author: alex
'''
from .tree_interface import TreeInterface
from .upstream_prune import UpstreamState
from threading import Timer
from pimdm.CustomTimer.RemainingTimer import RemainingTimer
from pimdm.custom_timer.RemainingTimer import RemainingTimer
from .globals import *
import random
from .metric import AssertMetric
from .originator import OriginatorState, OriginatorStateABC
from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh
from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh
import traceback
from . import DataPacketsSocket
from . import data_packets_socket
import threading
import logging
from .. import Main
class TreeInterfaceUpstream(TreeInterface):
......@@ -25,7 +19,7 @@ class TreeInterfaceUpstream(TreeInterface):
def __init__(self, kernel_entry, interface_id):
extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy()
extra_dict_logger['vif'] = interface_id
extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id]
extra_dict_logger['interfacename'] = kernel_entry.get_interface_name(interface_id)
logger = logging.LoggerAdapter(TreeInterfaceUpstream.LOGGER, extra_dict_logger)
TreeInterface.__init__(self, kernel_entry, interface_id, logger)
......@@ -47,15 +41,16 @@ class TreeInterfaceUpstream(TreeInterface):
if self.is_S_directly_conn():
self._graft_prune_state.sourceIsNowDirectConnect(self)
if self.get_interface().is_state_refresh_enabled():
interface = self.get_interface()
if interface is not None and interface.is_state_refresh_enabled():
self._originator_state.recvDataMsgFromSource(self)
# TODO TESTE SOCKET RECV DATA PCKTS
self.socket_is_enabled = True
(s,g) = self.get_tree_id()
interface_name = self.get_interface().interface_name
self.socket_pkt = DataPacketsSocket.get_s_g_bpf_filter_code(s, g, interface_name)
(s, g) = self.get_tree_id()
interface_name = self.get_interface_name()
self.socket_pkt = data_packets_socket.get_s_g_bpf_filter_code(s, g, interface_name)
# run receive method in background
receive_thread = threading.Thread(target=self.socket_recv)
......@@ -182,7 +177,9 @@ class TreeInterfaceUpstream(TreeInterface):
def recv_data_msg(self):
if not self.is_prune_limit_timer_running() and not self.is_S_directly_conn() and self.is_olist_null():
self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self)
elif self.is_S_directly_conn() and self.get_interface().is_state_refresh_enabled():
elif self.is_S_directly_conn():
interface = self.get_interface()
if interface is not None and interface.is_state_refresh_enabled():
self._originator_state.recvDataMsgFromSource(self)
......
'''
Created on Jul 16, 2015
@author: alex
'''
from abc import ABCMeta, abstractmethod
from .. import Main
from threading import RLock
import traceback
from .downstream_prune import DownstreamState
from .assert_ import AssertState, AssertStateABC
from .assert_state import AssertState, AssertStateABC
from pimdm.Packet.PacketPimGraft import PacketPimGraft
from pimdm.Packet.PacketPimGraftAck import PacketPimGraftAck
from pimdm.Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from pimdm.Packet.PacketPimHeader import PacketPimHeader
from pimdm.Packet.Packet import Packet
from pimdm.packet.PacketPimGraft import PacketPimGraft
from pimdm.packet.PacketPimGraftAck import PacketPimGraftAck
from pimdm.packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from pimdm.packet.PacketPimHeader import PacketPimHeader
from pimdm.packet.Packet import Packet
from pimdm.Packet.PacketPimJoinPrune import PacketPimJoinPrune
from pimdm.Packet.PacketPimAssert import PacketPimAssert
from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh
from pimdm.packet.PacketPimJoinPrune import PacketPimJoinPrune
from pimdm.packet.PacketPimAssert import PacketPimAssert
from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh
from .metric import AssertMetric
from threading import Timer
from .local_membership import LocalMembership
from .globals import *
from .globals import T_LIMIT
import logging
class TreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id, logger: logging.LoggerAdapter):
self._kernel_entry = kernel_entry
......@@ -36,9 +31,8 @@ class TreeInterface(metaclass=ABCMeta):
# Local Membership State
try:
interface_name = Main.kernel.vif_index_to_name_dic[interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
group_state = igmp_interface.interface_state.get_group_state(kernel_entry.group_ip)
membership_interface = self.get_membership_interface()
group_state = membership_interface.interface_state.get_group_state(kernel_entry.group_ip)
#self._igmp_has_members = group_state.add_multicast_routing_entry(self)
igmp_has_members = group_state.add_multicast_routing_entry(self)
self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo
......@@ -60,8 +54,7 @@ class TreeInterface(metaclass=ABCMeta):
# Received prune hold time
self._received_prune_holdtime = None
self._igmp_lock = RLock()
self._membership_lock = RLock()
############################################
# Set ASSERT State
......@@ -90,7 +83,6 @@ class TreeInterface(metaclass=ABCMeta):
finally:
self._assert_winner_metric = new_assert_metric
############################################
# ASSERT Timer
############################################
......@@ -106,7 +98,6 @@ class TreeInterface(metaclass=ABCMeta):
def assert_timeout(self):
self._assert_state.assertTimerExpires(self)
###########################################
# Recv packets
###########################################
......@@ -145,7 +136,6 @@ class TreeInterface(metaclass=ABCMeta):
def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator):
self.recv_assert_msg(received_metric)
######################################
# Send messages
######################################
......@@ -163,7 +153,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc()
return
def send_graft_ack(self, ip_sender):
print("send graft ack")
try:
......@@ -177,7 +166,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc()
return
def send_prune(self, holdtime=None):
if holdtime is None:
holdtime = T_LIMIT
......@@ -195,7 +183,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc()
return
def send_pruneecho(self):
holdtime = T_LIMIT
try:
......@@ -210,7 +197,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc()
return
def send_join(self):
print("send join")
......@@ -225,7 +211,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc()
return
def send_assert(self):
print("send assert")
......@@ -240,7 +225,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc()
return
def send_assert_cancel(self):
print("send assert cancel")
......@@ -254,7 +238,6 @@ class TreeInterface(metaclass=ABCMeta):
traceback.print_exc()
return
def send_state_refresh(self, state_refresh_msg_received: PacketPimStateRefresh):
pass
......@@ -282,9 +265,8 @@ class TreeInterface(metaclass=ABCMeta):
(s, g) = self.get_tree_id()
# unsubscribe igmp information
try:
interface_name = Main.kernel.vif_index_to_name_dic[self._interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
group_state = igmp_interface.interface_state.get_group_state(g)
membership_interface = self.get_membership_interface()
group_state = membership_interface.interface_state.get_group_state(g)
group_state.remove_multicast_routing_entry(self)
except:
pass
......@@ -306,29 +288,29 @@ class TreeInterface(metaclass=ABCMeta):
def evaluate_ingroup(self):
self._kernel_entry.evaluate_olist_change()
#############################################################
# Local Membership (IGMP)
############################################################
def notify_igmp(self, has_members: bool):
def notify_membership(self, has_members: bool):
with self.get_state_lock():
with self._igmp_lock:
with self._membership_lock:
if has_members != self._local_membership_state.has_members():
self._local_membership_state = LocalMembership.Include if has_members else LocalMembership.NoInfo
self.change_tree()
self.evaluate_ingroup()
def igmp_has_members(self):
with self._igmp_lock:
with self._membership_lock:
return self._local_membership_state.has_members()
def get_interface_name(self):
return self._kernel_entry.get_interface_name(self._interface_id)
def get_interface(self):
kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id]
interface = Main.interfaces[interface_name]
return interface
return self._kernel_entry.get_interface(self._interface_id)
def get_membership_interface(self):
return self._kernel_entry.get_membership_interface(self._interface_id)
def get_ip(self):
ip = self.get_interface().get_ip()
......@@ -353,9 +335,6 @@ class TreeInterface(metaclass=ABCMeta):
def is_downstream(self):
raise NotImplementedError()
# obtain ip of RPF'(S)
def get_neighbor_RPF(self):
'''
......@@ -375,8 +354,6 @@ class TreeInterface(metaclass=ABCMeta):
def get_received_prune_holdtime(self):
return self._received_prune_holdtime
###################################################
# ASSERT
###################################################
......
import array
'''
import struct
if struct.pack("H",1) == "\x00\x01": # big endian
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return s & 0xffff
else:
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s>>8)&0xff)|s<<8) & 0xffff
'''
HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF
HELLO_HOLD_TIME = 160
HELLO_HOLD_TIME_TIMEOUT = 0
def checksum(pkt: bytes) -> bytes:
......@@ -36,35 +11,6 @@ def checksum(pkt: bytes) -> bytes:
return (((s >> 8) & 0xff) | s << 8) & 0xffff
import ctypes
import ctypes.util
libc = ctypes.CDLL(ctypes.util.find_library('c'))
def if_nametoindex(name):
if not isinstance(name, str):
raise TypeError('name must be a string.')
ret = libc.if_nametoindex(name)
if not ret:
raise RuntimeError("Invalid Name")
return ret
def if_indextoname(index):
if not isinstance(index, int):
raise TypeError('index must be an int.')
libc.if_indextoname.argtypes = [ctypes.c_uint32, ctypes.c_char_p]
libc.if_indextoname.restype = ctypes.c_char_p
ifname = ctypes.create_string_buffer(32)
ifname = libc.if_indextoname(index, ifname)
if not ifname:
raise RuntimeError ("Inavlid Index")
return ifname.decode("utf-8")
# obtain TYPE_CHECKING (for type hinting)
try:
from typing import TYPE_CHECKING
......
......@@ -12,8 +12,8 @@ setup(
description="PIM-DM protocol",
long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown",
keywords="PIM-DM Multicast Routing Protocol Dense-Mode Router RFC3973",
version="1.0.4.2",
keywords="PIM-DM Multicast Routing Protocol Dense-Mode Router RFC3973 IPv4 IPv6",
version="1.1",
url="http://github.com/pedrofran12/pim_dm",
author="Pedro Oliveira",
author_email="pedro.francisco.oliveira@tecnico.ulisboa.pt",
......@@ -38,7 +38,6 @@ setup(
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.2",
"Programming Language :: Python :: 3.3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
......@@ -46,5 +45,5 @@ setup(
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
],
python_requires=">=3.2",
python_requires=">=3.3",
)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment