Commit f8914e80 authored by Tom Niget's avatar Tom Niget

fix lots of stuff

parent 093a0a1a
# vim: ts=4:sw=4:et:ai:sts=4 # vim: ts=4:sw=4:et:ai:sts=4
import csv, StringIO, subprocess import csv, subprocess
class Graph: class Graph:
[LINE, DOT, POINT, LINEPOINT] = range(0, 4) [LINE, DOT, POINT, LINEPOINT] = range(0, 4)
...@@ -14,12 +14,12 @@ class Graph: ...@@ -14,12 +14,12 @@ class Graph:
def generate(self, output_file): def generate(self, output_file):
lines = self.gen_output() lines = self.gen_output()
lines.insert(0, "set terminal postscript") lines.insert(0, "set terminal postscript")
lines.insert(0, "set output '%s'" % filename) lines.insert(0, "set output '%s'" % output_file)
gnuplot = subprocess.Popen(['gnuplot', '-'], gnuplot = subprocess.Popen(['gnuplot', '-'],
stdin = subprocess.PIPE, stdin = subprocess.PIPE,
stdout = subprocess.PIPE, stdout = subprocess.PIPE,
stderr = subprocess.STDOUT) stderr = subprocess.STDOUT)
gnuplot.communicate(input = "\n".join(lines)) gnuplot.communicate(input = "\n".join(lines).encode("utf-8"))
def Xplot(self, plotnr): def Xplot(self, plotnr):
lines = self.gen_output(plotnr) lines = self.gen_output(plotnr)
lines.insert(0, "set terminal wxt") lines.insert(0, "set terminal wxt")
...@@ -158,7 +158,7 @@ class Data: ...@@ -158,7 +158,7 @@ class Data:
for row in self._data: for row in self._data:
row.append(fn(row)) row.append(fn(row))
if colname: if colname:
self._colname.append(colname) self._colnames.append(colname)
for row in self._datadict: for row in self._datadict:
row[colname] = fn(row) row[colname] = fn(row)
return self.ncols() - 1 return self.ncols() - 1
......
...@@ -41,7 +41,7 @@ def main(): ...@@ -41,7 +41,7 @@ def main():
opts, args = getopt.getopt(sys.argv[1:], "hn:s:t:p:b:", [ opts, args = getopt.getopt(sys.argv[1:], "hn:s:t:p:b:", [
"help", "nodes=", "pktsize=", "time=", "packets=", "bytes=", "help", "nodes=", "pktsize=", "time=", "packets=", "bytes=",
"use-p2p", "delay=", "jitter=", "bandwidth=", "format=" ]) "use-p2p", "delay=", "jitter=", "bandwidth=", "format=" ])
except getopt.GetoptError, err: except getopt.GetoptError as err:
error = str(err) # opts will be empty error = str(err) # opts will be empty
pktsize = nr = time = packets = nbytes = None pktsize = nr = time = packets = nbytes = None
...@@ -155,12 +155,12 @@ def main(): ...@@ -155,12 +155,12 @@ def main():
out = out.strip() out = out.strip()
if format != "csv": if format != "csv":
print "Command line: %s" % (" ".join(sys.argv[1:])) print("Command line: %s" % (" ".join(sys.argv[1:])))
print out.strip() print(out.strip())
return return
data = out.split(" ") data = out.split(" ")
data = dict(map(lambda s: s.partition(":")[::2], data)) data = dict([s.partition(":")[::2] for s in data])
if sorted(data.keys()) != sorted(["brx", "prx", "pksz", "plsz", "err", if sorted(data.keys()) != sorted(["brx", "prx", "pksz", "plsz", "err",
"mind", "avgd", "maxd", "jit", "time"]): "mind", "avgd", "maxd", "jit", "time"]):
raise RuntimeError("Invalid output from udp-perf") raise RuntimeError("Invalid output from udp-perf")
...@@ -182,8 +182,8 @@ def main(): ...@@ -182,8 +182,8 @@ def main():
def ip2dec(ip): def ip2dec(ip):
match = re.search(r'^(\d+)\.(\d+)\.(\d+)\.(\d+)$', ip) match = re.search(r'^(\d+)\.(\d+)\.(\d+)\.(\d+)$', ip)
assert match assert match
return long(match.group(1)) * 2**24 + long(match.group(2)) * 2**16 + \ return int(match.group(1)) * 2**24 + int(match.group(2)) * 2**16 + \
long(match.group(3)) * 2**8 + long(match.group(4)) int(match.group(3)) * 2**8 + int(match.group(4))
def dec2ip(dec): def dec2ip(dec):
res = [None] * 4 res = [None] * 4
......
...@@ -42,7 +42,7 @@ for i in range(SIZE): ...@@ -42,7 +42,7 @@ for i in range(SIZE):
node[i].add_route(prefix='10.0.%d.0' % j, prefix_len=24, node[i].add_route(prefix='10.0.%d.0' % j, prefix_len=24,
nexthop='10.0.%d.2' % i) nexthop='10.0.%d.2' % i)
print "Nodes started with pids: %s" % str([n.pid for n in node]) print("Nodes started with pids: %s" % str([n.pid for n in node]))
#switch0 = nemu.Switch( #switch0 = nemu.Switch(
# bandwidth = 100 * 1024 * 1024, # bandwidth = 100 * 1024 * 1024,
...@@ -53,16 +53,15 @@ print "Nodes started with pids: %s" % str([n.pid for n in node]) ...@@ -53,16 +53,15 @@ print "Nodes started with pids: %s" % str([n.pid for n in node])
# Test connectivity first. Run process, hide output and check # Test connectivity first. Run process, hide output and check
# return code # return code
null = file("/dev/null", "w")
app0 = node[0].Popen("ping -c 1 10.0.%d.2" % (SIZE - 2), shell=True, app0 = node[0].Popen("ping -c 1 10.0.%d.2" % (SIZE - 2), shell=True,
stdout=null) stdout=subprocess.DEVNULL)
ret = app0.wait() ret = app0.wait()
assert ret == 0 assert ret == 0
app1 = node[-1].Popen("ping -c 1 10.0.0.1", shell = True, stdout = null) app1 = node[-1].Popen("ping -c 1 10.0.0.1", shell = True, stdout = subprocess.DEVNULL)
ret = app1.wait() ret = app1.wait()
assert ret == 0 assert ret == 0
print "Connectivity IPv4 OK!" print("Connectivity IPv4 OK!")
if X: if X:
app = [] app = []
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# vim:ts=4:sw=4:et:ai:sts=4 # vim:ts=4:sw=4:et:ai:sts=4
import os, nemu, subprocess, time import os, nemu, subprocess, time
# Uncomment for verbose operation. # Uncomment for verbose operation.
...@@ -12,8 +13,8 @@ X = "DISPLAY" in os.environ and xterm ...@@ -12,8 +13,8 @@ X = "DISPLAY" in os.environ and xterm
node0 = nemu.Node(forward_X11 = X) node0 = nemu.Node(forward_X11 = X)
node1 = nemu.Node(forward_X11 = X) node1 = nemu.Node(forward_X11 = X)
node2 = nemu.Node(forward_X11 = X) node2 = nemu.Node(forward_X11 = X)
print "Nodes started with pids: %s" % str((node0.pid, node1.pid, print("Nodes started with pids: %s" % str((node0.pid, node1.pid,
node2.pid)) node2.pid)))
# interface object maps to a veth pair with one end in a netns # interface object maps to a veth pair with one end in a netns
if0 = nemu.NodeInterface(node0) if0 = nemu.NodeInterface(node0)
...@@ -47,19 +48,18 @@ node2.add_route(prefix = '10.0.0.0', prefix_len = 24, nexthop = '10.0.1.1') ...@@ -47,19 +48,18 @@ node2.add_route(prefix = '10.0.0.0', prefix_len = 24, nexthop = '10.0.1.1')
# Test connectivity first. Run process, hide output and check # Test connectivity first. Run process, hide output and check
# return code # return code
null = file("/dev/null", "w") app0 = node0.Popen("ping -c 1 10.0.1.2", shell = True, stdout = subprocess.DEVNULL)
app0 = node0.Popen("ping -c 1 10.0.1.2", shell = True, stdout = null)
ret = app0.wait() ret = app0.wait()
assert ret == 0 assert ret == 0
app1 = node2.Popen("ping -c 1 10.0.0.1", shell = True, stdout = null) app1 = node2.Popen("ping -c 1 10.0.0.1", shell = True, stdout = subprocess.DEVNULL)
ret = app1.wait() ret = app1.wait()
assert ret == 0 assert ret == 0
print "Connectivity IPv4 OK!" print("Connectivity IPv4 OK!")
# Some nice visual demo. # Some nice visual demo.
if X: if X:
print "Running ping and tcpdump in different nodes." print("Running ping and tcpdump in different nodes.")
app1 = node1.Popen("%s -geometry -0+0 -e %s -ni %s" % app1 = node1.Popen("%s -geometry -0+0 -e %s -ni %s" %
(xterm, nemu.environ.TCPDUMP_PATH, if1b.name), shell = True) (xterm, nemu.environ.TCPDUMP_PATH, if1b.name), shell = True)
time.sleep(3) time.sleep(3)
...@@ -69,12 +69,12 @@ if X: ...@@ -69,12 +69,12 @@ if X:
app1.signal() app1.signal()
app1.wait() app1.wait()
print "Running network conditions test." print("Running network conditions test.")
# When using a args list, the shell is not needed # When using a args list, the shell is not needed
app2 = node0.Popen(["ping", "-q", "-c1000", "-f", "10.0.1.2"], app2 = node0.Popen(["ping", "-q", "-c1000", "-f", "10.0.1.2"],
stdout = subprocess.PIPE) stdout = subprocess.PIPE)
out, err = app2.communicate() out, err = app2.communicate()
print "Ping outout:" print("Ping outout:")
print out print(out)
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
from distutils.core import setup, Extension, Command from distutils.core import setup, Extension, Command
passfd = Extension('_passfd', sources = ['src/passfd/passfd.c'])
setup( setup(
name = 'nemu', name = 'nemu',
version = '0.3.1', version = '0.3.1',
...@@ -15,6 +17,7 @@ setup( ...@@ -15,6 +17,7 @@ setup(
license = 'GPLv2', license = 'GPLv2',
platforms = 'Linux', platforms = 'Linux',
packages = ['nemu'], packages = ['nemu'],
install_requires = ['python-unshare', 'python-passfd'], install_requires = ['unshare', 'six'],
package_dir = {'': 'src'} package_dir = {'': 'src'},
ext_modules = [passfd]
) )
...@@ -24,10 +24,11 @@ and run and test programs in them. ...@@ -24,10 +24,11 @@ and run and test programs in them.
""" """
# pylint: disable=W0401,R0903 # pylint: disable=W0401,R0903
from __future__ import absolute_import import pwd
import os, pwd
from nemu.node import *
from nemu.interface import * from nemu.interface import *
from nemu.node import *
class _Config(object): class _Config(object):
"""Global configuration singleton for Nemu.""" """Global configuration singleton for Nemu."""
......
...@@ -17,11 +17,14 @@ ...@@ -17,11 +17,14 @@
# You should have received a copy of the GNU General Public License along with # You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>. # Nemu. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import import errno
import errno, os, os.path, socket, subprocess, sys, syslog import os
import os.path
import socket
import subprocess
import sys
import syslog
from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
from six.moves import range
__all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"] __all__ = ["IP_PATH", "TC_PATH", "BRCTL_PATH", "SYSCTL_PATH", "HZ"]
__all__ += ["TCPDUMP_PATH", "NETPERF_PATH", "XAUTH_PATH", "XDPYINFO_PATH"] __all__ += ["TCPDUMP_PATH", "NETPERF_PATH", "XAUTH_PATH", "XDPYINFO_PATH"]
...@@ -84,8 +87,7 @@ def execute(cmd): ...@@ -84,8 +87,7 @@ def execute(cmd):
RuntimeError: the command was unsuccessful (return code != 0). RuntimeError: the command was unsuccessful (return code != 0).
""" """
debug("execute(%s)" % cmd) debug("execute(%s)" % cmd)
null = open("/dev/null", "r+") proc = subprocess.Popen(cmd, stdout = subprocess.DEVNULL, stderr = subprocess.PIPE)
proc = subprocess.Popen(cmd, stdout = null, stderr = subprocess.PIPE)
_, err = proc.communicate() _, err = proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err)) raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
...@@ -105,7 +107,7 @@ def backticks(cmd): ...@@ -105,7 +107,7 @@ def backticks(cmd):
out, err = proc.communicate() out, err = proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err)) raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
return out return out.decode("utf-8")
def eintr_wrapper(func, *args): def eintr_wrapper(func, *args):
"Wraps some callable with a loop that retries on EINTR." "Wraps some callable with a loop that retries on EINTR."
...@@ -133,7 +135,7 @@ def find_listen_port(family = socket.AF_INET, type = socket.SOCK_STREAM, ...@@ -133,7 +135,7 @@ def find_listen_port(family = socket.AF_INET, type = socket.SOCK_STREAM,
raise RuntimeError("Cannot find an usable port in the range specified") raise RuntimeError("Cannot find an usable port in the range specified")
# Logging # Logging
_log_level = LOG_WARNING _log_level = LOG_DEBUG
_log_use_syslog = False _log_use_syslog = False
_log_stream = sys.stderr _log_stream = sys.stderr
_log_syslog_opts = () _log_syslog_opts = ()
......
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
# You should have received a copy of the GNU General Public License along with # You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>. # Nemu. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import import os
import os, weakref import weakref
import nemu.iproute import nemu.iproute
from nemu.environ import * from nemu.environ import *
...@@ -39,7 +40,7 @@ class Interface(object): ...@@ -39,7 +40,7 @@ class Interface(object):
def _gen_if_name(): def _gen_if_name():
n = Interface._gen_next_id() n = Interface._gen_next_id()
# Max 15 chars # Max 15 chars
return "NETNSif-%.4x%.3x" % (os.getpid(), n) return "NETNSif-%.4x%.3x" % (os.getpid() % 0xffff, n)
def __init__(self, index): def __init__(self, index):
self._idx = index self._idx = index
...@@ -336,7 +337,7 @@ class ExternalInterface(Interface): ...@@ -336,7 +337,7 @@ class ExternalInterface(Interface):
nemu.iproute.del_addr(self.index, addr) nemu.iproute.del_addr(self.index, addr)
def get_addresses(self): def get_addresses(self):
addresses = nemu.iproute.get_addr_data(self.index) addresses = nemu.iproute.get_addr_data()
ret = [] ret = []
for a in addresses: for a in addresses:
if hasattr(a, 'broadcast'): if hasattr(a, 'broadcast'):
...@@ -385,7 +386,7 @@ class Switch(ExternalInterface): ...@@ -385,7 +386,7 @@ class Switch(ExternalInterface):
def _gen_br_name(): def _gen_br_name():
n = Switch._gen_next_id() n = Switch._gen_next_id()
# Max 15 chars # Max 15 chars
return "NETNSbr-%.4x%.3x" % (os.getpid(), n) return "NETNSbr-%.4x%.3x" % (os.getpid() % 0xffff, n)
def __init__(self, **args): def __init__(self, **args):
"""Creates a new Switch object, which models a linux bridge device. """Creates a new Switch object, which models a linux bridge device.
...@@ -434,7 +435,7 @@ class Switch(ExternalInterface): ...@@ -434,7 +435,7 @@ class Switch(ExternalInterface):
self._check_port(p) self._check_port(p)
self.up = False self.up = False
for p in self._ports.values(): for p in list(self._ports.values()):
self.disconnect(p) self.disconnect(p)
self._ports.clear() self._ports.clear()
......
...@@ -17,11 +17,19 @@ ...@@ -17,11 +17,19 @@
# You should have received a copy of the GNU General Public License along with # You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>. # Nemu. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import import copy
import copy, fcntl, os, re, socket, struct, subprocess, sys import fcntl
from nemu.environ import * import os
import re
import socket
import struct
import sys
import six import six
from nemu.environ import *
# helpers # helpers
def _any_to_bool(any): def _any_to_bool(any):
if isinstance(any, bool): if isinstance(any, bool):
...@@ -488,32 +496,10 @@ def del_addr(iface, address): ...@@ -488,32 +496,10 @@ def del_addr(iface, address):
"%s/%d" % (address.address, int(address.prefix_len))] "%s/%d" % (address.address, int(address.prefix_len))]
execute(cmd) execute(cmd)
def set_addr(iface, addresses, recover = True):
ifname = _get_if_name(iface)
addresses = get_addr_data()[1][ifname]
to_remove = set(orig_addresses) - set(addresses)
to_add = set(addresses) - set(orig_addresses)
for a in to_remove:
try:
del_addr(ifname, a)
except:
if recover:
set_addr(orig_addresses, recover = False) # rollback
raise
for a in to_add:
try:
add_addr(ifname, a)
except:
if recover:
set_addr(orig_addresses, recover = False) # rollback
raise
# Bridge handling # Bridge handling
def _sysfs_read_br(brname): def _sysfs_read_br(brname):
def readval(fname): def readval(fname):
f = open(fname) with open(fname) as f:
return f.readline().strip() return f.readline().strip()
p = "/sys/class/net/%s/bridge/" % brname p = "/sys/class/net/%s/bridge/" % brname
...@@ -882,7 +868,7 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None, ...@@ -882,7 +868,7 @@ def set_tc(iface, bandwidth = None, delay = None, delay_jitter = None,
if bandwidth: if bandwidth:
rate = "%dbit" % int(bandwidth) rate = "%dbit" % int(bandwidth)
mtu = ifdata[iface.index].mtu mtu = ifdata[iface.index].mtu
burst = max(mtu, int(bandwidth) / HZ) burst = max(mtu, int(bandwidth) // HZ)
limit = burst * 2 # FIXME? limit = burst * 2 # FIXME?
handle = "1:" handle = "1:"
if cmd == "change": if cmd == "change":
...@@ -953,8 +939,9 @@ def create_tap(iface, use_pi = False, tun = False): ...@@ -953,8 +939,9 @@ def create_tap(iface, use_pi = False, tun = False):
fd = os.open("/dev/net/tun", os.O_RDWR) fd = os.open("/dev/net/tun", os.O_RDWR)
err = fcntl.ioctl(fd, TUNSETIFF, struct.pack("16sH", iface.name, mode)) try:
if err < 0: fcntl.ioctl(fd, TUNSETIFF, struct.pack("16sH", iface.name.encode("ascii"), mode))
except IOError:
os.close(fd) os.close(fd)
raise RuntimeError("Could not configure device %s" % iface.name) raise RuntimeError("Could not configure device %s" % iface.name)
......
...@@ -17,10 +17,17 @@ ...@@ -17,10 +17,17 @@
# You should have received a copy of the GNU General Public License along with # You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>. # Nemu. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import import os
import os, socket, sys, traceback, unshare, weakref import socket
import sys
import traceback
import unshare
import weakref
import nemu.interface
import nemu.protocol
import nemu.subprocess_
from nemu.environ import * from nemu.environ import *
import nemu.interface, nemu.protocol, nemu.subprocess_
__all__ = ['Node', 'get_nodes', 'import_if'] __all__ = ['Node', 'get_nodes', 'import_if']
...@@ -186,7 +193,7 @@ class Node(object): ...@@ -186,7 +193,7 @@ class Node(object):
# Handle the creation of the child; parent gets (fd, pid), child creates and # Handle the creation of the child; parent gets (fd, pid), child creates and
# runs a Server(); never returns. # runs a Server(); never returns.
# Requires CAP_SYS_ADMIN privileges to run. # Requires CAP_SYS_ADMIN privileges to run.
def _start_child(nonetns): def _start_child(nonetns) -> (socket.socket, int):
# Create socket pair to communicate # Create socket pair to communicate
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
# Spawn a child that will run in a loop # Spawn a child that will run in a loop
......
...@@ -17,18 +17,23 @@ ...@@ -17,18 +17,23 @@
# You should have received a copy of the GNU General Public License along with # You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>. # Nemu. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import import base64
import base64, errno, os, passfd, re, select, signal, socket, sys, tempfile import errno
import time, traceback, unshare import os
import nemu.subprocess_, nemu.iproute import passfd
import re
import select
import signal
import socket
import sys
import tempfile
import time
import traceback
from pickle import loads, dumps
import nemu.iproute
import nemu.subprocess_
from nemu.environ import * from nemu.environ import *
from six.moves import map
from six.moves import range
try:
from six.moves.cPickle import loads, dumps
except:
from pickle import loads, dumps
# ============================================================================ # ============================================================================
# Server-side protocol implementation # Server-side protocol implementation
...@@ -93,7 +98,7 @@ class Server(object): ...@@ -93,7 +98,7 @@ class Server(object):
"""Class that implements the communication protocol and dispatches calls """Class that implements the communication protocol and dispatches calls
to the required functions. Also works as the main loop for the slave to the required functions. Also works as the main loop for the slave
process.""" process."""
def __init__(self, rfd, wfd): def __init__(self, rfd: socket.socket, wfd: socket.socket):
debug("Server(0x%x).__init__()" % id(self)) debug("Server(0x%x).__init__()" % id(self))
# Dictionary of valid commands # Dictionary of valid commands
self._commands = _proto_commands self._commands = _proto_commands
...@@ -109,7 +114,9 @@ class Server(object): ...@@ -109,7 +114,9 @@ class Server(object):
self._xfwd = None self._xfwd = None
self._xsock = None self._xsock = None
self._rfd_socket = rfd
self._rfd = _get_file(rfd, "r") self._rfd = _get_file(rfd, "r")
self._wfd_socket = wfd
self._wfd = _get_file(wfd, "w") self._wfd = _get_file(wfd, "w")
def clean(self): def clean(self):
...@@ -152,7 +159,7 @@ class Server(object): ...@@ -152,7 +159,7 @@ class Server(object):
def reply(self, code, text): def reply(self, code, text):
"Send back a reply to the client; handle multiline messages" "Send back a reply to the client; handle multiline messages"
if not hasattr(text, '__iter__'): if type(text) != list:
text = [ text ] text = [ text ]
clean = [] clean = []
# Split lines with embedded \n # Split lines with embedded \n
...@@ -250,12 +257,12 @@ class Server(object): ...@@ -250,12 +257,12 @@ class Server(object):
return None return None
elif argstemplate[j] == 'b': elif argstemplate[j] == 'b':
try: try:
args[i] = _db64(args[i]) args[i] = _db64(args[i]).decode("utf-8")
except TypeError: except TypeError:
self.reply(500, "Invalid parameter: not base-64 encoded.") self.reply(500, "Invalid parameter: not base-64 encoded.")
return None return None
elif argstemplate[j] != 's': # pragma: no cover elif argstemplate[j] != 's': # pragma: no cover
raise RuntimeError("Invalid argument template: %s" % _argstmpl) raise RuntimeError("Invalid argument template: %s" % argstemplate)
# Nothing done for "s" parameters # Nothing done for "s" parameters
j += 1 j += 1
...@@ -323,10 +330,10 @@ class Server(object): ...@@ -323,10 +330,10 @@ class Server(object):
"Invalid number of arguments for PROC ENV: must be even.") "Invalid number of arguments for PROC ENV: must be even.")
return return
self._proc['env'] = {} self._proc['env'] = {}
for i in range(len(env)/2): for i in range(len(env)//2):
self._proc['env'][env[i * 2]] = env[i * 2 + 1] self._proc['env'][env[i * 2]] = env[i * 2 + 1]
self.reply(200, "%d environment definition(s) read." % (len(env) / 2)) self.reply(200, "%d environment definition(s) read." % (len(env) // 2))
def do_PROC_SIN(self, cmdname): def do_PROC_SIN(self, cmdname):
self.reply(354, self.reply(354,
...@@ -453,7 +460,7 @@ class Server(object): ...@@ -453,7 +460,7 @@ class Server(object):
"Invalid number of arguments for IF SET: must be even.") "Invalid number of arguments for IF SET: must be even.")
return return
d = {'index': ifnr} d = {'index': ifnr}
for i in range(len(args) / 2): for i in range(len(args) // 2):
d[str(args[i * 2])] = args[i * 2 + 1] d[str(args[i * 2])] = args[i * 2 + 1]
iface = nemu.iproute.interface(**d) iface = nemu.iproute.interface(**d)
...@@ -532,7 +539,7 @@ class Server(object): ...@@ -532,7 +539,7 @@ class Server(object):
return return
# Needs to be a separate command to handle synch & buffering issues # Needs to be a separate command to handle synch & buffering issues
try: try:
passfd.sendfd(self._wfd, self._xsock.fileno(), "1") passfd.sendfd(self._wfd, self._xsock.fileno(), b"1")
except: except:
# need to fill the buffer on the other side, nevertheless # need to fill the buffer on the other side, nevertheless
self._wfd.write("1") self._wfd.write("1")
...@@ -548,9 +555,11 @@ class Server(object): ...@@ -548,9 +555,11 @@ class Server(object):
class Client(object): class Client(object):
"""Client-side implementation of the communication protocol. Acts as a RPC """Client-side implementation of the communication protocol. Acts as a RPC
service.""" service."""
def __init__(self, rfd, wfd): def __init__(self, rfd: socket.socket, wfd: socket.socket):
debug("Client(0x%x).__init__()" % id(self)) debug("Client(0x%x).__init__()" % id(self))
self._rfd_socket = rfd
self._rfd = _get_file(rfd, "r") self._rfd = _get_file(rfd, "r")
self._wfd_socket = wfd
self._wfd = _get_file(wfd, "w") self._wfd = _get_file(wfd, "w")
self._forwarder = None self._forwarder = None
# Wait for slave to send banner # Wait for slave to send banner
...@@ -560,7 +569,7 @@ class Client(object): ...@@ -560,7 +569,7 @@ class Client(object):
debug("Client(0x%x).__del__()" % id(self)) debug("Client(0x%x).__del__()" % id(self))
self.shutdown() self.shutdown()
def _send_cmd(self, *args): def _send_cmd(self, *args: str):
if not self._wfd: if not self._wfd:
raise RuntimeError("Client already shut down.") raise RuntimeError("Client already shut down.")
s = " ".join(map(str, args)) + "\n" s = " ".join(map(str, args)) + "\n"
...@@ -593,8 +602,9 @@ class Client(object): ...@@ -593,8 +602,9 @@ class Client(object):
code, text = self._read_reply() code, text = self._read_reply()
if code == 550: # exception if code == 550: # exception
e = loads(_db64(text.partition("\n")[2])) e = loads(_db64(text.partition("\n")[2]))
sys.stderr.write(e.child_traceback)
raise e raise e
if code / 100 != expected: if code // 100 != expected:
raise RuntimeError("Error from slave: %d %s" % (code, text)) raise RuntimeError("Error from slave: %d %s" % (code, text))
return text return text
...@@ -607,19 +617,21 @@ class Client(object): ...@@ -607,19 +617,21 @@ class Client(object):
self._send_cmd("QUIT") self._send_cmd("QUIT")
self._read_and_check_reply() self._read_and_check_reply()
self._rfd.close() self._rfd.close()
self._rfd_socket.close()
self._rfd = None self._rfd = None
self._wfd.close() self._wfd.close()
self._rfd_socket.close()
self._wfd = None self._wfd = None
if self._forwarder: if self._forwarder:
os.kill(self._forwarder, signal.SIGTERM) os.kill(self._forwarder, signal.SIGTERM)
self._forwarder = None self._forwarder = None
def _send_fd(self, name, fd): def _send_fd(self, name: str, fd: int):
"Pass a file descriptor" "Pass a file descriptor"
self._send_cmd("PROC", name) self._send_cmd("PROC", name)
self._read_and_check_reply(3) self._read_and_check_reply(3)
try: try:
passfd.sendfd(self._wfd, fd, "PROC " + name) passfd.sendfd(self._wfd, fd, ("PROC " + name).encode("ascii"))
except: except:
# need to fill the buffer on the other side, nevertheless # need to fill the buffer on the other side, nevertheless
self._wfd.write("=" * (len(name) + 5) + "\n") self._wfd.write("=" * (len(name) + 5) + "\n")
...@@ -683,10 +695,10 @@ class Client(object): ...@@ -683,10 +695,10 @@ class Client(object):
Returns the exitcode if finished, None otherwise.""" Returns the exitcode if finished, None otherwise."""
self._send_cmd("PROC", "POLL", pid) self._send_cmd("PROC", "POLL", pid)
code, text = self._read_reply() code, text = self._read_reply()
if code / 100 == 2: if code // 100 == 2:
exitcode = int(text.split()[0]) exitcode = int(text.split()[0])
return exitcode return exitcode
if code / 100 == 4: if code // 100 == 4:
return None return None
else: else:
raise RuntimeError("Error on command: %d %s" % (code, text)) raise RuntimeError("Error on command: %d %s" % (code, text))
...@@ -802,20 +814,23 @@ class Client(object): ...@@ -802,20 +814,23 @@ class Client(object):
server = self.set_x11(protoname, hexkey) server = self.set_x11(protoname, hexkey)
self._forwarder = _spawn_x11_forwarder(server, sock, addr) self._forwarder = _spawn_x11_forwarder(server, sock, addr)
def _b64(text): def _b64(text: str | bytes) -> str:
if text == None: if text == None:
# easier this way # easier this way
text = '' text = ''
text = str(text) if type(text) is str:
if len(text) == 0 or [x for x in text if ord(x) <= ord(" ") or btext = text.encode("utf-8")
ord(x) > ord("z") or x == "="]: else:
return "=" + base64.b64encode(text) btext = text
if len(text) == 0 or any(x for x in btext if x <= ord(" ") or
x > ord("z") or x == ord("=")):
return "=" + base64.b64encode(btext).decode("ascii")
else: else:
return text return text
def _db64(text): def _db64(text: str) -> bytes:
if not text or text[0] != '=': if not text or text[0] != '=':
return text return text.encode("utf-8")
return base64.b64decode(text[1:]) return base64.b64decode(text[1:])
def _get_file(fd, mode): def _get_file(fd, mode):
......
...@@ -17,10 +17,18 @@ ...@@ -17,10 +17,18 @@
# You should have received a copy of the GNU General Public License along with # You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>. # Nemu. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import import fcntl
import fcntl, grp, os, pickle, pwd, signal, select, sys, time, traceback import grp
import os
import pickle
import pwd
import select
import signal
import sys
import time
import traceback
from nemu.environ import eintr_wrapper from nemu.environ import eintr_wrapper
from six.moves import range
__all__ = [ 'PIPE', 'STDOUT', 'Popen', 'Subprocess', 'spawn', 'wait', 'poll', __all__ = [ 'PIPE', 'STDOUT', 'Popen', 'Subprocess', 'spawn', 'wait', 'poll',
'get_user', 'system', 'backticks', 'backticks_raise' ] 'get_user', 'system', 'backticks', 'backticks_raise' ]
...@@ -190,7 +198,7 @@ class Popen(Subprocess): ...@@ -190,7 +198,7 @@ class Popen(Subprocess):
if getattr(self, k) != None: if getattr(self, k) != None:
eintr_wrapper(os.close, v) eintr_wrapper(os.close, v)
def communicate(self, input = None): def communicate(self, input: bytes = None) -> tuple[bytes, bytes]:
"""See Popen.communicate.""" """See Popen.communicate."""
# FIXME: almost verbatim from stdlib version, need to be removed or # FIXME: almost verbatim from stdlib version, need to be removed or
# something # something
...@@ -217,7 +225,7 @@ class Popen(Subprocess): ...@@ -217,7 +225,7 @@ class Popen(Subprocess):
if self.stdin in w: if self.stdin in w:
wrote = os.write(self.stdin.fileno(), wrote = os.write(self.stdin.fileno(),
#buffer(input, offset, select.PIPE_BUF)) #buffer(input, offset, select.PIPE_BUF))
buffer(input, offset, 512)) # XXX: py2.7 input[offset:offset+512]) # XXX: py2.7
offset += wrote offset += wrote
if offset >= len(input): if offset >= len(input):
self.stdin.close() self.stdin.close()
...@@ -226,7 +234,7 @@ class Popen(Subprocess): ...@@ -226,7 +234,7 @@ class Popen(Subprocess):
if i in r: if i in r:
d = os.read(i.fileno(), 1024) # No need for eintr wrapper d = os.read(i.fileno(), 1024) # No need for eintr wrapper
if d == "": if d == "":
i.close i.close()
rset.remove(i) rset.remove(i)
else: else:
if i == self.stdout: if i == self.stdout:
...@@ -235,9 +243,9 @@ class Popen(Subprocess): ...@@ -235,9 +243,9 @@ class Popen(Subprocess):
err.append(d) err.append(d)
if out != None: if out != None:
out = ''.join(out) out = b''.join(out)
if err != None: if err != None:
err = ''.join(err) err = b''.join(err)
self.wait() self.wait()
return (out, err) return (out, err)
...@@ -376,7 +384,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, ...@@ -376,7 +384,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
eintr_wrapper(os.close, w) eintr_wrapper(os.close, w)
# read EOF for success, or a string as error info # read EOF for success, or a string as error info
s = "" s = b""
while True: while True:
s1 = eintr_wrapper(os.read, r, 4096) s1 = eintr_wrapper(os.read, r, 4096)
if s1 == "": if s1 == "":
...@@ -384,7 +392,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False, ...@@ -384,7 +392,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
s += s1 s += s1
eintr_wrapper(os.close, r) eintr_wrapper(os.close, r)
if s == "": if s == b"":
return pid return pid
# It was an error # It was an error
......
#
# This file includes code from python-passfd (https://github.com/NightTsarina/python-passfd).
# Copyright (c) 2010 Martina Ferrari <tina@tina.pm>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import socket
import struct
from io import IOBase
def __check_socket(sock: socket.socket | IOBase):
if hasattr(sock, 'family') and sock.family != socket.AF_UNIX:
raise ValueError("Only AF_UNIX sockets are allowed")
if hasattr(sock, 'fileno'):
sock = socket.socket(fileno=sock.fileno())
if not isinstance(sock, socket.socket):
raise TypeError("An socket object or file descriptor was expected")
return sock
def __check_fd(fd):
try:
fd = fd.fileno()
except AttributeError:
pass
if not isinstance(fd, int):
raise TypeError("An file object or file descriptor was expected")
return fd
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096):
"""
import _passfd
(ret, msg) = _passfd.recvfd(__check_socket(sock), msg_buf)
# -1 should raise OSError
if ret == -2:
raise RuntimeError("The message received did not contain exactly one" +
" file descriptor")
if ret == -3:
raise RuntimeError("The received file descriptor is not valid")
assert ret >= 0
return (ret, msg)
"""
"""
size = struct.calcsize("@i")
+ with socket.fromfd(conn.fileno(), socket.AF_UNIX, socket.SOCK_STREAM) as s:
+ msg, ancdata, flags, addr = s.recvmsg(1, socket.CMSG_LEN(size))
+ try:
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ if (cmsg_level == socket.SOL_SOCKET and
+ cmsg_type == socket.SCM_RIGHTS):
+ return struct.unpack("@i", cmsg_data[:size])[0]
+ except (ValueError, IndexError, struct.error):
+ pass
+ raise RuntimeError('Invalid data received')"""
size = struct.calcsize("@i")
msg, ancdata, flags, addr = __check_socket(sock).recvmsg(4096, socket.CMSG_LEN(size))
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
if not (cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS):
raise RuntimeError("The message received did not contain exactly one" +
" file descriptor")
fd: int = struct.unpack("@i", cmsg_data[:size])[0]
if fd < 0:
raise RuntimeError("The received file descriptor is not valid")
return fd, msg.decode("utf-8")
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE"):
"""
import _passfd
return _passfd.sendfd(__check_socket(sock), __check_fd(fd), message)
"""
return __check_socket(sock).sendmsg(
[message],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))])
\ No newline at end of file
#!/usr/bin/env python
# vim: set fileencoding=utf-8
# vim: ts=4:sw=4:et:ai:sts=4
# passfd.py: Python library to pass file descriptors across UNIX domain sockets.
'''This simple extension provides two functions to pass and receive file
descriptors across UNIX domain sockets, using the BSD-4.3+ sendmsg() and
recvmsg() interfaces.
Direct bindings to sendmsg and recvmsg are not provided, as the API does
not map nicely into Python.
Please note that this only supports BSD-4.3+ style file descriptor
passing, and was only tested on Linux. Patches are welcomed!
For more information, see one of the R. Stevens' books:
- Richard Stevens: Unix Network Programming, Prentice Hall, 1990;
chapter 6.10
- Richard Stevens: Advanced Programming in the UNIX Environment,
Addison-Wesley, 1993; chapter 15.3
'''
#
# Please note that this only supports BSD-4.3+ style file descriptor passing,
# and was only tested on Linux. Patches are welcomed!
#
# Copyright © 2010 Martina Ferrari <tina@tina.pm>
#
# Inspired by Socket::PassAccessRights, which is:
# Copyright (c) 2000 Sampo Kellomaki <sampo@iki.fi>
#
# For more information, see one of the R. Stevens' books:
# - Richard Stevens: Unix Network Programming, Prentice Hall, 1990;
# chapter 6.10
#
# - Richard Stevens: Advanced Programming in the UNIX Environment,
# Addison-Wesley, 1993; chapter 15.3
#
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
import socket
def __check_socket(sock):
if hasattr(sock, 'family') and sock.family != socket.AF_UNIX:
raise ValueError("Only AF_UNIX sockets are allowed")
if hasattr(sock, 'fileno'):
sock = sock.fileno()
if not isinstance(sock, int):
raise TypeError("An socket object or file descriptor was expected")
return sock
def __check_fd(fd):
try:
fd = fd.fileno()
except AttributeError:
pass
if not isinstance(fd, int):
raise TypeError("An file object or file descriptor was expected")
return fd
def sendfd(sock, fd, message = "NONE"):
"""Sends a message and piggybacks a file descriptor through a Unix
domain socket.
Note that the file descriptor cannot be sent by itself, at least
one byte of payload needs to be sent also.
Parameters:
sock: socket object or file descriptor for an AF_UNIX socket
fd: file object or file descriptor to pass
message: message to send
Return value:
On success, sendfd returns the number of bytes sent, not including
the file descriptor nor the control data. If there was no message
to send, 0 is returned."""
import _passfd
return _passfd.sendfd(__check_socket(sock), __check_fd(fd), message)
def recvfd(sock, msg_buf = 4096):
"""Receive a message and a file descriptor from a Unix domain socket.
Parameters:
sock: file descriptor or socket object for an AF_UNIX socket
buffersize: maximum message size to receive
Return value:
On success, recvfd returns a tuple containing the received
file descriptor and message. If recvmsg fails, an OSError exception
is raised. If the received data does not carry exactly one file
descriptor, or if the received file descriptor is not valid,
RuntimeError is raised."""
import _passfd
(ret, msg) = _passfd.recvfd(__check_socket(sock), msg_buf)
# -1 should raise OSError
if ret == -2:
raise RuntimeError("The message received did not contain exactly one" +
" file descriptor")
if ret == -3:
raise RuntimeError("The received file descriptor is not valid")
assert ret >= 0
return (ret, msg)
\ No newline at end of file
/* vim:ts=4:sw=4:et:ai:sts=4
*
* passfd.c: Functions to pass file descriptors across UNIX domain sockets.
*
* Please note that this only supports BSD-4.3+ style file descriptor passing,
* and was only tested on Linux. Patches are welcomed!
*
* Copyright © 2010 Martina Ferrari <tina@tina.pm>
*
* Inspired by Socket::PassAccessRights, which is:
* Copyright (c) 2000 Sampo Kellomaki <sampo@iki.fi>
*
* For more information, see one of the R. Stevens' books:
* - Richard Stevens: Unix Network Programming, Prentice Hall, 1990;
* chapter 6.10
*
* - Richard Stevens: Advanced Programming in the UNIX Environment,
* Addison-Wesley, 1993; chapter 15.3
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 2 of the License, or (at your option)
* any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#ifndef _GNU_SOURCE
# define _GNU_SOURCE
#endif
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
int _sendfd(int sock, int fd, size_t len, const void *msg);
int _recvfd(int sock, size_t *len, void *buf);
/* Python wrapper for _sendfd */
static PyObject *
sendfd(PyObject *self, PyObject *args) {
const char *message;
char *buf;
int ret, sock, fd;
Py_ssize_t message_len;
if(!PyArg_ParseTuple(args, "iis#", &sock, &fd, &message, &message_len))
return NULL;
/* I don't know if I need to make a copy of the message buffer for thread
* safety, but let's do it just in case... */
buf = strndup(message, (size_t)message_len);
if(buf == NULL)
return PyErr_SetFromErrno(PyExc_OSError);
Py_BEGIN_ALLOW_THREADS;
ret = _sendfd(sock, fd, message_len, message);
Py_END_ALLOW_THREADS;
free(buf);
if(ret == -1)
return PyErr_SetFromErrno(PyExc_OSError);
return Py_BuildValue("i", ret);
}
/* Python wrapper for _recvfd */
static PyObject *
recvfd(PyObject *self, PyObject *args) {
char *buffer;
int ret, sock;
Py_ssize_t buffersize = 4096;
size_t _buffersize;
PyObject *retval;
if(!PyArg_ParseTuple(args, "i|i", &sock, &buffersize))
return NULL;
if((buffer = malloc(buffersize)) == NULL)
return PyErr_SetFromErrno(PyExc_OSError);
_buffersize = buffersize;
Py_BEGIN_ALLOW_THREADS;
ret = _recvfd(sock, &_buffersize, buffer);
Py_END_ALLOW_THREADS;
buffersize = _buffersize;
if(ret == -1) {
free(buffer);
return PyErr_SetFromErrno(PyExc_OSError);
}
retval = Py_BuildValue("is#", ret, buffer, buffersize);
free(buffer);
return retval;
}
static PyMethodDef methods[] = {
{"sendfd", sendfd, METH_VARARGS, "rv = sendfd(sock, fd, message)"},
{"recvfd", recvfd, METH_VARARGS, "(fd, message) = recvfd(sock, "
"buffersize = 4096)"},
{NULL, NULL, 0, NULL}
};
static struct PyModuleDef passfdmodule = {
PyModuleDef_HEAD_INIT,
"_passfd", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
methods
};
PyMODINIT_FUNC PyInit__passfd(void) {
return PyModule_Create(&passfdmodule);
}
/* Size of the cmsg including one file descriptor */
#define CMSG_SIZE CMSG_SPACE(sizeof(int))
/*
* _sendfd(): send a message and piggyback a file descriptor.
*
* Note that the file descriptor cannot be sent by itself, at least one byte of
* payload needs to be sent.
*
* Parameters:
* sock: AF_UNIX socket
* fd: file descriptor to pass
* len: length of the message
* msg: the message itself
*
* Return value:
* On success, sendfd returns the number of characters from the message sent,
* the file descriptor information is not taken into account. If there was no
* message to send, 0 is returned. On error, -1 is returned, and errno is set
* appropriately.
*
*/
int _sendfd(int sock, int fd, size_t len, const void *msg) {
struct iovec iov[1];
struct msghdr msgh;
char buf[CMSG_SIZE];
struct cmsghdr *h;
int ret;
/* At least one byte needs to be sent, for some reason (?) */
if(len < 1)
return 0;
memset(&iov[0], 0, sizeof(struct iovec));
memset(&msgh, 0, sizeof(struct msghdr));
memset(buf, 0, CMSG_SIZE);
msgh.msg_name = NULL;
msgh.msg_namelen = 0;
msgh.msg_iov = iov;
msgh.msg_iovlen = 1;
msgh.msg_control = buf;
msgh.msg_controllen = CMSG_SIZE;
msgh.msg_flags = 0;
/* Message to be sent */
iov[0].iov_base = (void *)msg;
iov[0].iov_len = len;
/* Control data */
h = CMSG_FIRSTHDR(&msgh);
h->cmsg_len = CMSG_LEN(sizeof(int));
h->cmsg_level = SOL_SOCKET;
h->cmsg_type = SCM_RIGHTS;
((int *)CMSG_DATA(h))[0] = fd;
ret = sendmsg(sock, &msgh, 0);
return ret;
}
/*
* _recvfd(): receive a message and a file descriptor.
*
* Parameters:
* sock: AF_UNIX socket
* len: pointer to the length of the message buffer, modified on return
* buf: buffer to contain the received buffer
*
* If len is 0 or buf is NULL, the received message is stored in a temporary
* buffer and discarded later.
*
* Return value:
* On success, recvfd returns the received file descriptor, and len points to
* the size of the received message.
* If recvmsg fails, -1 is returned, and errno is set appropriately.
* If the received data does not carry exactly one file descriptor, -2 is
* returned. If the received file descriptor is not valid, -3 is returned.
*
*/
int _recvfd(int sock, size_t *len, void *buf) {
struct iovec iov[1];
struct msghdr msgh;
char cmsgbuf[CMSG_SIZE];
char extrabuf[4096];
struct cmsghdr *h;
int st, fd;
if(*len < 1 || buf == NULL) {
/* For some reason, again, one byte needs to be received. (it would not
* block?) */
iov[0].iov_base = extrabuf;
iov[0].iov_len = sizeof(extrabuf);
} else {
iov[0].iov_base = buf;
iov[0].iov_len = *len;
}
msgh.msg_name = NULL;
msgh.msg_namelen = 0;
msgh.msg_iov = iov;
msgh.msg_iovlen = 1;
msgh.msg_control = cmsgbuf;
msgh.msg_controllen = CMSG_SIZE;
msgh.msg_flags = 0;
st = recvmsg(sock, &msgh, 0);
if(st < 0)
return -1;
*len = st;
h = CMSG_FIRSTHDR(&msgh);
/* Check if we received what we expected */
if(h == NULL
|| h->cmsg_len != CMSG_LEN(sizeof(int))
|| h->cmsg_level != SOL_SOCKET
|| h->cmsg_type != SCM_RIGHTS) {
return -2;
}
fd = ((int *)CMSG_DATA(h))[0];
if(fd < 0)
return -3;
return fd;
}
\ No newline at end of file
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# vim:ts=4:sw=4:et:ai:sts=4 # vim:ts=4:sw=4:et:ai:sts=4
import grp, os, pwd, select, time, unittest import grp, os, pwd, select, time, unittest
import subprocess
import nemu, test_util import nemu, test_util
class TestConfigure(unittest.TestCase): class TestConfigure(unittest.TestCase):
...@@ -19,7 +21,7 @@ class TestConfigure(unittest.TestCase): ...@@ -19,7 +21,7 @@ class TestConfigure(unittest.TestCase):
try: try:
pwd.getpwnam('nobody') pwd.getpwnam('nobody')
nemu.config.run_as('nobody') nemu.config.run_as('nobody')
self.assertEquals(nemu.config.run_as, 'nobody') self.assertEqual(nemu.config.run_as, 'nobody')
except: except:
pass pass
...@@ -35,20 +37,19 @@ class TestGlobal(unittest.TestCase): ...@@ -35,20 +37,19 @@ class TestGlobal(unittest.TestCase):
i1.add_v4_address('10.0.0.1', 24) i1.add_v4_address('10.0.0.1', 24)
i2.add_v4_address('10.0.0.2', 24) i2.add_v4_address('10.0.0.2', 24)
null = file('/dev/null', 'wb') a1 = n1.Popen(['ping', '-qc1', '10.0.0.2'], stdout = subprocess.DEVNULL)
a1 = n1.Popen(['ping', '-qc1', '10.0.0.2'], stdout = null) a2 = n2.Popen(['ping', '-qc1', '10.0.0.1'], stdout = subprocess.DEVNULL)
a2 = n2.Popen(['ping', '-qc1', '10.0.0.1'], stdout = null) self.assertEqual(a1.wait(), 0)
self.assertEquals(a1.wait(), 0) self.assertEqual(a2.wait(), 0)
self.assertEquals(a2.wait(), 0)
# Test ipv6 autoconfigured addresses # Test ipv6 autoconfigured addresses
time.sleep(2) # Wait for autoconfiguration time.sleep(2) # Wait for autoconfiguration
a1 = n1.Popen(['ping6', '-qc1', '-I', i1.name, a1 = n1.Popen(['ping6', '-qc1', '-I', i1.name,
'fe80::d44b:3fff:fef7:ff7f'], stdout = null) 'fe80::d44b:3fff:fef7:ff7f'], stdout = subprocess.DEVNULL)
a2 = n2.Popen(['ping6', '-qc1', '-I', i2.name, a2 = n2.Popen(['ping6', '-qc1', '-I', i2.name,
'fe80::d44b:3fff:fef7:ff7e'], stdout = null) 'fe80::d44b:3fff:fef7:ff7e'], stdout = subprocess.DEVNULL)
self.assertEquals(a1.wait(), 0) self.assertEqual(a1.wait(), 0)
self.assertEquals(a2.wait(), 0) self.assertEqual(a2.wait(), 0)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_run_ping_node_if(self): def test_run_ping_node_if(self):
...@@ -64,11 +65,10 @@ class TestGlobal(unittest.TestCase): ...@@ -64,11 +65,10 @@ class TestGlobal(unittest.TestCase):
i1.add_v4_address('10.0.0.1', 24) i1.add_v4_address('10.0.0.1', 24)
i2.add_v4_address('10.0.0.2', 24) i2.add_v4_address('10.0.0.2', 24)
null = file('/dev/null', 'wb') a1 = n1.Popen(['ping', '-qc1', '10.0.0.2'], stdout = subprocess.DEVNULL)
a1 = n1.Popen(['ping', '-qc1', '10.0.0.2'], stdout = null) a2 = n2.Popen(['ping', '-qc1', '10.0.0.1'], stdout = subprocess.DEVNULL)
a2 = n2.Popen(['ping', '-qc1', '10.0.0.1'], stdout = null) self.assertEqual(a1.wait(), 0)
self.assertEquals(a1.wait(), 0) self.assertEqual(a2.wait(), 0)
self.assertEquals(a2.wait(), 0)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_run_ping_routing_p2p(self): def test_run_ping_routing_p2p(self):
...@@ -88,11 +88,10 @@ class TestGlobal(unittest.TestCase): ...@@ -88,11 +88,10 @@ class TestGlobal(unittest.TestCase):
n3.add_route(prefix = '10.0.0.0', prefix_len = 24, n3.add_route(prefix = '10.0.0.0', prefix_len = 24,
nexthop = '10.0.1.1') nexthop = '10.0.1.1')
null = file('/dev/null', 'wb') a1 = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = subprocess.DEVNULL)
a1 = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = null) a2 = n3.Popen(['ping', '-qc1', '10.0.0.1'], stdout = subprocess.DEVNULL)
a2 = n3.Popen(['ping', '-qc1', '10.0.0.1'], stdout = null) self.assertEqual(a1.wait(), 0)
self.assertEquals(a1.wait(), 0) self.assertEqual(a2.wait(), 0)
self.assertEquals(a2.wait(), 0)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_run_ping_routing(self): def test_run_ping_routing(self):
...@@ -121,11 +120,10 @@ class TestGlobal(unittest.TestCase): ...@@ -121,11 +120,10 @@ class TestGlobal(unittest.TestCase):
n3.add_route(prefix = '10.0.0.0', prefix_len = 24, n3.add_route(prefix = '10.0.0.0', prefix_len = 24,
nexthop = '10.0.1.1') nexthop = '10.0.1.1')
null = file('/dev/null', 'wb') a1 = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = subprocess.DEVNULL)
a1 = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = null) a2 = n3.Popen(['ping', '-qc1', '10.0.0.1'], stdout = subprocess.DEVNULL)
a2 = n3.Popen(['ping', '-qc1', '10.0.0.1'], stdout = null) self.assertEqual(a1.wait(), 0)
self.assertEquals(a1.wait(), 0) self.assertEqual(a2.wait(), 0)
self.assertEquals(a2.wait(), 0)
def _forward_packets(self, subproc, if1, if2): def _forward_packets(self, subproc, if1, if2):
while(True): while(True):
...@@ -156,10 +154,9 @@ class TestGlobal(unittest.TestCase): ...@@ -156,10 +154,9 @@ class TestGlobal(unittest.TestCase):
tun1.add_v4_address('10.0.1.1', 24) tun1.add_v4_address('10.0.1.1', 24)
tun2.add_v4_address('10.0.1.2', 24) tun2.add_v4_address('10.0.1.2', 24)
null = file('/dev/null', 'wb') a = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = subprocess.DEVNULL)
a = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = null)
self._forward_packets(a, tun1, tun2) self._forward_packets(a, tun1, tun2)
self.assertEquals(a.wait(), 0) self.assertEqual(a.wait(), 0)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_run_ping_tap(self): def test_run_ping_tap(self):
...@@ -175,10 +172,9 @@ class TestGlobal(unittest.TestCase): ...@@ -175,10 +172,9 @@ class TestGlobal(unittest.TestCase):
tap1.add_v4_address('10.0.1.1', 24) tap1.add_v4_address('10.0.1.1', 24)
tap2.add_v4_address('10.0.1.2', 24) tap2.add_v4_address('10.0.1.2', 24)
null = file('/dev/null', 'wb') a = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = subprocess.DEVNULL)
a = n1.Popen(['ping', '-qc1', '10.0.1.2'], stdout = null)
self._forward_packets(a, tap1, tap2) self._forward_packets(a, tap1, tap2)
self.assertEquals(a.wait(), 0) self.assertEqual(a.wait(), 0)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_run_ping_tap_routing(self): def test_run_ping_tap_routing(self):
...@@ -223,10 +219,9 @@ class TestGlobal(unittest.TestCase): ...@@ -223,10 +219,9 @@ class TestGlobal(unittest.TestCase):
n4.add_route(prefix = '10.0.1.0', prefix_len = 24, nexthop = '10.0.2.1') n4.add_route(prefix = '10.0.1.0', prefix_len = 24, nexthop = '10.0.2.1')
n4.add_route(prefix = '10.0.0.0', prefix_len = 24, nexthop = '10.0.2.1') n4.add_route(prefix = '10.0.0.0', prefix_len = 24, nexthop = '10.0.2.1')
null = file('/dev/null', 'wb') a = n1.Popen(['ping', '-qc1', '10.0.2.2'], stdout = subprocess.DEVNULL)
a = n1.Popen(['ping', '-qc1', '10.0.2.2'], stdout = null)
self._forward_packets(a, tap1, tap2) self._forward_packets(a, tap1, tap2)
self.assertEquals(a.wait(), 0) self.assertEqual(a.wait(), 0)
class TestX11(unittest.TestCase): class TestX11(unittest.TestCase):
@test_util.skipUnless("DISPLAY" in os.environ, "Test requires working X11") @test_util.skipUnless("DISPLAY" in os.environ, "Test requires working X11")
...@@ -239,7 +234,7 @@ class TestX11(unittest.TestCase): ...@@ -239,7 +234,7 @@ class TestX11(unittest.TestCase):
n = nemu.Node(nonetns = True, forward_X11 = True) n = nemu.Node(nonetns = True, forward_X11 = True)
info2 = n.backticks([xdpy]) info2 = n.backticks([xdpy])
info2 = info2.partition("\n")[2] info2 = info2.partition("\n")[2]
self.assertEquals(info, info2) self.assertEqual(info, info2)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
@test_util.skipUnless("DISPLAY" in os.environ, "Test requires working X11") @test_util.skipUnless("DISPLAY" in os.environ, "Test requires working X11")
...@@ -252,7 +247,7 @@ class TestX11(unittest.TestCase): ...@@ -252,7 +247,7 @@ class TestX11(unittest.TestCase):
n = nemu.Node(forward_X11 = True) n = nemu.Node(forward_X11 = True)
info2 = n.backticks([xdpy]) info2 = n.backticks([xdpy])
info2 = info2.partition("\n")[2] info2 = info2.partition("\n")[2]
self.assertEquals(info, info2) self.assertEqual(info, info2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -13,7 +13,7 @@ class TestUtils(unittest.TestCase): ...@@ -13,7 +13,7 @@ class TestUtils(unittest.TestCase):
self.assertTrue(len(devs) > 0) self.assertTrue(len(devs) > 0)
self.assertTrue('lo' in devs) self.assertTrue('lo' in devs)
self.assertTrue(devs['lo']['up']) self.assertTrue(devs['lo']['up'])
self.assertEquals(devs['lo']['lladdr'], '00:00:00:00:00:00') self.assertEqual(devs['lo']['lladdr'], '00:00:00:00:00:00')
self.assertTrue( { self.assertTrue( {
'address': '127.0.0.1', 'prefix_len': 8, 'address': '127.0.0.1', 'prefix_len': 8,
'broadcast': None, 'family': 'inet' 'broadcast': None, 'family': 'inet'
...@@ -22,13 +22,13 @@ class TestUtils(unittest.TestCase): ...@@ -22,13 +22,13 @@ class TestUtils(unittest.TestCase):
class TestIPRouteStuff(unittest.TestCase): class TestIPRouteStuff(unittest.TestCase):
def test_fix_lladdr(self): def test_fix_lladdr(self):
fl = nemu.iproute._fix_lladdr fl = nemu.iproute._fix_lladdr
self.assertEquals(fl('42:71:e0:90:ca:42'), '42:71:e0:90:ca:42') self.assertEqual(fl('42:71:e0:90:ca:42'), '42:71:e0:90:ca:42')
self.assertEquals(fl('4271E090CA42'), '42:71:e0:90:ca:42', self.assertEqual(fl('4271E090CA42'), '42:71:e0:90:ca:42',
'Normalization of link-level address: missing colons and ' 'Normalization of link-level address: missing colons and '
'upper caps') 'upper caps')
self.assertEquals(fl('2:71:E:90:CA:42'), '02:71:0e:90:ca:42', self.assertEqual(fl('2:71:E:90:CA:42'), '02:71:0e:90:ca:42',
'Normalization of link-level address: missing zeroes') 'Normalization of link-level address: missing zeroes')
self.assertEquals(fl('271E090CA42'), '02:71:e0:90:ca:42', self.assertEqual(fl('271E090CA42'), '02:71:e0:90:ca:42',
'Automatic normalization of link-level address: missing ' 'Automatic normalization of link-level address: missing '
'colons and zeroes') 'colons and zeroes')
self.assertRaises(ValueError, fl, 'foo') self.assertRaises(ValueError, fl, 'foo')
...@@ -44,29 +44,29 @@ class TestIPRouteStuff(unittest.TestCase): ...@@ -44,29 +44,29 @@ class TestIPRouteStuff(unittest.TestCase):
def test_non_empty_str(self): def test_non_empty_str(self):
nes = nemu.iproute._non_empty_str nes = nemu.iproute._non_empty_str
self.assertEquals(nes(''), None) self.assertEqual(nes(''), None)
self.assertEquals(nes('Foo'), 'Foo') self.assertEqual(nes('Foo'), 'Foo')
self.assertEquals(nes(1), '1') self.assertEqual(nes(1), '1')
def test_interface(self): def test_interface(self):
i = nemu.iproute.interface(index = 1) i = nemu.iproute.interface(index = 1)
self.assertRaises(AttributeError, setattr, i, 'index', 2) self.assertRaises(AttributeError, setattr, i, 'index', 2)
self.assertRaises(ValueError, setattr, i, 'mtu', -1) self.assertRaises(ValueError, setattr, i, 'mtu', -1)
self.assertEquals(repr(i), 'nemu.iproute.interface(index = 1, ' self.assertEqual(repr(i), 'nemu.iproute.interface(index = 1, '
'name = None, up = None, mtu = None, lladdr = None, ' 'name = None, up = None, mtu = None, lladdr = None, '
'broadcast = None, multicast = None, arp = None)') 'broadcast = None, multicast = None, arp = None)')
i.name = 'foo'; i.up = 1; i.arp = True; i.mtu = 1500 i.name = 'foo'; i.up = 1; i.arp = True; i.mtu = 1500
self.assertEquals(repr(i), 'nemu.iproute.interface(index = 1, ' self.assertEqual(repr(i), 'nemu.iproute.interface(index = 1, '
'name = \'foo\', up = True, mtu = 1500, lladdr = None, ' 'name = \'foo\', up = True, mtu = 1500, lladdr = None, '
'broadcast = None, multicast = None, arp = True)') 'broadcast = None, multicast = None, arp = True)')
j = nemu.iproute.interface(index = 2) j = nemu.iproute.interface(index = 2)
j.name = 'bar'; j.up = False; j.arp = 1 j.name = 'bar'; j.up = False; j.arp = 1
# Modifications to turn j into i. # Modifications to turn j into i.
self.assertEquals(repr(i - j), 'nemu.iproute.interface(index = 1, ' self.assertEqual(repr(i - j), 'nemu.iproute.interface(index = 1, '
'name = \'foo\', up = True, mtu = 1500, lladdr = None, ' 'name = \'foo\', up = True, mtu = 1500, lladdr = None, '
'broadcast = None, multicast = None, arp = None)') 'broadcast = None, multicast = None, arp = None)')
# Modifications to turn i into j. # Modifications to turn i into j.
self.assertEquals(repr(j - i), 'nemu.iproute.interface(index = 2, ' self.assertEqual(repr(j - i), 'nemu.iproute.interface(index = 2, '
'name = \'bar\', up = False, mtu = None, lladdr = None, ' 'name = \'bar\', up = False, mtu = None, lladdr = None, '
'broadcast = None, multicast = None, arp = None)') 'broadcast = None, multicast = None, arp = None)')
...@@ -86,8 +86,8 @@ class TestInterfaces(unittest.TestCase): ...@@ -86,8 +86,8 @@ class TestInterfaces(unittest.TestCase):
node_devs = set(node0.get_interfaces()) node_devs = set(node0.get_interfaces())
self.assertTrue(set(ifaces).issubset(node_devs)) self.assertTrue(set(ifaces).issubset(node_devs))
loopback = node_devs - set(ifaces) # should be! loopback = node_devs - set(ifaces) # should be!
self.assertEquals(len(loopback), 1) self.assertEqual(len(loopback), 1)
self.assertEquals(loopback.pop().name, 'lo') self.assertEqual(loopback.pop().name, 'lo')
devs = get_devs() devs = get_devs()
for i in range(5): for i in range(5):
...@@ -98,22 +98,22 @@ class TestInterfaces(unittest.TestCase): ...@@ -98,22 +98,22 @@ class TestInterfaces(unittest.TestCase):
def test_interface_settings(self): def test_interface_settings(self):
node0 = nemu.Node() node0 = nemu.Node()
if0 = node0.add_if(lladdr = '42:71:e0:90:ca:42', mtu = 1492) if0 = node0.add_if(lladdr = '42:71:e0:90:ca:42', mtu = 1492)
self.assertEquals(if0.lladdr, '42:71:e0:90:ca:42', self.assertEqual(if0.lladdr, '42:71:e0:90:ca:42',
"Constructor parameters") "Constructor parameters")
self.assertEquals(if0.mtu, 1492, "Constructor parameters") self.assertEqual(if0.mtu, 1492, "Constructor parameters")
if0.lladdr = '4271E090CA42' if0.lladdr = '4271E090CA42'
self.assertEquals(if0.lladdr, '42:71:e0:90:ca:42', """Normalization of self.assertEqual(if0.lladdr, '42:71:e0:90:ca:42', """Normalization of
link-level address: missing colons and upper caps""") link-level address: missing colons and upper caps""")
if0.lladdr = '2:71:E0:90:CA:42' if0.lladdr = '2:71:E0:90:CA:42'
self.assertEquals(if0.lladdr, '02:71:e0:90:ca:42', self.assertEqual(if0.lladdr, '02:71:e0:90:ca:42',
"""Normalization of link-level address: missing zeroes""") """Normalization of link-level address: missing zeroes""")
if0.lladdr = '271E090CA42' if0.lladdr = '271E090CA42'
self.assertEquals(if0.lladdr, '02:71:e0:90:ca:42', self.assertEqual(if0.lladdr, '02:71:e0:90:ca:42',
"""Automatic normalization of link-level address: missing """Automatic normalization of link-level address: missing
colons and zeroes""") colons and zeroes""")
self.assertRaises(ValueError, setattr, if0, 'lladdr', 'foo') self.assertRaises(ValueError, setattr, if0, 'lladdr', 'foo')
self.assertRaises(ValueError, setattr, if0, 'lladdr', '1234567890123') self.assertRaises(ValueError, setattr, if0, 'lladdr', '1234567890123')
self.assertEquals(if0.mtu, 1492) self.assertEqual(if0.mtu, 1492)
# detected by setter # detected by setter
self.assertRaises(ValueError, setattr, if0, 'mtu', 0) self.assertRaises(ValueError, setattr, if0, 'mtu', 0)
# error from ip # error from ip
...@@ -123,8 +123,8 @@ class TestInterfaces(unittest.TestCase): ...@@ -123,8 +123,8 @@ class TestInterfaces(unittest.TestCase):
devs = get_devs_netns(node0) devs = get_devs_netns(node0)
self.assertTrue(if0.name in devs) self.assertTrue(if0.name in devs)
self.assertFalse(devs[if0.name]['up']) self.assertFalse(devs[if0.name]['up'])
self.assertEquals(devs[if0.name]['lladdr'], if0.lladdr) self.assertEqual(devs[if0.name]['lladdr'], if0.lladdr)
self.assertEquals(devs[if0.name]['mtu'], if0.mtu) self.assertEqual(devs[if0.name]['mtu'], if0.mtu)
if0.up = True if0.up = True
devs = get_devs_netns(node0) devs = get_devs_netns(node0)
...@@ -132,10 +132,10 @@ class TestInterfaces(unittest.TestCase): ...@@ -132,10 +132,10 @@ class TestInterfaces(unittest.TestCase):
# Verify that data is actually read from the kernel # Verify that data is actually read from the kernel
r = node0.system([IP_PATH, "link", "set", if0.name, "mtu", "1500"]) r = node0.system([IP_PATH, "link", "set", if0.name, "mtu", "1500"])
self.assertEquals(r, 0) self.assertEqual(r, 0)
devs = get_devs_netns(node0) devs = get_devs_netns(node0)
self.assertEquals(devs[if0.name]['mtu'], 1500) self.assertEqual(devs[if0.name]['mtu'], 1500)
self.assertEquals(devs[if0.name]['mtu'], if0.mtu) self.assertEqual(devs[if0.name]['mtu'], if0.mtu)
# FIXME: get_stats # FIXME: get_stats
...@@ -164,7 +164,7 @@ class TestInterfaces(unittest.TestCase): ...@@ -164,7 +164,7 @@ class TestInterfaces(unittest.TestCase):
} in devs[if0.name]['addr']) } in devs[if0.name]['addr'])
self.assertTrue(len(if0.get_addresses()) >= 2) self.assertTrue(len(if0.get_addresses()) >= 2)
self.assertEquals(if0.get_addresses(), devs[if0.name]['addr']) self.assertEqual(if0.get_addresses(), devs[if0.name]['addr'])
class TestWithDummy(unittest.TestCase): class TestWithDummy(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -177,7 +177,7 @@ class TestWithDummy(unittest.TestCase): ...@@ -177,7 +177,7 @@ class TestWithDummy(unittest.TestCase):
def test_interface_migration(self): def test_interface_migration(self):
node = nemu.Node() node = nemu.Node()
self.dummyname = "dummy%d" % os.getpid() self.dummyname = "dummy%d" % os.getpid()
self.assertEquals(os.system("%s link add name %s type dummy" % self.assertEqual(os.system("%s link add name %s type dummy" %
(IP_PATH, self.dummyname)), 0) (IP_PATH, self.dummyname)), 0)
devs = get_devs() devs = get_devs()
self.assertTrue(self.dummyname in devs) self.assertTrue(self.dummyname in devs)
...@@ -194,8 +194,8 @@ class TestWithDummy(unittest.TestCase): ...@@ -194,8 +194,8 @@ class TestWithDummy(unittest.TestCase):
devs = get_devs_netns(node) devs = get_devs_netns(node)
self.assertTrue(if0.name in devs) self.assertTrue(if0.name in devs)
self.assertEquals(devs[if0.name]['lladdr'], '42:71:e0:90:ca:43') self.assertEqual(devs[if0.name]['lladdr'], '42:71:e0:90:ca:43')
self.assertEquals(devs[if0.name]['mtu'], 1400) self.assertEqual(devs[if0.name]['mtu'], 1400)
node.destroy() node.destroy()
self.assertTrue(self.dummyname in get_devs()) self.assertTrue(self.dummyname in get_devs())
......
...@@ -9,13 +9,13 @@ class TestNode(unittest.TestCase): ...@@ -9,13 +9,13 @@ class TestNode(unittest.TestCase):
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_node(self): def test_node(self):
node = nemu.Node() node = nemu.Node()
self.failIfEqual(node.pid, os.getpid()) self.assertNotEqual(node.pid, os.getpid())
self.failIfEqual(node.pid, None) self.assertNotEqual(node.pid, None)
# check if it really exists # check if it really exists
os.kill(node.pid, 0) os.kill(node.pid, 0)
nodes = nemu.get_nodes() nodes = nemu.get_nodes()
self.assertEquals(nodes, [node]) self.assertEqual(nodes, [node])
self.assertTrue(node.get_interface("lo").up) self.assertTrue(node.get_interface("lo").up)
...@@ -28,7 +28,7 @@ class TestNode(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestNode(unittest.TestCase):
os._exit(0) os._exit(0)
os._exit(1) os._exit(1)
(pid, exitcode) = os.waitpid(chld, 0) (pid, exitcode) = os.waitpid(chld, 0)
self.assertEquals(exitcode, 0, "Node does not recognise forks") self.assertEqual(exitcode, 0, "Node does not recognise forks")
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_cleanup(self): def test_cleanup(self):
...@@ -44,8 +44,8 @@ class TestNode(unittest.TestCase): ...@@ -44,8 +44,8 @@ class TestNode(unittest.TestCase):
# Test automatic destruction # Test automatic destruction
orig_devs = len(test_util.get_devs()) orig_devs = len(test_util.get_devs())
create_stuff() create_stuff()
self.assertEquals(nemu.get_nodes(), []) self.assertEqual(nemu.get_nodes(), [])
self.assertEquals(orig_devs, len(test_util.get_devs())) self.assertEqual(orig_devs, len(test_util.get_devs()))
# Test at_exit hooks # Test at_exit hooks
orig_devs = len(test_util.get_devs()) orig_devs = len(test_util.get_devs())
...@@ -56,7 +56,7 @@ class TestNode(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestNode(unittest.TestCase):
create_stuff() create_stuff()
os._exit(0) os._exit(0)
os.waitpid(chld, 0) os.waitpid(chld, 0)
self.assertEquals(orig_devs, len(test_util.get_devs())) self.assertEqual(orig_devs, len(test_util.get_devs()))
# Test signal hooks # Test signal hooks
orig_devs = len(test_util.get_devs()) orig_devs = len(test_util.get_devs())
...@@ -70,7 +70,7 @@ class TestNode(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestNode(unittest.TestCase):
time.sleep(10) time.sleep(10)
os.kill(chld, signal.SIGTERM) os.kill(chld, signal.SIGTERM)
os.waitpid(chld, 0) os.waitpid(chld, 0)
self.assertEquals(orig_devs, len(test_util.get_devs())) self.assertEqual(orig_devs, len(test_util.get_devs()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# vim:ts=4:sw=4:et:ai:sts=4 # vim:ts=4:sw=4:et:ai:sts=4
import subprocess
import nemu.protocol import nemu.protocol
import os, socket, sys, threading, unittest import os, socket, sys, threading, unittest
import test_util
class TestServer(unittest.TestCase): class TestServer(unittest.TestCase):
def test_server_startup(self): def test_server_startup(self):
# Test the creation of the server object with different ways of passing # Test the creation of the server object with different ways of passing
...@@ -14,10 +18,10 @@ class TestServer(unittest.TestCase): ...@@ -14,10 +18,10 @@ class TestServer(unittest.TestCase):
def test_help(fd): def test_help(fd):
fd.write("HELP\n") fd.write("HELP\n")
# should be more than one line # should be more than one line
self.assertEquals(fd.readline()[0:4], "200-") self.assertEqual(fd.readline()[0:4], "200-")
while True: while True:
l = fd.readline() l = fd.readline()
self.assertEquals(l[0:3], "200") self.assertEqual(l[0:3], "200")
if l[3] == ' ': if l[3] == ' ':
break break
...@@ -31,13 +35,13 @@ class TestServer(unittest.TestCase): ...@@ -31,13 +35,13 @@ class TestServer(unittest.TestCase):
t.start() t.start()
s = os.fdopen(s1.fileno(), "r+", 1) s = os.fdopen(s1.fileno(), "r+", 1)
self.assertEquals(s.readline()[0:4], "220 ") self.assertEqual(s.readline()[0:4], "220 ")
test_help(s) test_help(s)
s.close() s.close()
s0.close() s0.close()
s = os.fdopen(s3.fileno(), "r+", 1) s = os.fdopen(s3.fileno(), "r+", 1)
self.assertEquals(s.readline()[0:4], "220 ") self.assertEqual(s.readline()[0:4], "220 ")
test_help(s) test_help(s)
s.close() s.close()
s2.close() s2.close()
...@@ -52,9 +56,8 @@ class TestServer(unittest.TestCase): ...@@ -52,9 +56,8 @@ class TestServer(unittest.TestCase):
t.start() t.start()
cli = nemu.protocol.Client(s1, s1) cli = nemu.protocol.Client(s1, s1)
null = file('/dev/null', 'wb')
argv = [ '/bin/sh', '-c', 'yes' ] argv = [ '/bin/sh', '-c', 'yes' ]
pid = cli.spawn(argv, stdout = null) pid = cli.spawn(argv, stdout = subprocess.DEVNULL)
self.assertTrue(os.path.exists("/proc/%d" % pid)) self.assertTrue(os.path.exists("/proc/%d" % pid))
# try to exit while there are still processes running # try to exit while there are still processes running
cli.shutdown() cli.shutdown()
...@@ -88,6 +91,7 @@ class TestServer(unittest.TestCase): ...@@ -88,6 +91,7 @@ class TestServer(unittest.TestCase):
t.join() t.join()
@test_util.skip("python 3 can't makefile a socket in r+")
def test_basic_stuff(self): def test_basic_stuff(self):
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
srv = nemu.protocol.Server(s0, s0) srv = nemu.protocol.Server(s0, s0)
...@@ -95,15 +99,15 @@ class TestServer(unittest.TestCase): ...@@ -95,15 +99,15 @@ class TestServer(unittest.TestCase):
def check_error(self, cmd, code = 500): def check_error(self, cmd, code = 500):
s1.write("%s\n" % cmd) s1.write("%s\n" % cmd)
self.assertEquals(srv.readcmd(), None) self.assertEqual(srv.readcmd(), None)
self.assertEquals(s1.readline()[0:4], "%d " % code) self.assertEqual(s1.readline()[0:4], "%d " % code)
def check_ok(self, cmd, func, args): def check_ok(self, cmd, func, args):
s1.write("%s\n" % cmd) s1.write("%s\n" % cmd)
ccmd = " ".join(cmd.upper().split()[0:2]) ccmd = " ".join(cmd.upper().split()[0:2])
if func == None: if func == None:
self.assertEquals(srv.readcmd()[1:3], (ccmd, args)) self.assertEqual(srv.readcmd()[1:3], (ccmd, args))
else: else:
self.assertEquals(srv.readcmd(), (func, ccmd, args)) self.assertEqual(srv.readcmd(), (func, ccmd, args))
check_ok(self, "quit", srv.do_QUIT, []) check_ok(self, "quit", srv.do_QUIT, [])
check_ok(self, " quit ", srv.do_QUIT, []) check_ok(self, " quit ", srv.do_QUIT, [])
......
...@@ -17,13 +17,13 @@ class TestRouting(unittest.TestCase): ...@@ -17,13 +17,13 @@ class TestRouting(unittest.TestCase):
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_routing(self): def test_routing(self):
node = nemu.Node() node = nemu.Node()
self.assertEquals(len(node.get_routes()), 0) self.assertEqual(len(node.get_routes()), 0)
if0 = node.add_if() if0 = node.add_if()
if0.add_v4_address('10.0.0.1', 24) if0.add_v4_address('10.0.0.1', 24)
if0.up = True if0.up = True
routes = node.get_routes() routes = node.get_routes()
self.assertEquals(routes, [node.route(prefix = '10.0.0.0', self.assertEqual(routes, [node.route(prefix = '10.0.0.0',
prefix_len = 24, interface = if0)]) prefix_len = 24, interface = if0)])
node.add_route(nexthop = '10.0.0.2') # default route node.add_route(nexthop = '10.0.0.2') # default route
...@@ -45,7 +45,7 @@ class TestRouting(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestRouting(unittest.TestCase):
node.del_route(prefix = '11.1.0.1', prefix_len = 32, interface = if0) node.del_route(prefix = '11.1.0.1', prefix_len = 32, interface = if0)
node.del_route(prefix = '10.0.0.0', prefix_len = 24, interface = if0) node.del_route(prefix = '10.0.0.0', prefix_len = 24, interface = if0)
self.assertEquals(node.get_routes(), []) self.assertEqual(node.get_routes(), [])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# vim:ts=4:sw=4:et:ai:sts=4 # vim:ts=4:sw=4:et:ai:sts=4
import errno
import nemu, test_util import nemu, test_util
import nemu.subprocess_ as sp import nemu.subprocess_ as sp
...@@ -7,7 +8,7 @@ import grp, os, pwd, signal, socket, sys, time, unittest ...@@ -7,7 +8,7 @@ import grp, os, pwd, signal, socket, sys, time, unittest
def _stat(path): def _stat(path):
try: try:
return os.stat(user) return os.stat(path)
except: except:
return None return None
...@@ -24,20 +25,20 @@ def _getpwuid(uid): ...@@ -24,20 +25,20 @@ def _getpwuid(uid):
return None return None
def _readall(fd): def _readall(fd):
s = "" s = b""
while True: while True:
try: try:
s1 = os.read(fd, 4096) s1 = os.read(fd, 4096)
except OSError, e: except OSError as e:
if e.errno == errno.EINTR: if e.errno == errno.EINTR:
continue continue
else: else:
raise raise
if s1 == "": if s1 == b"":
break break
s += s1 s += s1
return s return s
_longstring = "Long string is long!\n" * 1000 _longstring = b"Long string is long!\n" * 1000
class TestSubprocess(unittest.TestCase): class TestSubprocess(unittest.TestCase):
def _check_ownership(self, user, pid): def _check_ownership(self, user, pid):
...@@ -49,11 +50,11 @@ class TestSubprocess(unittest.TestCase): ...@@ -49,11 +50,11 @@ class TestSubprocess(unittest.TestCase):
data = stat.readline() data = stat.readline()
fields = data.split() fields = data.split()
if fields[0] == 'Uid:': if fields[0] == 'Uid:':
self.assertEquals(fields[1:4], (uid,) * 4) self.assertEqual(fields[1:4], (uid,) * 4)
if fields[0] == 'Gid:': if fields[0] == 'Gid:':
self.assertEquals(fields[1:4], (gid,) * 4) self.assertEqual(fields[1:4], (gid,) * 4)
if fields[0] == 'Groups:': if fields[0] == 'Groups:':
self.assertEquals(set(fields[1:]), set(groups)) self.assertEqual(set(fields[1:]), set(groups))
break break
stat.close() stat.close()
...@@ -74,7 +75,7 @@ class TestSubprocess(unittest.TestCase): ...@@ -74,7 +75,7 @@ class TestSubprocess(unittest.TestCase):
pid = sp.spawn('/bin/sleep', ['/bin/sleep', '100'], user = user) pid = sp.spawn('/bin/sleep', ['/bin/sleep', '100'], user = user)
self._check_ownership(user, pid) self._check_ownership(user, pid)
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
self.assertEquals(sp.wait(pid), signal.SIGTERM) self.assertEqual(sp.wait(pid), signal.SIGTERM)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_Subprocess_chuser(self): def test_Subprocess_chuser(self):
...@@ -83,7 +84,7 @@ class TestSubprocess(unittest.TestCase): ...@@ -83,7 +84,7 @@ class TestSubprocess(unittest.TestCase):
p = node.Subprocess(['/bin/sleep', '1000'], user = user) p = node.Subprocess(['/bin/sleep', '1000'], user = user)
self._check_ownership(user, p.pid) self._check_ownership(user, p.pid)
p.signal() p.signal()
self.assertEquals(p.wait(), -signal.SIGTERM) self.assertEqual(p.wait(), -signal.SIGTERM)
def test_spawn_basic(self): def test_spawn_basic(self):
# User does not exist # User does not exist
...@@ -106,14 +107,14 @@ class TestSubprocess(unittest.TestCase): ...@@ -106,14 +107,14 @@ class TestSubprocess(unittest.TestCase):
r, w = os.pipe() r, w = os.pipe()
p = sp.spawn('/bin/echo', ['echo', 'hello world'], stdout = w) p = sp.spawn('/bin/echo', ['echo', 'hello world'], stdout = w)
os.close(w) os.close(w)
self.assertEquals(_readall(r), "hello world\n") self.assertEqual(_readall(r), b"hello world\n")
os.close(r) os.close(r)
# Check poll. # Check poll.
while True: while True:
ret = sp.poll(p) ret = sp.poll(p)
if ret is not None: if ret is not None:
self.assertEquals(ret, 0) self.assertEqual(ret, 0)
break break
time.sleep(0.2) # Wait a little bit. time.sleep(0.2) # Wait a little bit.
# It cannot be wait()ed again. # It cannot be wait()ed again.
...@@ -124,12 +125,12 @@ class TestSubprocess(unittest.TestCase): ...@@ -124,12 +125,12 @@ class TestSubprocess(unittest.TestCase):
p = sp.spawn('/bin/cat', stdout = w0, stdin = r1, close_fds = [r0, w1]) p = sp.spawn('/bin/cat', stdout = w0, stdin = r1, close_fds = [r0, w1])
os.close(w0) os.close(w0)
os.close(r1) os.close(r1)
self.assertEquals(sp.poll(p), None) self.assertEqual(sp.poll(p), None)
os.write(w1, "hello world\n") os.write(w1, b"hello world\n")
os.close(w1) os.close(w1)
self.assertEquals(_readall(r0), "hello world\n") self.assertEqual(_readall(r0), b"hello world\n")
os.close(r0) os.close(r0)
self.assertEquals(sp.wait(p), 0) self.assertEqual(sp.wait(p), 0)
def test_Subprocess_basic(self): def test_Subprocess_basic(self):
node = nemu.Node(nonetns = True) node = nemu.Node(nonetns = True)
...@@ -152,15 +153,15 @@ class TestSubprocess(unittest.TestCase): ...@@ -152,15 +153,15 @@ class TestSubprocess(unittest.TestCase):
# Argv # Argv
self.assertRaises(OSError, node.Subprocess, 'true; false') self.assertRaises(OSError, node.Subprocess, 'true; false')
self.assertEquals(node.Subprocess('true').wait(), 0) self.assertEqual(node.Subprocess('true').wait(), 0)
self.assertEquals(node.Subprocess('true; false', shell = True).wait(), self.assertEqual(node.Subprocess('true; false', shell = True).wait(),
1) 1)
# Piping # Piping
r, w = os.pipe() r, w = os.pipe()
p = node.Subprocess(['echo', 'hello world'], stdout = w) p = node.Subprocess(['echo', 'hello world'], stdout = w)
os.close(w) os.close(w)
self.assertEquals(_readall(r), "hello world\n") self.assertEqual(_readall(r), b"hello world\n")
os.close(r) os.close(r)
p.wait() p.wait()
...@@ -168,18 +169,18 @@ class TestSubprocess(unittest.TestCase): ...@@ -168,18 +169,18 @@ class TestSubprocess(unittest.TestCase):
r, w = os.pipe() r, w = os.pipe()
p = node.Subprocess('/bin/pwd', stdout = w, cwd = "/") p = node.Subprocess('/bin/pwd', stdout = w, cwd = "/")
os.close(w) os.close(w)
self.assertEquals(_readall(r), "/\n") self.assertEqual(_readall(r), b"/\n")
os.close(r) os.close(r)
p.wait() p.wait()
p = node.Subprocess(['sleep', '100']) p = node.Subprocess(['sleep', '100'])
self.assertTrue(p.pid > 0) self.assertTrue(p.pid > 0)
self.assertEquals(p.poll(), None) # not finished self.assertEqual(p.poll(), None) # not finished
p.signal() p.signal()
p.signal() # verify no-op (otherwise there will be an exception) p.signal() # verify no-op (otherwise there will be an exception)
self.assertEquals(p.wait(), -signal.SIGTERM) self.assertEqual(p.wait(), -signal.SIGTERM)
self.assertEquals(p.wait(), -signal.SIGTERM) # no-op self.assertEqual(p.wait(), -signal.SIGTERM) # no-op
self.assertEquals(p.poll(), -signal.SIGTERM) # no-op self.assertEqual(p.poll(), -signal.SIGTERM) # no-op
# destroy # destroy
p = node.Subprocess(['sleep', '100']) p = node.Subprocess(['sleep', '100'])
...@@ -196,22 +197,21 @@ class TestSubprocess(unittest.TestCase): ...@@ -196,22 +197,21 @@ class TestSubprocess(unittest.TestCase):
r, w = os.pipe() r, w = os.pipe()
p = node.Subprocess(cmd, shell = True, stdout = w) p = node.Subprocess(cmd, shell = True, stdout = w)
os.close(w) os.close(w)
self.assertEquals(_readall(r), "\n") # wait for trap to be installed self.assertEqual(_readall(r), b"\n") # wait for trap to be installed
os.close(r) os.close(r)
pid = p.pid pid = p.pid
os.kill(pid, 0) # verify process still there os.kill(pid, 0) # verify process still there
# Avoid the warning about the process being killed # Avoid the warning about the process being killed
orig_stderr = sys.stderr with open("/dev/null", "w") as sys.stderr:
sys.stderr = open("/dev/null", "w")
p.destroy() p.destroy()
sys.stderr = orig_stderr sys.stderr = sys.__stderr__
self.assertRaises(OSError, os.kill, pid, 0) # should be dead by now self.assertRaises(OSError, os.kill, pid, 0) # should be dead by now
p = node.Subprocess(['sleep', '100']) p = node.Subprocess(['sleep', '100'])
os.kill(p.pid, signal.SIGTERM) os.kill(p.pid, signal.SIGTERM)
time.sleep(0.2) time.sleep(0.2)
p.signal() # since it has not been waited for, it should not raise p.signal() # since it has not been waited for, it should not raise
self.assertEquals(p.wait(), -signal.SIGTERM) self.assertEqual(p.wait(), -signal.SIGTERM)
def test_Popen(self): def test_Popen(self):
node = nemu.Node(nonetns = True) node = nemu.Node(nonetns = True)
...@@ -222,74 +222,74 @@ class TestSubprocess(unittest.TestCase): ...@@ -222,74 +222,74 @@ class TestSubprocess(unittest.TestCase):
p = node.Popen('cat', stdout = w0, stdin = r1) p = node.Popen('cat', stdout = w0, stdin = r1)
os.close(w0) os.close(w0)
os.close(r1) os.close(r1)
os.write(w1, "hello world\n") os.write(w1, b"hello world\n")
os.close(w1) os.close(w1)
self.assertEquals(_readall(r0), "hello world\n") self.assertEqual(_readall(r0), b"hello world\n")
os.close(r0) os.close(r0)
# now with a socketpair, not using integers # now with a socketpair, not using integers
(s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0) (s0, s1) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM, 0)
p = node.Popen('cat', stdout = s0, stdin = s0) p = node.Popen('cat', stdout = s0, stdin = s0)
s0.close() s0.close()
s1.send("hello world\n") s1.send(b"hello world\n")
self.assertEquals(s1.recv(512), "hello world\n") self.assertEqual(s1.recv(512), b"hello world\n")
s1.close() s1.close()
# pipes # pipes
p = node.Popen('cat', stdin = sp.PIPE, stdout = sp.PIPE) p = node.Popen('cat', stdin = sp.PIPE, stdout = sp.PIPE)
p.stdin.write("hello world\n") p.stdin.write(b"hello world\n")
p.stdin.close() p.stdin.close()
self.assertEquals(p.stdout.readlines(), ["hello world\n"]) self.assertEqual(p.stdout.readlines(), [b"hello world\n"])
self.assertEquals(p.stderr, None) self.assertEqual(p.stderr, None)
self.assertEquals(p.wait(), 0) self.assertEqual(p.wait(), 0)
p = node.Popen('cat', stdin = sp.PIPE, stdout = sp.PIPE) p = node.Popen('cat', stdin = sp.PIPE, stdout = sp.PIPE)
self.assertEquals(p.communicate(_longstring), (_longstring, None)) self.assertEqual(p.communicate(_longstring), (_longstring, None))
p = node.Popen('cat', stdin = sp.PIPE, stdout = sp.PIPE) p = node.Popen('cat', stdin = sp.PIPE, stdout = sp.PIPE)
p.stdin.write(_longstring) p.stdin.write(_longstring)
self.assertEquals(p.communicate(), (_longstring, None)) self.assertEqual(p.communicate(), (_longstring, None))
p = node.Popen('cat', stdin = sp.PIPE) p = node.Popen('cat', stdin = sp.PIPE)
self.assertEquals(p.communicate(), (None, None)) self.assertEqual(p.communicate(), (None, None))
p = node.Popen('cat >&2', shell = True, stdin = sp.PIPE, p = node.Popen('cat >&2', shell = True, stdin = sp.PIPE,
stderr = sp.PIPE) stderr = sp.PIPE)
p.stdin.write("hello world\n") p.stdin.write("hello world\n")
p.stdin.close() p.stdin.close()
self.assertEquals(p.stderr.readlines(), ["hello world\n"]) self.assertEqual(p.stderr.readlines(), ["hello world\n"])
self.assertEquals(p.stdout, None) self.assertEqual(p.stdout, None)
self.assertEquals(p.wait(), 0) self.assertEqual(p.wait(), 0)
p = node.Popen(['sh', '-c', 'cat >&2'], stdin = sp.PIPE, p = node.Popen(['sh', '-c', 'cat >&2'], stdin = sp.PIPE,
stderr = sp.PIPE) stderr = sp.PIPE)
self.assertEquals(p.communicate(_longstring), (None, _longstring)) self.assertEqual(p.communicate(_longstring), (None, _longstring))
# #
p = node.Popen(['sh', '-c', 'cat >&2'], p = node.Popen(['sh', '-c', 'cat >&2'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT)
p.stdin.write("hello world\n") p.stdin.write("hello world\n")
p.stdin.close() p.stdin.close()
self.assertEquals(p.stdout.readlines(), ["hello world\n"]) self.assertEqual(p.stdout.readlines(), ["hello world\n"])
self.assertEquals(p.stderr, None) self.assertEqual(p.stderr, None)
self.assertEquals(p.wait(), 0) self.assertEqual(p.wait(), 0)
p = node.Popen(['sh', '-c', 'cat >&2'], p = node.Popen(['sh', '-c', 'cat >&2'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT)
self.assertEquals(p.communicate(_longstring), (_longstring, None)) self.assertEqual(p.communicate(_longstring), (_longstring, None))
# #
p = node.Popen(['tee', '/dev/stderr'], p = node.Popen(['tee', '/dev/stderr'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT)
p.stdin.write("hello world\n") p.stdin.write("hello world\n")
p.stdin.close() p.stdin.close()
self.assertEquals(p.stdout.readlines(), ["hello world\n"] * 2) self.assertEqual(p.stdout.readlines(), ["hello world\n"] * 2)
self.assertEquals(p.stderr, None) self.assertEqual(p.stderr, None)
self.assertEquals(p.wait(), 0) self.assertEqual(p.wait(), 0)
p = node.Popen(['tee', '/dev/stderr'], p = node.Popen(['tee', '/dev/stderr'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.STDOUT)
self.assertEquals(p.communicate(_longstring[0:512]), self.assertEqual(p.communicate(_longstring[0:512]),
(_longstring[0:512] * 2, None)) (_longstring[0:512] * 2, None))
# #
...@@ -297,30 +297,30 @@ class TestSubprocess(unittest.TestCase): ...@@ -297,30 +297,30 @@ class TestSubprocess(unittest.TestCase):
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.PIPE) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.PIPE)
p.stdin.write("hello world\n") p.stdin.write("hello world\n")
p.stdin.close() p.stdin.close()
self.assertEquals(p.stdout.readlines(), ["hello world\n"]) self.assertEqual(p.stdout.readlines(), ["hello world\n"])
self.assertEquals(p.stderr.readlines(), ["hello world\n"]) self.assertEqual(p.stderr.readlines(), ["hello world\n"])
self.assertEquals(p.wait(), 0) self.assertEqual(p.wait(), 0)
p = node.Popen(['tee', '/dev/stderr'], p = node.Popen(['tee', '/dev/stderr'],
stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.PIPE) stdin = sp.PIPE, stdout = sp.PIPE, stderr = sp.PIPE)
self.assertEquals(p.communicate(_longstring), (_longstring, ) * 2) self.assertEqual(p.communicate(_longstring), (_longstring, ) * 2)
def test_backticks(self): def test_backticks(self):
node = nemu.Node(nonetns = True) node = nemu.Node(nonetns = True)
self.assertEquals(node.backticks("echo hello world"), "hello world\n") self.assertEqual(node.backticks("echo hello world"), "hello world\n")
self.assertEquals(node.backticks(r"echo hello\ \ world"), self.assertEqual(node.backticks(r"echo hello\ \ world"),
"hello world\n") "hello world\n")
self.assertEquals(node.backticks(["echo", "hello", "world"]), self.assertEqual(node.backticks(["echo", "hello", "world"]),
"hello world\n") "hello world\n")
self.assertEquals(node.backticks("echo hello world > /dev/null"), "") self.assertEqual(node.backticks("echo hello world > /dev/null"), "")
self.assertEquals(node.backticks_raise("true"), "") self.assertEqual(node.backticks_raise("true"), "")
self.assertRaises(RuntimeError, node.backticks_raise, "false") self.assertRaises(RuntimeError, node.backticks_raise, "false")
self.assertRaises(RuntimeError, node.backticks_raise, "kill $$") self.assertRaises(RuntimeError, node.backticks_raise, "kill $$")
def test_system(self): def test_system(self):
node = nemu.Node(nonetns = True) node = nemu.Node(nonetns = True)
self.assertEquals(node.system("true"), 0) self.assertEqual(node.system("true"), 0)
self.assertEquals(node.system("false"), 1) self.assertEqual(node.system("false"), 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -21,27 +21,27 @@ class TestSwitch(unittest.TestCase): ...@@ -21,27 +21,27 @@ class TestSwitch(unittest.TestCase):
(n1, n2, i1, i2, l) = self.stuff (n1, n2, i1, i2, l) = self.stuff
l.mtu = 3000 l.mtu = 3000
ifdata = nemu.iproute.get_if_data()[0] ifdata = nemu.iproute.get_if_data()[0]
self.assertEquals(ifdata[l.index].mtu, 3000) self.assertEqual(ifdata[l.index].mtu, 3000)
self.assertEquals(ifdata[i1.control.index].mtu, 3000, self.assertEqual(ifdata[i1.control.index].mtu, 3000,
"MTU propagation") "MTU propagation")
self.assertEquals(ifdata[i2.control.index].mtu, 3000, self.assertEqual(ifdata[i2.control.index].mtu, 3000,
"MTU propagation") "MTU propagation")
i1.mtu = i2.mtu = 3000 i1.mtu = i2.mtu = 3000
self.assertEquals(ifdata[l.index].up, False) self.assertEqual(ifdata[l.index].up, False)
self.assertEquals(ifdata[i1.control.index].up, False, self.assertEqual(ifdata[i1.control.index].up, False,
"UP propagation") "UP propagation")
self.assertEquals(ifdata[i2.control.index].up, False, self.assertEqual(ifdata[i2.control.index].up, False,
"UP propagation") "UP propagation")
l.up = True l.up = True
ifdata = nemu.iproute.get_if_data()[0] ifdata = nemu.iproute.get_if_data()[0]
self.assertEquals(ifdata[i1.control.index].up, True, "UP propagation") self.assertEqual(ifdata[i1.control.index].up, True, "UP propagation")
self.assertEquals(ifdata[i2.control.index].up, True, "UP propagation") self.assertEqual(ifdata[i2.control.index].up, True, "UP propagation")
tcdata = nemu.iproute.get_tc_data()[0] tcdata = nemu.iproute.get_tc_data()[0]
self.assertEquals(tcdata[i1.control.index], {"qdiscs": {}}) self.assertEqual(tcdata[i1.control.index], {"qdiscs": {}})
self.assertEquals(tcdata[i2.control.index], {"qdiscs": {}}) self.assertEqual(tcdata[i2.control.index], {"qdiscs": {}})
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges") @test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
def test_switch_changes(self): def test_switch_changes(self):
...@@ -52,10 +52,10 @@ class TestSwitch(unittest.TestCase): ...@@ -52,10 +52,10 @@ class TestSwitch(unittest.TestCase):
"priomap 1 2 2 2 1 2 0 0 1 1 1 1 1 1 1 1") % "priomap 1 2 2 2 1 2 0 0 1 1 1 1 1 1 1 1") %
(nemu.environ.TC_PATH, i1.control.name)) (nemu.environ.TC_PATH, i1.control.name))
tcdata = nemu.iproute.get_tc_data()[0] tcdata = nemu.iproute.get_tc_data()[0]
self.assertEquals(tcdata[i1.control.index], "foreign") self.assertEqual(tcdata[i1.control.index], "foreign")
l.set_parameters(bandwidth = 13107200) # 100 mbits l.set_parameters(bandwidth = 13107200) # 100 mbits
tcdata = nemu.iproute.get_tc_data()[0] tcdata = nemu.iproute.get_tc_data()[0]
self.assertEquals(tcdata[i1.control.index], self.assertEqual(tcdata[i1.control.index],
{"bandwidth": 13107000, "qdiscs": {"tbf": "1"}}) {"bandwidth": 13107000, "qdiscs": {"tbf": "1"}})
# Test tc replacements # Test tc replacements
...@@ -77,36 +77,36 @@ class TestSwitch(unittest.TestCase): ...@@ -77,36 +77,36 @@ class TestSwitch(unittest.TestCase):
(n1, n2, i1, i2, l) = self.stuff (n1, n2, i1, i2, l) = self.stuff
l.set_parameters() l.set_parameters()
tcdata = nemu.iproute.get_tc_data()[0] tcdata = nemu.iproute.get_tc_data()[0]
self.assertEquals(tcdata[i1.control.index], {"qdiscs": {}}) self.assertEqual(tcdata[i1.control.index], {"qdiscs": {}})
self.assertEquals(tcdata[i2.control.index], {"qdiscs": {}}) self.assertEqual(tcdata[i2.control.index], {"qdiscs": {}})
def _test_tbf(self): def _test_tbf(self):
(n1, n2, i1, i2, l) = self.stuff (n1, n2, i1, i2, l) = self.stuff
l.set_parameters(bandwidth = 13107200) # 100 mbits l.set_parameters(bandwidth = 13107200) # 100 mbits
tcdata = nemu.iproute.get_tc_data()[0] tcdata = nemu.iproute.get_tc_data()[0]
self.assertEquals(tcdata[i1.control.index], self.assertEqual(tcdata[i1.control.index],
# adjust for tc rounding # adjust for tc rounding
{"bandwidth": 13107000, "qdiscs": {"tbf": "1"}}) {"bandwidth": 13107000, "qdiscs": {"tbf": "1"}})
self.assertEquals(tcdata[i2.control.index], self.assertEqual(tcdata[i2.control.index],
{"bandwidth": 13107000, "qdiscs": {"tbf": "1"}}) {"bandwidth": 13107000, "qdiscs": {"tbf": "1"}})
def _test_netem(self): def _test_netem(self):
(n1, n2, i1, i2, l) = self.stuff (n1, n2, i1, i2, l) = self.stuff
l.set_parameters(delay = 0.001) # 1ms l.set_parameters(delay = 0.001) # 1ms
tcdata = nemu.iproute.get_tc_data()[0] tcdata = nemu.iproute.get_tc_data()[0]
self.assertEquals(tcdata[i1.control.index], self.assertEqual(tcdata[i1.control.index],
{"delay": 0.001, "qdiscs": {"netem": "2"}}) {"delay": 0.001, "qdiscs": {"netem": "2"}})
self.assertEquals(tcdata[i2.control.index], self.assertEqual(tcdata[i2.control.index],
{"delay": 0.001, "qdiscs": {"netem": "2"}}) {"delay": 0.001, "qdiscs": {"netem": "2"}})
def _test_both(self): def _test_both(self):
(n1, n2, i1, i2, l) = self.stuff (n1, n2, i1, i2, l) = self.stuff
l.set_parameters(bandwidth = 13107200, delay = 0.001) # 100 mbits, 1ms l.set_parameters(bandwidth = 13107200, delay = 0.001) # 100 mbits, 1ms
tcdata = nemu.iproute.get_tc_data()[0] tcdata = nemu.iproute.get_tc_data()[0]
self.assertEquals(tcdata[i1.control.index], self.assertEqual(tcdata[i1.control.index],
{"bandwidth": 13107000, "delay": 0.001, {"bandwidth": 13107000, "delay": 0.001,
"qdiscs": {"tbf": "1", "netem": "2"}}) "qdiscs": {"tbf": "1", "netem": "2"}})
self.assertEquals(tcdata[i2.control.index], self.assertEqual(tcdata[i2.control.index],
{"bandwidth": 13107000, "delay": 0.001, {"bandwidth": 13107000, "delay": 0.001,
"qdiscs": {"tbf": "1", "netem": "2"}}) "qdiscs": {"tbf": "1", "netem": "2"}})
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment