Commit f8a4cba9 authored by Jason Madden's avatar Jason Madden

Fix test__example_udp_servers.py by not subclassing _socket.socket, wrapping...

Fix test__example_udp_servers.py by not subclassing _socket.socket, wrapping it instead and providing access to the wrapped socket. This is more like what Python2 does as well.
parent 62360059
......@@ -13,7 +13,7 @@ class EchoServer(DatagramServer):
def handle(self, data, address):
print('%s: got %r' % (address[0], data))
self.socket.sendto('Received %s bytes' % len(data), address)
self.socket.sendto(('Received %s bytes' % len(data)).encode('utf-8'), address)
if __name__ == '__main__':
......
......@@ -27,23 +27,35 @@ def _get_memory(string, offset):
timeout_default = object()
class _wrefsocket(_socket.socket):
# Plain stdlib socket.socket objects subclass _socket.socket
# and add weakref ability. The ssl module, for one, counts on this.
# We don't create socket.socket objects (because they may have been
# monkey patched to be the object from this module), but we still
# need to make sure what we do create can be weakrefd.
class socket(_socket.socket):
__slots__ = ["__weakref__",]
__slots__ = ["__weakref__", "_io_refs", "_closed", "hub", "_read_event", "_write_event", "timeout"]
class socket(object):
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
_socket.socket.__init__(self, family, type, proto, fileno)
# Take the same approach as socket2: wrap a real socket object,
# don't subclass it. This lets code that needs the raw _sock (not tied to the hub)
# get it. This shows up in tests like test__example_udp_server.
self._sock = _wrefsocket(family, type, proto, fileno)
self._io_refs = 0
self._closed = False
_socket.socket.setblocking(self, False)
fileno = _socket.socket.fileno(self)
_socket.socket.setblocking(self._sock, False)
fileno = _socket.socket.fileno(self._sock)
self.hub = get_hub()
io_class = self.hub.loop.io
self._read_event = io_class(fileno, 1)
self._write_event = io_class(fileno, 2)
self.timeout = _socket.getdefaulttimeout()
def __getattr__(self, name):
return getattr(self._sock, name)
if hasattr(_socket, 'SOCK_NONBLOCK'):
# Only defined under Linux
@property
......@@ -59,7 +71,7 @@ class socket(_socket.socket):
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
s = _socket.socket.__repr__(self._sock)
if s.startswith("<socket object"):
s = "<%s.%s%s%s" % (self.__class__.__module__,
self.__class__.__name__,
......@@ -185,7 +197,7 @@ class socket(_socket.socket):
# This function should not reference any globals. See Python issue #808164.
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
_ss.close(self)
_ss.close(self._sock)
def close(self):
# This function should not reference any globals. See Python issue #808164.
......@@ -205,11 +217,11 @@ class socket(_socket.socket):
can be reused for other purposes. The file descriptor is returned.
"""
self._closed = True
return super().detach()
return self._sock.detach()
def connect(self, address):
if self.timeout == 0.0:
return _socket.socket.connect(self, address)
return _socket.socket.connect(self._sock, address)
if isinstance(address, tuple):
r = getaddrinfo(address[0], address[1], self.family)
address = r[0][-1]
......@@ -222,7 +234,7 @@ class socket(_socket.socket):
err = self.getsockopt(SOL_SOCKET, SO_ERROR)
if err:
raise error(err, strerror(err))
result = _socket.socket.connect_ex(self, address)
result = _socket.socket.connect_ex(self._sock, address)
if not result or result == EISCONN:
break
elif (result in (EWOULDBLOCK, EINPROGRESS, EALREADY)) or (result == EINVAL and is_windows):
......@@ -247,7 +259,7 @@ class socket(_socket.socket):
def recv(self, *args):
while True:
try:
return _socket.socket.recv(self, *args)
return _socket.socket.recv(self._sock, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
......@@ -256,7 +268,7 @@ class socket(_socket.socket):
def recvfrom(self, *args):
while True:
try:
return _socket.socket.recvfrom(self, *args)
return _socket.socket.recvfrom(self._sock, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
......@@ -265,7 +277,7 @@ class socket(_socket.socket):
def recvfrom_into(self, *args):
while True:
try:
return _socket.socket.recvfrom_into(self, *args)
return _socket.socket.recvfrom_into(self._sock, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
......@@ -274,7 +286,7 @@ class socket(_socket.socket):
def recv_into(self, *args):
while True:
try:
return _socket.socket.recv_into(self, *args)
return _socket.socket.recv_into(self._sock, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
......@@ -284,13 +296,13 @@ class socket(_socket.socket):
if timeout is timeout_default:
timeout = self.timeout
try:
return _socket.socket.send(self, data, flags)
return _socket.socket.send(self._sock, data, flags)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
self._wait(self._write_event)
try:
return _socket.socket.send(self, data, flags)
return _socket.socket.send(self._sock, data, flags)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
......@@ -315,13 +327,13 @@ class socket(_socket.socket):
def sendto(self, *args):
try:
return _socket.socket.sendto(self, *args)
return _socket.socket.sendto(self._sock, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
self._wait(self._write_event)
try:
return _socket.socket.sendto(self, *args)
return _socket.socket.sendto(self._sock, *args)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
......@@ -355,7 +367,7 @@ class socket(_socket.socket):
else:
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
super().shutdown(how)
self._sock.shutdown(how)
SocketType = socket
......
......@@ -126,7 +126,7 @@ class SSLSocket(socket):
if connected:
# create the SSL object
try:
self._sslobj = self.context._wrap_socket(self, server_side,
self._sslobj = self.context._wrap_socket(getattr(self, '_sock', self), server_side,
server_hostname)
if do_handshake_on_connect:
timeout = self.gettimeout()
......
......@@ -9,9 +9,10 @@ class Test(util.TestServer):
def _run_all_tests(self):
sock = socket.socket(type=socket.SOCK_DGRAM)
sock.connect(('127.0.0.1', 9000))
sock.send('Test udp_server')
sock.send(b'Test udp_server')
data, address = sock.recvfrom(8192)
self.assertEqual(data, 'Received 15 bytes')
self.assertEqual(data, b'Received 15 bytes')
sock.close()
if __name__ == '__main__':
......
......@@ -83,7 +83,6 @@ if PYPY:
if PY3:
# No idea / TODO
FAILING_TESTS += '''
test__example_udp_server.py
test_threading_2.py
test__refcount.py
test__all__.py
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment