Commit ecb56d5f authored by Martín Ferrari's avatar Martín Ferrari

X11 forwarding support

parent 64c2e142
......@@ -32,6 +32,7 @@ PROC ABRT 200 (5)
PROC POLL <pid> 200 <code>/450/500 check if process alive
PROC WAIT <pid> 200 <code>/500 waitpid(pid)
PROC KILL <pid> <signal> 200/500 kill(pid, signal)
X11 <prot> <data> 354+200/500 (6)
(1) valid arguments: mtu <n>, up <0|1>, name <name>, lladdr <addr>,
broadcast <addr>, multicast <0|1>, arp <0|1>.
......@@ -52,6 +53,10 @@ command. Answers 200/500 after processing the file descriptor.
was successful, the process is started and the process ID is returned as the
first token of the reply.
(6) Enable X11 forwarding, using the specified protocol and data for
authentication. A opened socket ready to receive X connections is passed over
the channel. Answers 200/500 after transmitting the file descriptor.
Sample session
--------------
......
# vim:ts=4:sw=4:et:ai:sts=4
import os, os.path, subprocess, sys, syslog
import os, os.path, socket, subprocess, sys, syslog
from syslog import LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
__all__ = ["ip_path", "tc_path", "brctl_path", "sysctl_path", "hz"]
__all__ += ["tcpdump_path", "netperf_path"]
__all__ += ["tcpdump_path", "netperf_path", "xauth_path", "xdpyinfo_path"]
__all__ += ["execute", "backticks"]
__all__ += ["find_listen_port"]
__all__ += ["LOG_ERR", "LOG_WARNING", "LOG_NOTICE", "LOG_INFO", "LOG_DEBUG"]
__all__ += ["set_log_level", "logger"]
__all__ += ["error", "warning", "notice", "info", "debug"]
......@@ -44,6 +45,8 @@ sysctl_path = find_bin_or_die("sysctl")
# Optional tools
tcpdump_path = find_bin("tcpdump")
netperf_path = find_bin("netperf")
xauth_path = find_bin("xauth")
xdpyinfo_path = find_bin("xdpyinfo")
# Seems this is completely bogus. At least, we can assume that the internal HZ
# is bigger than this.
......@@ -72,6 +75,17 @@ def backticks(cmd):
raise RuntimeError("Error executing `%s': %s" % (" ".join(cmd), err))
return out
def find_listen_port(family = socket.AF_INET, type = socket.SOCK_STREAM,
proto = 0, addr = "127.0.0.1", min_port = 1, max_port = 65535):
s = socket.socket(family, type, proto)
for p in range(min_port, max_port + 1):
try:
s.bind((addr, p))
return s, p
except socket.error:
pass
raise RuntimeError("Cannot find an usable port in the range specified")
# Logging
_log_level = LOG_WARNING
_log_use_syslog = False
......
......@@ -15,7 +15,7 @@ class Node(object):
s = sorted(Node._nodes.items(), key = lambda x: x[0])
return [x[1] for x in s]
def __init__(self, nonetns = False):
def __init__(self, nonetns = False, forward_X11 = False):
"""Create a new node in the emulation. Implemented as a separate
process in a new network name space. Requires root privileges to run.
......@@ -32,6 +32,8 @@ class Node(object):
self._pid = pid
debug("Node(0x%x).__init__(), pid = %s" % (id(self), pid))
self._slave = netns.protocol.Client(fd, fd)
if forward_X11:
self._slave.enable_x11_forwarding()
Node._nodes[Node._nextnode] = self
Node._nextnode += 1
......
#!/usr/bin/env python
# vim:ts=4:sw=4:et:ai:sts=4
import base64, os, passfd, re, signal, sys, traceback, unshare
import base64, os, passfd, re, select, signal, socket, sys, tempfile, time
import traceback, unshare
import netns.subprocess_, netns.iproute
from netns.environ import *
......@@ -24,6 +25,10 @@ except:
_proto_commands = {
"QUIT": { None: ("", "") },
"HELP": { None: ("", "") },
"X11": {
"SET": ("ss", ""),
"SOCK": ("", "")
},
"IF": {
"LIST": ("", "i"),
"SET": ("iss", "s*"),
......@@ -63,6 +68,8 @@ _proc_commands = {
}
}
KILL_WAIT = 3 # seconds
class Server(object):
"""Class that implements the communication protocol and dispatches calls
to the required functions. Also works as the main loop for the slave
......@@ -77,10 +84,37 @@ class Server(object):
self._children = set()
# Buffer and flag for PROC mode
self._proc = None
# temporary xauth files
self._xauthfiles = {}
# X11 forwarding info
self._xfwd = None
self._xsock = None
self._rfd = _get_file(rfd, "r")
self._wfd = _get_file(wfd, "w")
def clean(self):
for pid in self._children:
os.kill(pid, signal.SIGTERM)
now = time.time()
while time.time() - now < KILL_WAIT:
ch = [pid for pid in self._children
if not netns.subprocess_.poll(pid)]
if not ch:
break
time.sleep(0.1)
for pid in ch:
warning("Killing forcefully process %d." % pid)
os.kill(pid, signal.SIGKILL)
for pid in self._children:
netns.subprocess_.poll(pid)
for f in self._xauthfiles.values():
try:
os.unlink(f)
except:
pass
def reply(self, code, text):
"Send back a reply to the client; handle multiline messages"
if not hasattr(text, '__iter__'):
......@@ -204,6 +238,7 @@ class Server(object):
self._wfd.close()
except:
pass
self.clean()
debug("Server(0x%x) exiting" % id(self))
# FIXME: cleanup
......@@ -274,6 +309,38 @@ class Server(object):
self._proc = None
self._commands = _proto_commands
if 'env' not in params:
params['env'] = dict(os.environ) # copy
xauth = None
if self._xfwd:
display, protoname, hexkey = self._xfwd
user = params['user'] if 'user' in params else None
try:
fd, xauth = tempfile.mkstemp()
os.close(fd)
# stupid xauth format: needs the 'hostname' for local
# connections
execute([xauth_path, "-f", xauth, "add",
"%s/unix:%d" % (socket.gethostname(), display),
protoname, hexkey])
if user:
user, uid, gid = netns.subprocess_.get_user(user)
os.chown(xauth, uid, gid)
params['env']['DISPLAY'] = "127.0.0.1:%d" % display
params['env']['XAUTHORITY'] = xauth
except Exception, e:
warning("Cannot forward X: %s" % e)
try:
os.unlink(xauth)
except:
pass
else:
if 'DISPLAY' in params['env']:
del params['env']['DISPLAY']
try:
chld = netns.subprocess_.spawn(**params)
finally:
......@@ -283,6 +350,7 @@ class Server(object):
os.close(params[d])
self._children.add(chld)
self._xauthfiles[chld] = xauth
self.reply(200, "%d running." % chld)
def do_PROC_ABRT(self, cmdname):
......@@ -301,6 +369,12 @@ class Server(object):
if ret != None:
self._children.remove(pid)
if pid in self._xauthfiles:
try:
os.unlink(self._xauthfiles[pid])
except:
pass
del self._xauthfiles[pid]
self.reply(200, "%d exitcode." % ret)
else:
self.reply(450, "Not finished yet.")
......@@ -387,6 +461,39 @@ class Server(object):
nexthop, ifnr or None, metric))
self.reply(200, "Done.")
def do_X11_SET(self, cmdname, protoname, hexkey):
if not xauth_path:
self.reply(500, "Impossible to forward X: xauth not present")
return
skt, port = None, None
try:
skt, port = find_listen_port(min_port = 6010, max_port = 6099)
except:
self.reply(500, "Cannot allocate a port for X forwarding.")
return
display = port - 6000
self.reply(200, "Socket created on port %d. Use X11 SOCK to get the "
"file descriptor "
"(fixed 1-byte payload before protocol response).")
self._xfwd = display, protoname, hexkey
self._xsock = skt
def do_X11_SOCK(self, cmdname):
if not self._xsock:
self.reply(500, "X forwarding not set up.")
return
# Needs to be a separate command to handle synch & buffering issues
try:
passfd.sendfd(self._wfd, self._xsock.fileno(), "1")
except:
# need to fill the buffer on the other side, nevertheless
self._wfd.write("1")
self.reply(500, "Error sending file descriptor.")
return
self._xsock = None
self.reply(200, "Will set up X forwarding.")
# ============================================================================
#
# Client-side protocol implementation.
......@@ -398,6 +505,7 @@ class Client(object):
debug("Client(0x%x).__init__()" % id(self))
self._rfd = _get_file(rfd, "r")
self._wfd = _get_file(wfd, "w")
self._forwarder = None
# Wait for slave to send banner
self._read_and_check_reply()
......@@ -455,6 +563,9 @@ class Client(object):
self._rfd = None
self._wfd.close()
self._wfd = None
if self._forwarder:
os.kill(self._forwarder, signal.SIGTERM)
self._forwarder = None
def _send_fd(self, name, fd):
"Pass a file descriptor"
......@@ -614,7 +725,35 @@ class Client(object):
route.interface or 0, route.metric or 0]
self._send_cmd(*args)
self._read_and_check_reply()
def set_x11(self, protoname, hexkey):
# Returns a socket ready to accept() connections
self._send_cmd("X11", "SET", protoname, hexkey)
self._read_and_check_reply()
# Receive the socket
self._send_cmd("X11", "SOCK")
fd, payload = passfd.recvfd(self._rfd, 1)
self._read_and_check_reply()
skt = socket.fromfd(fd, socket.AF_INET, socket.SOCK_DGRAM)
os.close(fd) # fromfd dup()'s
return skt
def enable_x11_forwarding(self):
xinfo = _parse_display()
if not xinfo:
raise RuntimeError("Impossible to forward X: DISPLAY variable not "
"set or invalid")
if not xauth_path:
raise RuntimeError("Impossible to forward X: xauth not present")
auth = backticks([xauth_path, "list", os.environ["DISPLAY"]])
match = re.match(r"\S+\s+(\S+)\s+(\S+)\n", auth)
if not match:
raise RuntimeError("Impossible to forward X: invalid DISPLAY")
protoname, hexkey = match.groups()
server = self.set_x11(protoname, hexkey)
self._forwarder = _spawn_x11_forwarder(server, *xinfo)
def _b64(text):
if text == None:
# easier this way
......@@ -639,3 +778,96 @@ def _get_file(fd, mode):
nfd = os.dup(fd)
return os.fdopen(nfd, mode, 1)
def _parse_display():
if "DISPLAY" not in os.environ:
return None
dpy = os.environ["DISPLAY"]
match = re.search(r"^(.*):(\d+)(?:\.(\d+))$", dpy)
if not match:
return None
if match.group(1):
sock = (socket.AF_INET, socket.SOCK_STREAM, 0)
addr = (match.group(1), 6000 + int(match.group(2)))
else:
sock = (socket.AF_UNIX, socket.SOCK_STREAM, 0)
addr = ("/tmp/.X11-unix/X%d" % int(match.group(2)))
return sock, addr
def _spawn_x11_forwarder(server, xsock, xaddr):
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.listen(10) # arbitrary
pid = os.fork()
if pid:
return pid
# XXX: clear signals, etc
try:
_x11_forwarder(server, xsock, xaddr)
except:
traceback.print_exc(file=sys.stderr)
os._exit(1)
def _x11_forwarder(server, xsock, xaddr):
commr = {}
commw = {}
while(True):
toread = [x for x in commr.keys() if commr[x]["in"]] + [server]
towrite = [x for x in commw.keys() if commw[x]["buf"]]
(rr, wr, er) = select.select(toread, towrite, [])
if server in rr:
xconn = socket.socket(*xsock)
xconn.connect(xaddr)
client, addr = server.accept()
commr[xconn.fileno()] = commw[client.fileno()] = {
"in": xconn,
"out": client,
"buf": []}
commw[xconn.fileno()] = commr[client.fileno()] = {
"in": client,
"out": xconn,
"buf": []}
continue
for fd in rr:
try:
s = os.read(fd, 4096)
except OSError, e:
if e.errno != errno.EINTR:
raise
if s == "":
continue
if s == "":
# fd closed
#commr[fd]["in"].shutdown(socket.SHUT_RD)
if not commr[fd]["buf"]:
commr[fd]["out"].shutdown(socket.SHUT_WR)
commr[fd]["in"] = None
else:
commr[fd]["buf"].append(s)
for fd in wr:
try:
x = os.write(fd, commw[fd]["buf"][0])
except OSError, e:
if e.errno == errno.EINTR:
if x > 0:
pass
else:
continue
if e.errno != errno.EPIPE:
raise
# broken pipe, discard output and close
commw[fd]["in"].shutdown(socket.SHUT_RD)
commw[fd]["out"].shutdown(socket.SHUT_WR)
del commr[commw[fd]["in"].fileno()]
del commw[fd]
continue
if x < len(commw[fd]["buf"][0]):
commw[fd]["buf"][0] = commw[fd]["buf"][0][x:]
else:
del commw[fd]["buf"][0]
if not commw[fd]["buf"] and not commw[fd]["in"]:
commw[fd]["out"].shutdown(socket.SHUT_WR)
del commw[fd]["out"]
......@@ -4,7 +4,7 @@
import fcntl, grp, os, pickle, pwd, signal, select, sys, time, traceback
__all__ = [ 'PIPE', 'STDOUT', 'Popen', 'Subprocess', 'spawn', 'wait', 'poll',
'system', 'backticks', 'backticks_raise' ]
'get_user', 'system', 'backticks', 'backticks_raise' ]
# User-facing interfaces
......@@ -280,19 +280,7 @@ def spawn(executable, argv = None, cwd = None, env = None, close_fds = False,
assert not (set([0, 1, 2]) & set(filtered_userfd))
if user != None:
if str(user).isdigit():
uid = int(user)
try:
user = pwd.getpwuid(uid)[0]
except:
raise ValueError("UID %d does not exist" % int(user))
else:
try:
uid = pwd.getpwnam(str(user))[2]
except:
raise ValueError("User %s does not exist" % str(user))
gid = pwd.getpwuid(uid)[3]
user, uid, gid = get_user(user)
groups = [x[2] for x in grp.getgrall() if user in x[3]]
(r, w) = os.pipe()
......@@ -390,6 +378,21 @@ def wait(pid):
"""Wait for process to die and return the exit code."""
return _eintr_wrapper(os.waitpid, pid, 0)[1]
def get_user(user):
"Take either an username or an uid, and return a tuple (user, uid, gid)."
if str(user).isdigit():
uid = int(user)
try:
user = pwd.getpwuid(uid)[0]
except KeyError:
raise ValueError("UID %d does not exist" % int(user))
else:
try:
uid = pwd.getpwnam(str(user))[2]
except KeyError:
raise ValueError("User %s does not exist" % str(user))
gid = pwd.getpwuid(uid)[3]
return user, uid, gid
# internal stuff, do not look!
......
......@@ -189,5 +189,31 @@ class TestGlobal(unittest.TestCase):
self.assertEquals(a.wait(), 0)
class TestX11(unittest.TestCase):
@test_util.skipUnless("DISPLAY" in os.environ, "Test requires working X11")
@test_util.skipUnless(netns.environ.xdpyinfo_path, "Test requires xdpyinfo")
def test_run_xdpyinfo(self):
xdpy = netns.environ.xdpyinfo_path
info = netns.environ.backticks([xdpy])
# remove first line, contains the display name
info = info.partition("\n")[2]
n = netns.Node(nonetns = True, forward_X11 = True)
info2 = n.backticks([xdpy])
info2 = info2.partition("\n")[2]
self.assertEquals(info, info2)
@test_util.skipUnless(os.getuid() == 0, "Test requires root privileges")
@test_util.skipUnless("DISPLAY" in os.environ, "Test requires working X11")
@test_util.skipUnless(netns.environ.xdpyinfo_path, "Test requires xdpyinfo")
def test_run_xdpyinfo_netns(self):
xdpy = netns.environ.xdpyinfo_path
info = netns.environ.backticks([xdpy])
# remove first line, contains the display name
info = info.partition("\n")[2]
n = netns.Node(forward_X11 = True)
info2 = n.backticks([xdpy])
info2 = info2.partition("\n")[2]
self.assertEquals(info, info2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment