Commit 9174fcee authored by Guido van Rossum's avatar Guido van Rossum

Misc asyncio improvements from upstream

parent b8e15d38
......@@ -115,24 +115,16 @@ def _ipaddr_info(host, port, family, type, proto):
if port is None:
port = 0
elif isinstance(port, bytes):
if port == b'':
port = 0
else:
try:
port = int(port)
except ValueError:
# Might be a service name like b"http".
port = socket.getservbyname(port.decode('ascii'))
elif isinstance(port, str):
if port == '':
port = 0
else:
try:
port = int(port)
except ValueError:
# Might be a service name like "http".
port = socket.getservbyname(port)
elif isinstance(port, bytes) and port == b'':
port = 0
elif isinstance(port, str) and port == '':
port = 0
else:
# If port's a service name like "http", don't skip getaddrinfo.
try:
port = int(port)
except (TypeError, ValueError):
return None
if family == socket.AF_UNSPEC:
afs = [socket.AF_INET, socket.AF_INET6]
......
......@@ -3,7 +3,6 @@ import subprocess
import warnings
from . import compat
from . import futures
from . import protocols
from . import transports
from .coroutines import coroutine
......
......@@ -120,8 +120,8 @@ class CoroWrapper:
def send(self, value):
return self.gen.send(value)
def throw(self, exc):
return self.gen.throw(exc)
def throw(self, type, value=None, traceback=None):
return self.gen.throw(type, value, traceback)
def close(self):
return self.gen.close()
......
......@@ -7,7 +7,6 @@ import heapq
from . import compat
from . import events
from . import futures
from . import locks
from .coroutines import coroutine
......
......@@ -594,6 +594,10 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
"""Return a future aggregating results from the given coroutines
or futures.
Coroutines will be wrapped in a future and scheduled in the event
loop. They will not necessarily be scheduled in the same order as
passed in.
All futures must share the same event loop. If all the tasks are
done successfully, the returned future's result is the list of
results (in the order of the original sequence, not necessarily
......
......@@ -142,26 +142,6 @@ class BaseEventTests(test_utils.TestCase):
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', b'1', INET, STREAM, TCP))
def test_getaddrinfo_servname(self):
INET = socket.AF_INET
STREAM = socket.SOCK_STREAM
TCP = socket.IPPROTO_TCP
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 80)),
base_events._ipaddr_info('1.2.3.4', 'http', INET, STREAM, TCP))
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 80)),
base_events._ipaddr_info('1.2.3.4', b'http', INET, STREAM, TCP))
# Raises "service/proto not found".
with self.assertRaises(OSError):
base_events._ipaddr_info('1.2.3.4', 'nonsense', INET, STREAM, TCP)
with self.assertRaises(OSError):
base_events._ipaddr_info('1.2.3.4', 'nonsense', INET, STREAM, TCP)
@patch_socket
def test_ipaddr_info_no_inet_pton(self, m_socket):
del m_socket.inet_pton
......@@ -1209,6 +1189,37 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
def test_create_connection_no_inet_pton(self, m_socket):
self._test_create_connection_ip_addr(m_socket, False)
@patch_socket
def test_create_connection_service_name(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo
sock = m_socket.socket.return_value
self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
self.loop.add_writer = mock.Mock()
self.loop.add_writer._is_coroutine = False
for service, port in ('http', 80), (b'http', 80):
coro = self.loop.create_connection(asyncio.Protocol,
'127.0.0.1', service)
t, p = self.loop.run_until_complete(coro)
try:
sock.connect.assert_called_with(('127.0.0.1', port))
_, kwargs = m_socket.socket.call_args
self.assertEqual(kwargs['family'], m_socket.AF_INET)
self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM)
finally:
t.close()
test_utils.run_briefly(self.loop) # allow transport to close
for service in 'nonsense', b'nonsense':
coro = self.loop.create_connection(asyncio.Protocol,
'127.0.0.1', service)
with self.assertRaises(OSError):
self.loop.run_until_complete(coro)
def test_create_connection_no_local_addr(self):
@asyncio.coroutine
def getaddrinfo(host, *args, **kw):
......
......@@ -2,6 +2,8 @@
import errno
import socket
import threading
import time
import unittest
from unittest import mock
try:
......@@ -1784,5 +1786,89 @@ class SelectorDatagramTransportTests(test_utils.TestCase):
'Fatal error on transport\nprotocol:.*\ntransport:.*'),
exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
class SelectorLoopFunctionalTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
@asyncio.coroutine
def recv_all(self, sock, nbytes):
buf = b''
while len(buf) < nbytes:
buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
return buf
def test_sock_connect_sock_write_race(self):
TIMEOUT = 3.0
PAYLOAD = b'DATA' * 1024 * 1024
class Server(threading.Thread):
def __init__(self, *args, srv_sock, **kwargs):
super().__init__(*args, **kwargs)
self.srv_sock = srv_sock
def run(self):
with self.srv_sock:
srv_sock.listen(100)
sock, addr = self.srv_sock.accept()
sock.settimeout(TIMEOUT)
with sock:
sock.sendall(b'helo')
buf = bytearray()
while len(buf) < len(PAYLOAD):
pack = sock.recv(1024 * 65)
if not pack:
break
buf.extend(pack)
@asyncio.coroutine
def client(addr):
sock = socket.socket()
with sock:
sock.setblocking(False)
started = time.monotonic()
while True:
if time.monotonic() - started > TIMEOUT:
self.fail('unable to connect to the socket')
return
try:
yield from self.loop.sock_connect(sock, addr)
except OSError:
yield from asyncio.sleep(0.05, loop=self.loop)
else:
break
# Give 'Server' thread a chance to accept and send b'helo'
time.sleep(0.1)
data = yield from self.recv_all(sock, 4)
self.assertEqual(data, b'helo')
yield from self.loop.sock_sendall(sock, PAYLOAD)
srv_sock = socket.socket()
srv_sock.settimeout(TIMEOUT)
srv_sock.bind(('127.0.0.1', 0))
srv_addr = srv_sock.getsockname()
srv = Server(srv_sock=srv_sock, daemon=True)
srv.start()
try:
self.loop.run_until_complete(
asyncio.wait_for(client(srv_addr), loop=self.loop,
timeout=TIMEOUT))
finally:
srv.join()
if __name__ == '__main__':
unittest.main()
......@@ -1723,6 +1723,37 @@ class TaskTests(test_utils.TestCase):
wd['cw'] = cw # Would fail without __weakref__ slot.
cw.gen = None # Suppress warning from __del__.
def test_corowrapper_throw(self):
# Issue 429: CoroWrapper.throw must be compatible with gen.throw
def foo():
value = None
while True:
try:
value = yield value
except Exception as e:
value = e
exception = Exception("foo")
cw = asyncio.coroutines.CoroWrapper(foo())
cw.send(None)
self.assertIs(exception, cw.throw(exception))
cw = asyncio.coroutines.CoroWrapper(foo())
cw.send(None)
self.assertIs(exception, cw.throw(Exception, exception))
cw = asyncio.coroutines.CoroWrapper(foo())
cw.send(None)
exception = cw.throw(Exception, "foo")
self.assertIsInstance(exception, Exception)
self.assertEqual(exception.args, ("foo", ))
cw = asyncio.coroutines.CoroWrapper(foo())
cw.send(None)
exception = cw.throw(Exception, "foo", None)
self.assertIsInstance(exception, Exception)
self.assertEqual(exception.args, ("foo", ))
@unittest.skipUnless(PY34,
'need python 3.4 or later')
def test_log_destroyed_pending_task(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment