Commit 1af99d4e authored by Olivier Tilmans's avatar Olivier Tilmans

Fix dns_matching example

* The name encoding function was not checking the constraints on domain
  names properly (<= 253 chars as one byte is needed for the name of the
  label and one for the terminating 0-len label; <= 63 chars per label).
* The name encoding function was erroring when assigning a struct value
  in the byte array (Python 3.6.3). Refactored to join successive
  subarrays, and moved the null padding to make it explicit that it is
  needed by the bpf map key (and not the dns encoding).
* Used builtin from argparse to have a list of domains in the command
  line arguments.
* Reset the non-block flag through fcntl directly instead of
  reconstructing a socket object.
* Exit gracefully when triggering SIGINT as hinted.
parent 61bc92ad
...@@ -4,34 +4,33 @@ from __future__ import print_function ...@@ -4,34 +4,33 @@ from __future__ import print_function
from bcc import BPF from bcc import BPF
from ctypes import * from ctypes import *
import sys
import socket
import os import os
import struct import sys
import fcntl
import dnslib import dnslib
import argparse import argparse
def encode_dns(name): def encode_dns(name):
size = 255 if len(name) + 1 > 255:
if len(name) > 255:
raise Exception("DNS Name too long.") raise Exception("DNS Name too long.")
b = bytearray(size) b = bytearray()
i = 0; for element in name.split('.'):
elements = name.split(".") sublen = len(element)
for element in elements: if sublen > 63:
b[i] = struct.pack("!B", len(element)) raise ValueError('DNS label %s is too long' % element)
i += 1 b.append(sublen)
for j in range(0, len(element)): b.extend(element.encode('ascii'))
b[i] = element[j] b.append(0) # Add 0-len octet label for the root server
i += 1 return b
return (c_ubyte * size).from_buffer(b)
def add_cache_entry(cache, name): def add_cache_entry(cache, name):
key = cache.Key() key = cache.Key()
key.p = encode_dns(name) key_len = len(key.p)
name_buffer = encode_dns(name)
# Pad the buffer with null bytes if it is too short
name_buffer.extend((0,) * (key_len - len(name_buffer)))
key.p = (c_ubyte * key_len).from_buffer(name_buffer)
leaf = cache.Leaf() leaf = cache.Leaf()
leaf.p = (c_ubyte * 4).from_buffer(bytearray(4)) leaf.p = (c_ubyte * 4).from_buffer(bytearray(4))
cache[key] = leaf cache[key] = leaf
...@@ -41,8 +40,8 @@ parser = argparse.ArgumentParser(usage='For detailed information about usage,\ ...@@ -41,8 +40,8 @@ parser = argparse.ArgumentParser(usage='For detailed information about usage,\
try with -h option') try with -h option')
req_args = parser.add_argument_group("Required arguments") req_args = parser.add_argument_group("Required arguments")
req_args.add_argument("-i", "--interface", type=str, required=True, help="Interface name") req_args.add_argument("-i", "--interface", type=str, required=True, help="Interface name")
req_args.add_argument("-d", "--domains", type=str, required=True, req_args.add_argument("-d", "--domains", type=str, required=True, nargs="+",
help='List of domain names separated by comma. For example: -d "abc.def, xyz.mno"') help='List of domain names separated by space. For example: -d "abc.def xyz.mno"')
args = parser.parse_args() args = parser.parse_args()
# initialize BPF - load source code from http-parse-simple.c # initialize BPF - load source code from http-parse-simple.c
...@@ -63,8 +62,7 @@ BPF.attach_raw_socket(function_dns_matching, args.interface) ...@@ -63,8 +62,7 @@ BPF.attach_raw_socket(function_dns_matching, args.interface)
cache = bpf.get_table("cache") cache = bpf.get_table("cache")
# Add cache entries # Add cache entries
entries = [i.strip() for i in args.domains.split(",")] for e in args.domains:
for e in entries:
print(">>>> Adding map entry: ", e) print(">>>> Adding map entry: ", e)
add_cache_entry(cache, e) add_cache_entry(cache, e)
...@@ -75,12 +73,15 @@ print("Packets received by user space program will be printed here") ...@@ -75,12 +73,15 @@ print("Packets received by user space program will be printed here")
print("\nHit Ctrl+C to end...") print("\nHit Ctrl+C to end...")
socket_fd = function_dns_matching.sock socket_fd = function_dns_matching.sock
sock = socket.fromfd(socket_fd, socket.PF_PACKET, socket.SOCK_RAW, socket.IPPROTO_IP) fl = fcntl.fcntl(socket_fd, fcntl.F_GETFL)
sock.setblocking(True) fcntl.fcntl(socket_fd, fcntl.F_SETFL, fl & (~os.O_NONBLOCK))
while 1: while 1:
#retrieve raw packet from socket #retrieve raw packet from socket
try:
packet_str = os.read(socket_fd, 2048) packet_str = os.read(socket_fd, 2048)
except KeyboardInterrupt:
sys.exit(0)
packet_bytearray = bytearray(packet_str) packet_bytearray = bytearray(packet_str)
ETH_HLEN = 14 ETH_HLEN = 14
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment