Commit 11978100 authored by zhifan huang's avatar zhifan huang

test: add GFW simulate test

GFW will reset the TCP connection as soon as it finds a HTTP GET packet,
and will randomly drop packets, if it finds the traffic too large.
parent 70edd01f
"""this program contain two function,
RST http get,
random drop packet of some connection
"""
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
import logging
import queue
import random
import socket
import subprocess
import time
from threading import Event, Lock
from typing import Dict
from netfilterqueue import NetfilterQueue
import nftables
from pathlib2 import Path
from scapy.layers import http
from scapy.all import IP, TCP, UDP, Raw
from scapy.all import send
logger = logging.getLogger(__name__)
LEVEL = logging.DEBUG
QUEUE_SIZE = 100
WORKERS = 2
NFT_FILE = str(Path(__file__).parent.resolve() / "ip_rules")
Conn = namedtuple("ConnectionParameter", ["src", "sport", "dst", "dport"])
class TCPControlBlock:
"""class to track tcp connection
GFW only detect one direction of a TCP
"""
_renew_time = 600
_max_traffic = 10
def __init__(self, src, sport, dst, dport) -> None:
self.src = src
self.sport = sport
self.dst = dst
self.dport = dport
self.ack = 0
self.seq = 0
# TODO: Actually, I need (src, dst, dport) to track traffic
self.traffic = 0
self.last_active = time.time()
self.state = None
self.lock = Lock()
def update(self, packet: TCP):
"""when receive new packet, update connection info
for simplicity, trust all incoming packet
"""
# GFW doesn't check ack field, and ignore packet with repeat seq
# maybe receive packet in different order, just track latest seq, ack
# a forged packet can cheat the track system.
with self.lock:
if time.time() - self.last_active > self._renew_time:
self.traffic = 0
self.last_active = time.time()
self.ack = max(self.ack, packet.ack)
self.seq = max(self.seq, packet.seq + len(packet.payload))
self.traffic += len(packet.payload)
logger.info("traffic is %s", self.traffic)
if self.traffic > self._max_traffic:
self.bad_service()
def reset_connection(self):
"""GFW seems not to distinguish which side is in China or not.
It simply sends the same things to each side.
"""
logger.info(
"RST connection %s:%s -> %s:%s", self.src, self.sport, self.dst, self.dport
)
# to recipient
send_rst_packet(self.src, self.dst, self.sport, self.dport, self.seq, self.ack)
# to sender
send_rst_packet(self.dst, self.src, self.dport, self.sport, self.ack, self.seq)
def bad_service(self):
"""add src, dst to nft set 'bad_service', for randomly drop packet"""
logger.info("limit connection %s -> %s", self.src, self.dst)
nft = nftables.Nftables()
nft.cmd("add element inet filter bad_service {%s,%s}" % (self.src, self.dst))
# class TCPConnTable(dict):
# pass
packet_queue = queue.Queue()
tcp_table: Dict[Conn, TCPControlBlock] = {}
def accept_and_record(pkt):
"event box, provide packet"
logger.debug("find a packet %s", pkt)
packet = IP(pkt.get_payload())
if packet.proto in (socket.IPPROTO_TCP, socket.IPPROTO_UDP):
# record only tcp, udp
try:
packet_queue.put(packet, block=False)
except queue.Full:
logger.error("recv too many packet")
time.sleep(0.3)
# gfw is a intrusion detect system, allow all packet
pkt.accept()
def analysis(event: Event):
"analysis box, consume packet"
while not event.is_set():
try:
packet: IP = packet_queue.get(timeout=1)
except queue.Empty:
continue
if packet.haslayer(TCP):
logger.info("Analysis tcp %s", packet.summary())
conn_param = Conn(packet.src, packet.sport, packet.dst, packet.dport)
tcb = tcp_table.get(conn_param, None)
if not tcb:
tcp_table[conn_param] = tcb = TCPControlBlock(*conn_param)
tcb.update(packet.payload)
# interpret the payload of tcp
if packet.haslayer(http.HTTPRequest):
logger.info("find a http request")
http_p = packet.getlayer(http.HTTPRequest)
if http_p.Method == b"GET":
tcb.reset_connection()
elif packet.haslayer(http.HTTPResponse):
logger.info("find a http response")
else:
# not a http packet
pass
if packet.haslayer(UDP):
udp_p = packet.payload
if udp_p.dport == 53:
# dns query, need dns poison
pass
def send_rst_packet(src, dst, sport, dport, seq, ack):
"""send 2 type RST, 1 RST_1, 3 RST_2
common:
IP: ttl random
TCP:
window size: random
option: None(even context have)
RST_1:
IP: id = 0, FLAGS=0
TCP: FLAGS = "R"
RST_2
IP: id random, FLAGS="DF"
TCP: FLAGS = "AR"
"""
# type 1
i = IP(src=src, dst=dst, id=0)
t = TCP(sport=sport, dport=dport, seq=seq, flags="R")
send(i / t)
# type 2
i.id = random.randint(0, 1 << 16 - 1)
i.flags = "DF"
t.flags = "RA"
t.ack = ack
send(i / t)
send(i / t)
send(i / t)
def main():
subprocess.run(["nft", "-f", NFT_FILE], check=True)
event = Event()
# nft = nftables.Nftables()
with ThreadPoolExecutor(max_workers=3) as executor:
for _ in range(WORKERS):
executor.submit(analysis, event)
nfqueue = NetfilterQueue()
nfqueue.bind(1, accept_and_record)
try:
nfqueue.run()
except KeyboardInterrupt:
event.set()
logger.debug("set event and exit program")
nfqueue.unbind()
subprocess.run(("nft", "flush", "ruleset"), check=True)
if __name__ == "__main__":
logging.basicConfig(
level=LEVEL,
filename="gfw.log",
filemode="w",
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%I:%M:%S",
)
try:
main()
except Exception as e:
logger.error(e)
#!/usr/sbin/nft -f
table inet filter {
set blocked_ip {
type ipv4_addr;
elements = {
10.0.10.5
};
}
set tcp_connection {
type ipv4_addr;
flags dynamic;
size 65536;
timeout 10m;
}
set bad_service {
type ipv4_addr;
flags dynamic;
size 65536;
timeout 5m;
elements = {
10.0.11.3
};
}
chain input {
type filter hook input priority filter;
}
chain forward_ipv4 {
# gfw mainly prevent china people from connecting outside
ip daddr @blocked_ip drop
# 50 percent to drop packet
ip daddr @bad_service numgen random mod 10 < 5 drop
ct state new ip protocol tcp \
update @tcp_connection {ip saddr} \
update @tcp_connection {ip daddr}
ip daddr @tcp_connection queue num 1
}
chain forward_ipv6 {
}
chain forward {
# base chain
type filter hook forward priority 0;
meta protocol vmap {ip: jump forward_ipv4, ip6: jump forward_ipv6 }
}
}
#!/bin/python2
import atexit
import ipaddress
from subprocess import PIPE
import nemu
from pathlib2 import Path
from re6st.tests.test_network.network_build import Node, NetManager
GFW = str(Path(__file__).parent.resolve() / "gfw.py")
def net_gfw():
"""Underlying network
registry .2-----
|
10.0.0|
.1 |
---------------Internet
|.1 |.1
|10.1.0 |
|.2 |
gateway1(GFW) s3:10.0.1
|.1 |.2 |.3 |.4
s1:10.1.1 m3 m4 m5
|.2 |.3
m1 m2
"""
internet = Node()
gateway1 = Node()
registry = Node()
m1 = Node()
m2 = Node()
m3 = Node()
m4 = Node()
m5 = Node()
switch1 = nemu.Switch()
switch2 = nemu.Switch()
nm = NetManager()
nm.object = [internet, switch1, switch2, gateway1]
nm.registries = {registry: [m1, m2, m3, m4, m5]}
re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet)
g1_if_0, in_if_1 = nemu.P2PInterface.create_pair(gateway1, internet)
re_if_0.add_v4_address(address="10.0.0.2", prefix_len=24)
in_if_0.add_v4_address(address="10.0.0.1", prefix_len=24)
g1_if_0.add_v4_address(address="10.1.0.2", prefix_len=24)
in_if_1.add_v4_address(address="10.1.0.1", prefix_len=24)
for iface in (re_if_0, in_if_0, g1_if_0, in_if_1):
nm.object.append(iface)
iface.up = True
ip = ipaddress.ip_address(u"10.1.1.1")
for i, node in enumerate([gateway1, m1, m2]):
iface = node.connect_switch(switch1, str(ip + i))
nm.object.append(iface)
if i: # except the first
node.add_route(prefix="10.0.0.0", prefix_len=8, nexthop=ip)
ip = ipaddress.ip_address(u"10.0.1.1")
for i, node in enumerate([internet, m3, m4, m5]):
iface = node.connect_switch(switch2, str(ip + i))
nm.object.append(iface)
if i: # except the first
node.add_route(prefix="10.0.0.0", prefix_len=8, nexthop=ip)
registry.add_route(prefix="10.0.0.0", prefix_len=8, nexthop="10.0.0.1")
gateway1.add_route(prefix="10.0.0.0", prefix_len=8, nexthop="10.1.0.1")
internet.add_route(prefix="10.1.0.0", prefix_len=16, nexthop="10.1.0.2")
switch1.up = switch2.up = True
nm.connectable_test()
gateway1.gfw = gateway1.Popen(["python3", GFW], stdout=PIPE)
atexit.register(gateway1.gfw.destroy)
return nm
import logging
import os
import time
import unittest
import network_build_gfw
from re6st.tests.test_network import re6st_wrap, test_net
@unittest.skipIf(os.geteuid(), "Using root or creating a user namespace")
class TestGFWNet(unittest.TestCase):
"""network with gfw test case"""
@classmethod
def setUpClass(cls):
"""create work dir"""
logging.basicConfig(level=logging.INFO)
re6st_wrap.initial()
@classmethod
def tearDownClass(cls):
"""watch any process leaked after tests"""
logging.basicConfig(level=logging.WARNING)
def test_gfw_ping(self):
"""create a network in a net segment, test the connectivity by ping"""
nm = network_build_gfw.net_gfw()
nodes, _ = test_net.deploy_re6st(nm)
test_net.wait_stable(nodes, 40)
time.sleep(10)
self.assertFalse(test_net.wait_stable(nodes, 30), " ping test success")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment