Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cpython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Kirill Smelkov
cpython
Commits
519883d2
Commit
519883d2
authored
Feb 18, 2014
by
Yury Selivanov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
asyncio: Add support for UNIX Domain Sockets.
parent
242e2659
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
738 additions
and
193 deletions
+738
-193
Lib/asyncio/base_events.py
Lib/asyncio/base_events.py
+7
-0
Lib/asyncio/events.py
Lib/asyncio/events.py
+26
-0
Lib/asyncio/streams.py
Lib/asyncio/streams.py
+38
-1
Lib/asyncio/test_utils.py
Lib/asyncio/test_utils.py
+119
-34
Lib/asyncio/unix_events.py
Lib/asyncio/unix_events.py
+73
-2
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_base_events.py
+1
-1
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_events.py
+235
-114
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_selector_events.py
+2
-1
Lib/test/test_asyncio/test_streams.py
Lib/test/test_asyncio/test_streams.py
+156
-39
Lib/test/test_asyncio/test_unix_events.py
Lib/test/test_asyncio/test_unix_events.py
+81
-1
No files found.
Lib/asyncio/base_events.py
View file @
519883d2
...
...
@@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
sock
.
setblocking
(
False
)
transport
,
protocol
=
yield
from
self
.
_create_connection_transport
(
sock
,
protocol_factory
,
ssl
,
server_hostname
)
return
transport
,
protocol
@
tasks
.
coroutine
def
_create_connection_transport
(
self
,
sock
,
protocol_factory
,
ssl
,
server_hostname
):
protocol
=
protocol_factory
()
waiter
=
futures
.
Future
(
loop
=
self
)
if
ssl
:
...
...
Lib/asyncio/events.py
View file @
519883d2
...
...
@@ -220,6 +220,32 @@ class AbstractEventLoop:
"""
raise
NotImplementedError
def
create_unix_connection
(
self
,
protocol_factory
,
path
,
*
,
ssl
=
None
,
sock
=
None
,
server_hostname
=
None
):
raise
NotImplementedError
def
create_unix_server
(
self
,
protocol_factory
,
path
,
*
,
sock
=
None
,
backlog
=
100
,
ssl
=
None
):
"""A coroutine which creates a UNIX Domain Socket server.
The return valud is a Server object, which can be used to stop
the service.
path is a str, representing a file systsem path to bind the
server socket to.
sock can optionally be specified in order to use a preexisting
socket object.
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
"""
raise
NotImplementedError
def
create_datagram_endpoint
(
self
,
protocol_factory
,
local_addr
=
None
,
remote_addr
=
None
,
*
,
family
=
0
,
proto
=
0
,
flags
=
0
):
...
...
Lib/asyncio/streams.py
View file @
519883d2
"""Stream-related things."""
__all__
=
[
'StreamReader'
,
'StreamWriter'
,
'StreamReaderProtocol'
,
'open_connection'
,
'start_server'
,
'IncompleteReadError'
,
'open_connection'
,
'start_server'
,
'open_unix_connection'
,
'start_unix_server'
,
'IncompleteReadError'
,
]
import
socket
from
.
import
events
from
.
import
futures
from
.
import
protocols
...
...
@@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
return
(
yield
from
loop
.
create_server
(
factory
,
host
,
port
,
**
kwds
))
if
hasattr
(
socket
,
'AF_UNIX'
):
# UNIX Domain Sockets are supported on this platform
@
tasks
.
coroutine
def
open_unix_connection
(
path
=
None
,
*
,
loop
=
None
,
limit
=
_DEFAULT_LIMIT
,
**
kwds
):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
if
loop
is
None
:
loop
=
events
.
get_event_loop
()
reader
=
StreamReader
(
limit
=
limit
,
loop
=
loop
)
protocol
=
StreamReaderProtocol
(
reader
,
loop
=
loop
)
transport
,
_
=
yield
from
loop
.
create_unix_connection
(
lambda
:
protocol
,
path
,
**
kwds
)
writer
=
StreamWriter
(
transport
,
protocol
,
reader
,
loop
)
return
reader
,
writer
@
tasks
.
coroutine
def
start_unix_server
(
client_connected_cb
,
path
=
None
,
*
,
loop
=
None
,
limit
=
_DEFAULT_LIMIT
,
**
kwds
):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
if
loop
is
None
:
loop
=
events
.
get_event_loop
()
def
factory
():
reader
=
StreamReader
(
limit
=
limit
,
loop
=
loop
)
protocol
=
StreamReaderProtocol
(
reader
,
client_connected_cb
,
loop
=
loop
)
return
protocol
return
(
yield
from
loop
.
create_unix_server
(
factory
,
path
,
**
kwds
))
class
FlowControlMixin
(
protocols
.
Protocol
):
"""Reusable flow control logic for StreamWriter.drain().
...
...
Lib/asyncio/test_utils.py
View file @
519883d2
...
...
@@ -4,12 +4,18 @@ import collections
import
contextlib
import
io
import
os
import
socket
import
socketserver
import
sys
import
tempfile
import
threading
import
time
import
unittest
import
unittest.mock
from
http.server
import
HTTPServer
from
wsgiref.simple_server
import
make_server
,
WSGIRequestHandler
,
WSGIServer
try
:
import
ssl
except
ImportError
:
# pragma: no cover
...
...
@@ -70,42 +76,51 @@ def run_once(loop):
loop
.
run_forever
()
@
contextlib
.
contextmanager
def
run_test_server
(
*
,
host
=
'127.0.0.1'
,
port
=
0
,
use_ssl
=
False
):
class
SilentWSGIRequestHandler
(
WSGIRequestHandler
):
class
SilentWSGIRequestHandler
(
WSGIRequestHandler
):
def
get_stderr
(
self
):
return
io
.
StringIO
()
def
get_stderr
(
self
):
return
io
.
StringIO
()
def
log_message
(
self
,
format
,
*
args
):
pass
def
log_message
(
self
,
format
,
*
args
):
pass
class
SilentWSGIServer
(
WSGIServer
):
def
handle_error
(
self
,
request
,
client_address
):
class
SilentWSGIServer
(
WSGIServer
):
def
handle_error
(
self
,
request
,
client_address
):
pass
class
SSLWSGIServerMixin
:
def
finish_request
(
self
,
request
,
client_address
):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'tests'
)
if
not
os
.
path
.
isdir
(
here
):
here
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
__file__
),
'test'
,
'test_asyncio'
)
keyfile
=
os
.
path
.
join
(
here
,
'ssl_key.pem'
)
certfile
=
os
.
path
.
join
(
here
,
'ssl_cert.pem'
)
ssock
=
ssl
.
wrap_socket
(
request
,
keyfile
=
keyfile
,
certfile
=
certfile
,
server_side
=
True
)
try
:
self
.
RequestHandlerClass
(
ssock
,
client_address
,
self
)
ssock
.
close
()
except
OSError
:
# maybe socket has been closed by peer
pass
class
SSLWSGIServer
(
SilentWSGIServer
):
def
finish_request
(
self
,
request
,
client_address
):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'tests'
)
if
not
os
.
path
.
isdir
(
here
):
here
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
__file__
),
'test'
,
'test_asyncio'
)
keyfile
=
os
.
path
.
join
(
here
,
'ssl_key.pem'
)
certfile
=
os
.
path
.
join
(
here
,
'ssl_cert.pem'
)
ssock
=
ssl
.
wrap_socket
(
request
,
keyfile
=
keyfile
,
certfile
=
certfile
,
server_side
=
True
)
try
:
self
.
RequestHandlerClass
(
ssock
,
client_address
,
self
)
ssock
.
close
()
except
OSError
:
# maybe socket has been closed by peer
pass
class
SSLWSGIServer
(
SSLWSGIServerMixin
,
SilentWSGIServer
):
pass
def
_run_test_server
(
*
,
address
,
use_ssl
=
False
,
server_cls
,
server_ssl_cls
):
def
app
(
environ
,
start_response
):
status
=
'200 OK'
...
...
@@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
server_class
=
SSLWSGIServer
if
use_ssl
else
SilentWSGIServer
httpd
=
make_server
(
host
,
port
,
app
,
server_class
,
SilentWSGIRequestHandler
)
server_class
=
server_ssl_cls
if
use_ssl
else
server_cls
httpd
=
server_class
(
address
,
SilentWSGIRequestHandler
)
httpd
.
set_app
(
app
)
httpd
.
address
=
httpd
.
server_address
server_thread
=
threading
.
Thread
(
target
=
httpd
.
serve_forever
)
server_thread
.
start
()
...
...
@@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread
.
join
()
if
hasattr
(
socket
,
'AF_UNIX'
):
class
UnixHTTPServer
(
socketserver
.
UnixStreamServer
,
HTTPServer
):
def
server_bind
(
self
):
socketserver
.
UnixStreamServer
.
server_bind
(
self
)
self
.
server_name
=
'127.0.0.1'
self
.
server_port
=
80
class
UnixWSGIServer
(
UnixHTTPServer
,
WSGIServer
):
def
server_bind
(
self
):
UnixHTTPServer
.
server_bind
(
self
)
self
.
setup_environ
()
def
get_request
(
self
):
request
,
client_addr
=
super
().
get_request
()
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
# However, this isn't true for UNIX sockets,
# as the second return value will be a path;
# hence we return some fake data sufficient
# to get the tests going
return
request
,
(
'127.0.0.1'
,
''
)
class
SilentUnixWSGIServer
(
UnixWSGIServer
):
def
handle_error
(
self
,
request
,
client_address
):
pass
class
UnixSSLWSGIServer
(
SSLWSGIServerMixin
,
SilentUnixWSGIServer
):
pass
def
gen_unix_socket_path
():
with
tempfile
.
NamedTemporaryFile
()
as
file
:
return
file
.
name
@
contextlib
.
contextmanager
def
unix_socket_path
():
path
=
gen_unix_socket_path
()
try
:
yield
path
finally
:
try
:
os
.
unlink
(
path
)
except
OSError
:
pass
@
contextlib
.
contextmanager
def
run_test_unix_server
(
*
,
use_ssl
=
False
):
with
unix_socket_path
()
as
path
:
yield
from
_run_test_server
(
address
=
path
,
use_ssl
=
use_ssl
,
server_cls
=
SilentUnixWSGIServer
,
server_ssl_cls
=
UnixSSLWSGIServer
)
@
contextlib
.
contextmanager
def
run_test_server
(
*
,
host
=
'127.0.0.1'
,
port
=
0
,
use_ssl
=
False
):
yield
from
_run_test_server
(
address
=
(
host
,
port
),
use_ssl
=
use_ssl
,
server_cls
=
SilentWSGIServer
,
server_ssl_cls
=
SSLWSGIServer
)
def
make_test_protocol
(
base
):
dct
=
{}
for
name
in
dir
(
base
):
...
...
@@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def
_write_to_self
(
self
):
pass
def
MockCallback
(
**
kwargs
):
return
unittest
.
mock
.
Mock
(
spec
=
[
'__call__'
],
**
kwargs
)
Lib/asyncio/unix_events.py
View file @
519883d2
...
...
@@ -11,6 +11,7 @@ import sys
import
threading
from
.
import
base_events
from
.
import
base_subprocess
from
.
import
constants
from
.
import
events
...
...
@@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover
class
_UnixSelectorEventLoop
(
selector_events
.
BaseSelectorEventLoop
):
"""Unix event loop
"""Unix event loop
.
Adds signal handling
to SelectorEventLoop
Adds signal handling
and UNIX Domain Socket support to SelectorEventLoop.
"""
def
__init__
(
self
,
selector
=
None
):
...
...
@@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def
_child_watcher_callback
(
self
,
pid
,
returncode
,
transp
):
self
.
call_soon_threadsafe
(
transp
.
_process_exited
,
returncode
)
@
tasks
.
coroutine
def
create_unix_connection
(
self
,
protocol_factory
,
path
,
*
,
ssl
=
None
,
sock
=
None
,
server_hostname
=
None
):
assert
server_hostname
is
None
or
isinstance
(
server_hostname
,
str
)
if
ssl
:
if
server_hostname
is
None
:
raise
ValueError
(
'you have to pass server_hostname when using ssl'
)
else
:
if
server_hostname
is
not
None
:
raise
ValueError
(
'server_hostname is only meaningful with ssl'
)
if
path
is
not
None
:
if
sock
is
not
None
:
raise
ValueError
(
'path and sock can not be specified at the same time'
)
try
:
sock
=
socket
.
socket
(
socket
.
AF_UNIX
,
socket
.
SOCK_STREAM
,
0
)
sock
.
setblocking
(
False
)
yield
from
self
.
sock_connect
(
sock
,
path
)
except
OSError
:
if
sock
is
not
None
:
sock
.
close
()
raise
else
:
if
sock
is
None
:
raise
ValueError
(
'no path and sock were specified'
)
sock
.
setblocking
(
False
)
transport
,
protocol
=
yield
from
self
.
_create_connection_transport
(
sock
,
protocol_factory
,
ssl
,
server_hostname
)
return
transport
,
protocol
@
tasks
.
coroutine
def
create_unix_server
(
self
,
protocol_factory
,
path
=
None
,
*
,
sock
=
None
,
backlog
=
100
,
ssl
=
None
):
if
isinstance
(
ssl
,
bool
):
raise
TypeError
(
'ssl argument must be an SSLContext or None'
)
if
path
is
not
None
:
sock
=
socket
.
socket
(
socket
.
AF_UNIX
,
socket
.
SOCK_STREAM
)
try
:
sock
.
bind
(
path
)
except
OSError
as
exc
:
if
exc
.
errno
==
errno
.
EADDRINUSE
:
# Let's improve the error message by adding
# with what exact address it occurs.
msg
=
'Address {!r} is already in use'
.
format
(
path
)
raise
OSError
(
errno
.
EADDRINUSE
,
msg
)
from
None
else
:
raise
else
:
if
sock
is
None
:
raise
ValueError
(
'path was not specified, and no sock specified'
)
if
sock
.
family
!=
socket
.
AF_UNIX
:
raise
ValueError
(
'A UNIX Domain Socket was expected, got {!r}'
.
format
(
sock
))
server
=
base_events
.
Server
(
self
,
[
sock
])
sock
.
listen
(
backlog
)
sock
.
setblocking
(
False
)
self
.
_start_serving
(
protocol_factory
,
sock
,
ssl
,
server
)
return
server
def
_set_nonblocking
(
fd
):
flags
=
fcntl
.
fcntl
(
fd
,
fcntl
.
F_GETFL
)
...
...
Lib/test/test_asyncio/test_base_events.py
View file @
519883d2
...
...
@@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase):
idx
=
-
1
data
=
[
10.0
,
10.0
,
10.3
,
13.0
]
self
.
loop
.
_scheduled
=
[
asyncio
.
TimerHandle
(
11.0
,
lambda
:
True
,
())]
self
.
loop
.
_scheduled
=
[
asyncio
.
TimerHandle
(
11.0
,
lambda
:
True
,
())]
self
.
loop
.
_run_once
()
self
.
assertEqual
(
logging
.
DEBUG
,
m_logger
.
log
.
call_args
[
0
][
0
])
...
...
Lib/test/test_asyncio/test_events.py
View file @
519883d2
...
...
@@ -39,13 +39,14 @@ def data_file(filename):
return
fullname
raise
FileNotFoundError
(
filename
)
ONLYCERT
=
data_file
(
'ssl_cert.pem'
)
ONLYKEY
=
data_file
(
'ssl_key.pem'
)
SIGNED_CERTFILE
=
data_file
(
'keycert3.pem'
)
SIGNING_CA
=
data_file
(
'pycacert.pem'
)
class
MyProto
(
asyncio
.
Protocol
):
class
My
Base
Proto
(
asyncio
.
Protocol
):
done
=
None
def
__init__
(
self
,
loop
=
None
):
...
...
@@ -59,7 +60,6 @@ class MyProto(asyncio.Protocol):
self
.
transport
=
transport
assert
self
.
state
==
'INITIAL'
,
self
.
state
self
.
state
=
'CONNECTED'
transport
.
write
(
b'GET / HTTP/1.0
\
r
\
n
Host: example.com
\
r
\
n
\
r
\
n
'
)
def
data_received
(
self
,
data
):
assert
self
.
state
==
'CONNECTED'
,
self
.
state
...
...
@@ -76,6 +76,12 @@ class MyProto(asyncio.Protocol):
self
.
done
.
set_result
(
None
)
class
MyProto
(
MyBaseProto
):
def
connection_made
(
self
,
transport
):
super
().
connection_made
(
transport
)
transport
.
write
(
b'GET / HTTP/1.0
\
r
\
n
Host: example.com
\
r
\
n
\
r
\
n
'
)
class
MyDatagramProto
(
asyncio
.
DatagramProtocol
):
done
=
None
...
...
@@ -357,22 +363,30 @@ class EventLoopTestsMixin:
r
.
close
()
self
.
assertGreaterEqual
(
len
(
data
),
200
)
def
_basetest_sock_client_ops
(
self
,
httpd
,
sock
):
sock
.
setblocking
(
False
)
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_connect
(
sock
,
httpd
.
address
))
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_sendall
(
sock
,
b'GET / HTTP/1.0
\
r
\
n
\
r
\
n
'
))
data
=
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_recv
(
sock
,
1024
))
# consume data
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_recv
(
sock
,
1024
))
sock
.
close
()
self
.
assertTrue
(
data
.
startswith
(
b'HTTP/1.0 200 OK'
))
def
test_sock_client_ops
(
self
):
with
test_utils
.
run_test_server
()
as
httpd
:
sock
=
socket
.
socket
()
sock
.
setblocking
(
False
)
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_connect
(
sock
,
httpd
.
address
))
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_sendall
(
sock
,
b'GET / HTTP/1.0
\
r
\
n
\
r
\
n
'
))
data
=
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_recv
(
sock
,
1024
))
# consume data
self
.
loop
.
run_until_complete
(
self
.
loop
.
sock_recv
(
sock
,
1024
))
sock
.
close
()
self
.
_basetest_sock_client_ops
(
httpd
,
sock
)
self
.
assertTrue
(
data
.
startswith
(
b'HTTP/1.0 200 OK'
))
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_unix_sock_client_ops
(
self
):
with
test_utils
.
run_test_unix_server
()
as
httpd
:
sock
=
socket
.
socket
(
socket
.
AF_UNIX
)
self
.
_basetest_sock_client_ops
(
httpd
,
sock
)
def
test_sock_client_fail
(
self
):
# Make sure that we will get an unused port
...
...
@@ -485,16 +499,26 @@ class EventLoopTestsMixin:
self
.
loop
.
run_forever
()
self
.
assertEqual
(
caught
,
1
)
def
_basetest_create_connection
(
self
,
connection_fut
):
tr
,
pr
=
self
.
loop
.
run_until_complete
(
connection_fut
)
self
.
assertIsInstance
(
tr
,
asyncio
.
Transport
)
self
.
assertIsInstance
(
pr
,
asyncio
.
Protocol
)
self
.
loop
.
run_until_complete
(
pr
.
done
)
self
.
assertGreater
(
pr
.
nbytes
,
0
)
tr
.
close
()
def
test_create_connection
(
self
):
with
test_utils
.
run_test_server
()
as
httpd
:
f
=
self
.
loop
.
create_connection
(
conn_fut
=
self
.
loop
.
create_connection
(
lambda
:
MyProto
(
loop
=
self
.
loop
),
*
httpd
.
address
)
tr
,
pr
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertIsInstance
(
tr
,
asyncio
.
Transport
)
self
.
assertIsInstance
(
pr
,
asyncio
.
Protocol
)
self
.
loop
.
run_until_complete
(
pr
.
done
)
self
.
assertGreater
(
pr
.
nbytes
,
0
)
tr
.
close
()
self
.
_basetest_create_connection
(
conn_fut
)
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_create_unix_connection
(
self
):
with
test_utils
.
run_test_unix_server
()
as
httpd
:
conn_fut
=
self
.
loop
.
create_unix_connection
(
lambda
:
MyProto
(
loop
=
self
.
loop
),
httpd
.
address
)
self
.
_basetest_create_connection
(
conn_fut
)
def
test_create_connection_sock
(
self
):
with
test_utils
.
run_test_server
()
as
httpd
:
...
...
@@ -524,20 +548,37 @@ class EventLoopTestsMixin:
self
.
assertGreater
(
pr
.
nbytes
,
0
)
tr
.
close
()
def
_basetest_create_ssl_connection
(
self
,
connection_fut
):
tr
,
pr
=
self
.
loop
.
run_until_complete
(
connection_fut
)
self
.
assertIsInstance
(
tr
,
asyncio
.
Transport
)
self
.
assertIsInstance
(
pr
,
asyncio
.
Protocol
)
self
.
assertTrue
(
'ssl'
in
tr
.
__class__
.
__name__
.
lower
())
self
.
assertIsNotNone
(
tr
.
get_extra_info
(
'sockname'
))
self
.
loop
.
run_until_complete
(
pr
.
done
)
self
.
assertGreater
(
pr
.
nbytes
,
0
)
tr
.
close
()
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
def
test_create_ssl_connection
(
self
):
with
test_utils
.
run_test_server
(
use_ssl
=
True
)
as
httpd
:
f
=
self
.
loop
.
create_connection
(
lambda
:
MyProto
(
loop
=
self
.
loop
),
*
httpd
.
address
,
conn_fut
=
self
.
loop
.
create_connection
(
lambda
:
MyProto
(
loop
=
self
.
loop
),
*
httpd
.
address
,
ssl
=
test_utils
.
dummy_ssl_context
())
tr
,
pr
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertIsInstance
(
tr
,
asyncio
.
Transport
)
self
.
assertIsInstance
(
pr
,
asyncio
.
Protocol
)
self
.
assertTrue
(
'ssl'
in
tr
.
__class__
.
__name__
.
lower
())
self
.
assertIsNotNone
(
tr
.
get_extra_info
(
'sockname'
))
self
.
loop
.
run_until_complete
(
pr
.
done
)
self
.
assertGreater
(
pr
.
nbytes
,
0
)
tr
.
close
()
self
.
_basetest_create_ssl_connection
(
conn_fut
)
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_create_ssl_unix_connection
(
self
):
with
test_utils
.
run_test_unix_server
(
use_ssl
=
True
)
as
httpd
:
conn_fut
=
self
.
loop
.
create_unix_connection
(
lambda
:
MyProto
(
loop
=
self
.
loop
),
httpd
.
address
,
ssl
=
test_utils
.
dummy_ssl_context
(),
server_hostname
=
'127.0.0.1'
)
self
.
_basetest_create_ssl_connection
(
conn_fut
)
def
test_create_connection_local_addr
(
self
):
with
test_utils
.
run_test_server
()
as
httpd
:
...
...
@@ -561,14 +602,8 @@ class EventLoopTestsMixin:
self
.
assertIn
(
str
(
httpd
.
address
),
cm
.
exception
.
strerror
)
def
test_create_server
(
self
):
proto
=
None
def
factory
():
nonlocal
proto
proto
=
MyProto
()
return
proto
f
=
self
.
loop
.
create_server
(
factory
,
'0.0.0.0'
,
0
)
proto
=
MyProto
()
f
=
self
.
loop
.
create_server
(
lambda
:
proto
,
'0.0.0.0'
,
0
)
server
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertEqual
(
len
(
server
.
sockets
),
1
)
sock
=
server
.
sockets
[
0
]
...
...
@@ -605,38 +640,76 @@ class EventLoopTestsMixin:
# close server
server
.
close
()
def
_make_ssl_server
(
self
,
factory
,
certfile
,
keyfile
=
None
):
def
_make_unix_server
(
self
,
factory
,
**
kwargs
):
path
=
test_utils
.
gen_unix_socket_path
()
self
.
addCleanup
(
lambda
:
os
.
path
.
exists
(
path
)
and
os
.
unlink
(
path
))
f
=
self
.
loop
.
create_unix_server
(
factory
,
path
,
**
kwargs
)
server
=
self
.
loop
.
run_until_complete
(
f
)
return
server
,
path
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_create_unix_server
(
self
):
proto
=
MyProto
()
server
,
path
=
self
.
_make_unix_server
(
lambda
:
proto
)
self
.
assertEqual
(
len
(
server
.
sockets
),
1
)
client
=
socket
.
socket
(
socket
.
AF_UNIX
)
client
.
connect
(
path
)
client
.
sendall
(
b'xxx'
)
test_utils
.
run_briefly
(
self
.
loop
)
test_utils
.
run_until
(
self
.
loop
,
lambda
:
proto
is
not
None
,
10
)
self
.
assertIsInstance
(
proto
,
MyProto
)
self
.
assertEqual
(
'INITIAL'
,
proto
.
state
)
test_utils
.
run_briefly
(
self
.
loop
)
self
.
assertEqual
(
'CONNECTED'
,
proto
.
state
)
test_utils
.
run_until
(
self
.
loop
,
lambda
:
proto
.
nbytes
>
0
,
timeout
=
10
)
self
.
assertEqual
(
3
,
proto
.
nbytes
)
# close connection
proto
.
transport
.
close
()
test_utils
.
run_briefly
(
self
.
loop
)
# windows iocp
self
.
assertEqual
(
'CLOSED'
,
proto
.
state
)
# the client socket must be closed after to avoid ECONNRESET upon
# recv()/send() on the serving socket
client
.
close
()
# close server
server
.
close
()
def
_create_ssl_context
(
self
,
certfile
,
keyfile
=
None
):
sslcontext
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_SSLv23
)
sslcontext
.
options
|=
ssl
.
OP_NO_SSLv2
sslcontext
.
load_cert_chain
(
certfile
,
keyfile
)
return
sslcontext
f
=
self
.
loop
.
create_server
(
factory
,
'127.0.0.1'
,
0
,
ssl
=
sslcontext
)
def
_make_ssl_server
(
self
,
factory
,
certfile
,
keyfile
=
None
):
sslcontext
=
self
.
_create_ssl_context
(
certfile
,
keyfile
)
f
=
self
.
loop
.
create_server
(
factory
,
'127.0.0.1'
,
0
,
ssl
=
sslcontext
)
server
=
self
.
loop
.
run_until_complete
(
f
)
sock
=
server
.
sockets
[
0
]
host
,
port
=
sock
.
getsockname
()
self
.
assertEqual
(
host
,
'127.0.0.1'
)
return
server
,
host
,
port
def
_make_ssl_unix_server
(
self
,
factory
,
certfile
,
keyfile
=
None
):
sslcontext
=
self
.
_create_ssl_context
(
certfile
,
keyfile
)
return
self
.
_make_unix_server
(
factory
,
ssl
=
sslcontext
)
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
def
test_create_server_ssl
(
self
):
proto
=
None
class
ClientMyProto
(
MyProto
):
def
connection_made
(
self
,
transport
):
self
.
transport
=
transport
assert
self
.
state
==
'INITIAL'
,
self
.
state
self
.
state
=
'CONNECTED'
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
ONLYCERT
,
ONLYKEY
)
def
factory
():
nonlocal
proto
proto
=
MyProto
(
loop
=
self
.
loop
)
return
proto
server
,
host
,
port
=
self
.
_make_ssl_server
(
factory
,
ONLYCERT
,
ONLYKEY
)
f_c
=
self
.
loop
.
create_connection
(
ClientMyProto
,
host
,
port
,
f_c
=
self
.
loop
.
create_connection
(
MyBaseProto
,
host
,
port
,
ssl
=
test_utils
.
dummy_ssl_context
())
client
,
pr
=
self
.
loop
.
run_until_complete
(
f_c
)
...
...
@@ -667,16 +740,45 @@ class EventLoopTestsMixin:
server
.
close
()
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
HAS_SNI
,
'No SNI support in ssl module'
)
def
test_create_server_ssl_verify_failed
(
self
):
proto
=
None
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_create_unix_server_ssl
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
path
=
self
.
_make_ssl_unix_server
(
lambda
:
proto
,
ONLYCERT
,
ONLYKEY
)
def
factory
():
nonlocal
proto
proto
=
MyProto
(
loop
=
self
.
loop
)
return
proto
f_c
=
self
.
loop
.
create_unix_connection
(
MyBaseProto
,
path
,
ssl
=
test_utils
.
dummy_ssl_context
(),
server_hostname
=
''
)
client
,
pr
=
self
.
loop
.
run_until_complete
(
f_c
)
client
.
write
(
b'xxx'
)
test_utils
.
run_briefly
(
self
.
loop
)
self
.
assertIsInstance
(
proto
,
MyProto
)
test_utils
.
run_briefly
(
self
.
loop
)
self
.
assertEqual
(
'CONNECTED'
,
proto
.
state
)
test_utils
.
run_until
(
self
.
loop
,
lambda
:
proto
.
nbytes
>
0
,
timeout
=
10
)
self
.
assertEqual
(
3
,
proto
.
nbytes
)
# close connection
proto
.
transport
.
close
()
self
.
loop
.
run_until_complete
(
proto
.
done
)
self
.
assertEqual
(
'CLOSED'
,
proto
.
state
)
# the client socket must be closed after to avoid ECONNRESET upon
# recv()/send() on the serving socket
client
.
close
()
server
,
host
,
port
=
self
.
_make_ssl_server
(
factory
,
SIGNED_CERTFILE
)
# stop serving
server
.
close
()
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
HAS_SNI
,
'No SNI support in ssl module'
)
def
test_create_server_ssl_verify_failed
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_SSLv23
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
...
...
@@ -697,15 +799,36 @@ class EventLoopTestsMixin:
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
HAS_SNI
,
'No SNI support in ssl module'
)
def
test_create_server_ssl_match_failed
(
self
):
proto
=
None
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_create_unix_server_ssl_verify_failed
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
path
=
self
.
_make_ssl_unix_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
def
factory
():
nonlocal
proto
proto
=
MyProto
(
loop
=
self
.
loop
)
return
proto
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_SSLv23
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
sslcontext_client
.
verify_mode
=
ssl
.
CERT_REQUIRED
if
hasattr
(
sslcontext_client
,
'check_hostname'
):
sslcontext_client
.
check_hostname
=
True
server
,
host
,
port
=
self
.
_make_ssl_server
(
factory
,
SIGNED_CERTFILE
)
# no CA loaded
f_c
=
self
.
loop
.
create_unix_connection
(
MyProto
,
path
,
ssl
=
sslcontext_client
,
server_hostname
=
'invalid'
)
with
self
.
assertRaisesRegex
(
ssl
.
SSLError
,
'certificate verify failed '
):
self
.
loop
.
run_until_complete
(
f_c
)
# close connection
self
.
assertIsNone
(
proto
.
transport
)
server
.
close
()
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
HAS_SNI
,
'No SNI support in ssl module'
)
def
test_create_server_ssl_match_failed
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_SSLv23
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
...
...
@@ -729,15 +852,36 @@ class EventLoopTestsMixin:
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
HAS_SNI
,
'No SNI support in ssl module'
)
def
test_create_server_ssl_verified
(
self
):
proto
=
None
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_create_unix_server_ssl_verified
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
path
=
self
.
_make_ssl_unix_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
def
factory
():
nonlocal
proto
proto
=
MyProto
(
loop
=
self
.
loop
)
return
proto
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_SSLv23
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
sslcontext_client
.
verify_mode
=
ssl
.
CERT_REQUIRED
sslcontext_client
.
load_verify_locations
(
cafile
=
SIGNING_CA
)
if
hasattr
(
sslcontext_client
,
'check_hostname'
):
sslcontext_client
.
check_hostname
=
True
server
,
host
,
port
=
self
.
_make_ssl_server
(
factory
,
SIGNED_CERTFILE
)
# Connection succeeds with correct CA and server hostname.
f_c
=
self
.
loop
.
create_unix_connection
(
MyProto
,
path
,
ssl
=
sslcontext_client
,
server_hostname
=
'localhost'
)
client
,
pr
=
self
.
loop
.
run_until_complete
(
f_c
)
# close connection
proto
.
transport
.
close
()
client
.
close
()
server
.
close
()
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
HAS_SNI
,
'No SNI support in ssl module'
)
def
test_create_server_ssl_verified
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_SSLv23
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
...
...
@@ -915,19 +1059,15 @@ class EventLoopTestsMixin:
@
unittest
.
skipUnless
(
sys
.
platform
!=
'win32'
,
"Don't support pipes for Windows"
)
def
test_read_pipe
(
self
):
proto
=
None
def
factory
():
nonlocal
proto
proto
=
MyReadPipeProto
(
loop
=
self
.
loop
)
return
proto
proto
=
MyReadPipeProto
(
loop
=
self
.
loop
)
rpipe
,
wpipe
=
os
.
pipe
()
pipeobj
=
io
.
open
(
rpipe
,
'rb'
,
1024
)
@
asyncio
.
coroutine
def
connect
():
t
,
p
=
yield
from
self
.
loop
.
connect_read_pipe
(
factory
,
pipeobj
)
t
,
p
=
yield
from
self
.
loop
.
connect_read_pipe
(
lambda
:
proto
,
pipeobj
)
self
.
assertIs
(
p
,
proto
)
self
.
assertIs
(
t
,
proto
.
transport
)
self
.
assertEqual
([
'INITIAL'
,
'CONNECTED'
],
proto
.
state
)
...
...
@@ -959,19 +1099,14 @@ class EventLoopTestsMixin:
# Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9
@
support
.
requires_freebsd_version
(
8
)
def
test_read_pty_output
(
self
):
proto
=
None
def
factory
():
nonlocal
proto
proto
=
MyReadPipeProto
(
loop
=
self
.
loop
)
return
proto
proto
=
MyReadPipeProto
(
loop
=
self
.
loop
)
master
,
slave
=
os
.
openpty
()
master_read_obj
=
io
.
open
(
master
,
'rb'
,
0
)
@
asyncio
.
coroutine
def
connect
():
t
,
p
=
yield
from
self
.
loop
.
connect_read_pipe
(
factory
,
t
,
p
=
yield
from
self
.
loop
.
connect_read_pipe
(
lambda
:
proto
,
master_read_obj
)
self
.
assertIs
(
p
,
proto
)
self
.
assertIs
(
t
,
proto
.
transport
)
...
...
@@ -999,21 +1134,17 @@ class EventLoopTestsMixin:
@
unittest
.
skipUnless
(
sys
.
platform
!=
'win32'
,
"Don't support pipes for Windows"
)
def
test_write_pipe
(
self
):
proto
=
None
proto
=
MyWritePipeProto
(
loop
=
self
.
loop
)
transport
=
None
def
factory
():
nonlocal
proto
proto
=
MyWritePipeProto
(
loop
=
self
.
loop
)
return
proto
rpipe
,
wpipe
=
os
.
pipe
()
pipeobj
=
io
.
open
(
wpipe
,
'wb'
,
1024
)
@
asyncio
.
coroutine
def
connect
():
nonlocal
transport
t
,
p
=
yield
from
self
.
loop
.
connect_write_pipe
(
factory
,
pipeobj
)
t
,
p
=
yield
from
self
.
loop
.
connect_write_pipe
(
lambda
:
proto
,
pipeobj
)
self
.
assertIs
(
p
,
proto
)
self
.
assertIs
(
t
,
proto
.
transport
)
self
.
assertEqual
(
'CONNECTED'
,
proto
.
state
)
...
...
@@ -1045,21 +1176,16 @@ class EventLoopTestsMixin:
@
unittest
.
skipUnless
(
sys
.
platform
!=
'win32'
,
"Don't support pipes for Windows"
)
def
test_write_pipe_disconnect_on_close
(
self
):
proto
=
None
proto
=
MyWritePipeProto
(
loop
=
self
.
loop
)
transport
=
None
def
factory
():
nonlocal
proto
proto
=
MyWritePipeProto
(
loop
=
self
.
loop
)
return
proto
rsock
,
wsock
=
test_utils
.
socketpair
()
pipeobj
=
io
.
open
(
wsock
.
detach
(),
'wb'
,
1024
)
@
asyncio
.
coroutine
def
connect
():
nonlocal
transport
t
,
p
=
yield
from
self
.
loop
.
connect_write_pipe
(
factory
,
t
,
p
=
yield
from
self
.
loop
.
connect_write_pipe
(
lambda
:
proto
,
pipeobj
)
self
.
assertIs
(
p
,
proto
)
self
.
assertIs
(
t
,
proto
.
transport
)
...
...
@@ -1084,21 +1210,16 @@ class EventLoopTestsMixin:
# older than 10.6 (Snow Leopard)
@
support
.
requires_mac_ver
(
10
,
6
)
def
test_write_pty
(
self
):
proto
=
None
proto
=
MyWritePipeProto
(
loop
=
self
.
loop
)
transport
=
None
def
factory
():
nonlocal
proto
proto
=
MyWritePipeProto
(
loop
=
self
.
loop
)
return
proto
master
,
slave
=
os
.
openpty
()
slave_write_obj
=
io
.
open
(
slave
,
'wb'
,
0
)
@
asyncio
.
coroutine
def
connect
():
nonlocal
transport
t
,
p
=
yield
from
self
.
loop
.
connect_write_pipe
(
factory
,
t
,
p
=
yield
from
self
.
loop
.
connect_write_pipe
(
lambda
:
proto
,
slave_write_obj
)
self
.
assertIs
(
p
,
proto
)
self
.
assertIs
(
t
,
proto
.
transport
)
...
...
Lib/test/test_asyncio/test_selector_events.py
View file @
519883d2
...
...
@@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self
.
loop
.
remove_reader
=
unittest
.
mock
.
Mock
()
self
.
loop
.
remove_writer
=
unittest
.
mock
.
Mock
()
waiter
=
asyncio
.
Future
(
loop
=
self
.
loop
)
transport
=
self
.
loop
.
_make_ssl_transport
(
m
,
asyncio
.
Protocol
(),
m
,
waiter
)
transport
=
self
.
loop
.
_make_ssl_transport
(
m
,
asyncio
.
Protocol
(),
m
,
waiter
)
self
.
assertIsInstance
(
transport
,
_SelectorSslTransport
)
@
unittest
.
mock
.
patch
(
'asyncio.selector_events.ssl'
,
None
)
...
...
Lib/test/test_asyncio/test_streams.py
View file @
519883d2
"""Tests for streams.py."""
import
functools
import
gc
import
socket
import
unittest
import
unittest.mock
try
:
...
...
@@ -32,48 +34,85 @@ class StreamReaderTests(unittest.TestCase):
stream
=
asyncio
.
StreamReader
()
self
.
assertIs
(
stream
.
_loop
,
m_events
.
get_event_loop
.
return_value
)
def
_basetest_open_connection
(
self
,
open_connection_fut
):
reader
,
writer
=
self
.
loop
.
run_until_complete
(
open_connection_fut
)
writer
.
write
(
b'GET / HTTP/1.0
\
r
\
n
\
r
\
n
'
)
f
=
reader
.
readline
()
data
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertEqual
(
data
,
b'HTTP/1.0 200 OK
\
r
\
n
'
)
f
=
reader
.
read
()
data
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertTrue
(
data
.
endswith
(
b'
\
r
\
n
\
r
\
n
Test message'
))
writer
.
close
()
def
test_open_connection
(
self
):
with
test_utils
.
run_test_server
()
as
httpd
:
f
=
asyncio
.
open_connection
(
*
httpd
.
address
,
loop
=
self
.
loop
)
reader
,
writer
=
self
.
loop
.
run_until_complete
(
f
)
writer
.
write
(
b'GET / HTTP/1.0
\
r
\
n
\
r
\
n
'
)
f
=
reader
.
readline
()
data
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertEqual
(
data
,
b'HTTP/1.0 200 OK
\
r
\
n
'
)
f
=
reader
.
read
()
data
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertTrue
(
data
.
endswith
(
b'
\
r
\
n
\
r
\
n
Test message'
))
writer
.
close
()
conn_fut
=
asyncio
.
open_connection
(
*
httpd
.
address
,
loop
=
self
.
loop
)
self
.
_basetest_open_connection
(
conn_fut
)
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_open_unix_connection
(
self
):
with
test_utils
.
run_test_unix_server
()
as
httpd
:
conn_fut
=
asyncio
.
open_unix_connection
(
httpd
.
address
,
loop
=
self
.
loop
)
self
.
_basetest_open_connection
(
conn_fut
)
def
_basetest_open_connection_no_loop_ssl
(
self
,
open_connection_fut
):
try
:
reader
,
writer
=
self
.
loop
.
run_until_complete
(
open_connection_fut
)
finally
:
asyncio
.
set_event_loop
(
None
)
writer
.
write
(
b'GET / HTTP/1.0
\
r
\
n
\
r
\
n
'
)
f
=
reader
.
read
()
data
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertTrue
(
data
.
endswith
(
b'
\
r
\
n
\
r
\
n
Test message'
))
writer
.
close
()
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
def
test_open_connection_no_loop_ssl
(
self
):
with
test_utils
.
run_test_server
(
use_ssl
=
True
)
as
httpd
:
try
:
asyncio
.
set_event_loop
(
self
.
loop
)
f
=
asyncio
.
open_connection
(
*
httpd
.
address
,
ssl
=
test_utils
.
dummy_ssl_context
())
reader
,
writer
=
self
.
loop
.
run_until_complete
(
f
)
finally
:
asyncio
.
set_event_loop
(
None
)
writer
.
write
(
b'GET / HTTP/1.0
\
r
\
n
\
r
\
n
'
)
f
=
reader
.
read
()
data
=
self
.
loop
.
run_until_complete
(
f
)
self
.
assertTrue
(
data
.
endswith
(
b'
\
r
\
n
\
r
\
n
Test message'
))
conn_fut
=
asyncio
.
open_connection
(
*
httpd
.
address
,
ssl
=
test_utils
.
dummy_ssl_context
(),
loop
=
self
.
loop
)
writer
.
close
()
self
.
_basetest_open_connection_no_loop_ssl
(
conn_fut
)
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_open_unix_connection_no_loop_ssl
(
self
):
with
test_utils
.
run_test_unix_server
(
use_ssl
=
True
)
as
httpd
:
conn_fut
=
asyncio
.
open_unix_connection
(
httpd
.
address
,
ssl
=
test_utils
.
dummy_ssl_context
(),
server_hostname
=
''
,
loop
=
self
.
loop
)
self
.
_basetest_open_connection_no_loop_ssl
(
conn_fut
)
def
_basetest_open_connection_error
(
self
,
open_connection_fut
):
reader
,
writer
=
self
.
loop
.
run_until_complete
(
open_connection_fut
)
writer
.
_protocol
.
connection_lost
(
ZeroDivisionError
())
f
=
reader
.
read
()
with
self
.
assertRaises
(
ZeroDivisionError
):
self
.
loop
.
run_until_complete
(
f
)
writer
.
close
()
test_utils
.
run_briefly
(
self
.
loop
)
def
test_open_connection_error
(
self
):
with
test_utils
.
run_test_server
()
as
httpd
:
f
=
asyncio
.
open_connection
(
*
httpd
.
address
,
loop
=
self
.
loop
)
reader
,
writer
=
self
.
loop
.
run_until_complete
(
f
)
writer
.
_protocol
.
connection_lost
(
ZeroDivisionError
())
f
=
reader
.
read
()
with
self
.
assertRaises
(
ZeroDivisionError
):
self
.
loop
.
run_until_complete
(
f
)
conn_fut
=
asyncio
.
open_connection
(
*
httpd
.
address
,
loop
=
self
.
loop
)
self
.
_basetest_open_connection_error
(
conn_fut
)
writer
.
close
()
test_utils
.
run_briefly
(
self
.
loop
)
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_open_unix_connection_error
(
self
):
with
test_utils
.
run_test_unix_server
()
as
httpd
:
conn_fut
=
asyncio
.
open_unix_connection
(
httpd
.
address
,
loop
=
self
.
loop
)
self
.
_basetest_open_connection_error
(
conn_fut
)
def
test_feed_empty_data
(
self
):
stream
=
asyncio
.
StreamReader
(
loop
=
self
.
loop
)
...
...
@@ -415,10 +454,13 @@ class StreamReaderTests(unittest.TestCase):
client_writer
.
write
(
data
)
def
start
(
self
):
sock
=
socket
.
socket
()
sock
.
bind
((
'127.0.0.1'
,
0
))
self
.
server
=
self
.
loop
.
run_until_complete
(
asyncio
.
start_server
(
self
.
handle_client
,
'127.0.0.1'
,
12345
,
sock
=
sock
,
loop
=
self
.
loop
))
return
sock
.
getsockname
()
def
handle_client_callback
(
self
,
client_reader
,
client_writer
):
task
=
asyncio
.
Task
(
client_reader
.
readline
(),
loop
=
self
.
loop
)
...
...
@@ -429,10 +471,15 @@ class StreamReaderTests(unittest.TestCase):
task
.
add_done_callback
(
done
)
def
start_callback
(
self
):
sock
=
socket
.
socket
()
sock
.
bind
((
'127.0.0.1'
,
0
))
addr
=
sock
.
getsockname
()
sock
.
close
()
self
.
server
=
self
.
loop
.
run_until_complete
(
asyncio
.
start_server
(
self
.
handle_client_callback
,
'127.0.0.1'
,
12345
,
host
=
addr
[
0
],
port
=
addr
[
1
]
,
loop
=
self
.
loop
))
return
addr
def
stop
(
self
):
if
self
.
server
is
not
None
:
...
...
@@ -441,9 +488,9 @@ class StreamReaderTests(unittest.TestCase):
self
.
server
=
None
@
asyncio
.
coroutine
def
client
():
def
client
(
addr
):
reader
,
writer
=
yield
from
asyncio
.
open_connection
(
'127.0.0.1'
,
12345
,
loop
=
self
.
loop
)
*
addr
,
loop
=
self
.
loop
)
# send a line
writer
.
write
(
b"hello world!
\
n
"
)
# read it back
...
...
@@ -453,20 +500,90 @@ class StreamReaderTests(unittest.TestCase):
# test the server variant with a coroutine as client handler
server
=
MyServer
(
self
.
loop
)
server
.
start
()
msg
=
self
.
loop
.
run_until_complete
(
asyncio
.
Task
(
client
(),
addr
=
server
.
start
()
msg
=
self
.
loop
.
run_until_complete
(
asyncio
.
Task
(
client
(
addr
),
loop
=
self
.
loop
))
server
.
stop
()
self
.
assertEqual
(
msg
,
b"hello world!
\
n
"
)
# test the server variant with a callback as client handler
server
=
MyServer
(
self
.
loop
)
server
.
start_callback
()
msg
=
self
.
loop
.
run_until_complete
(
asyncio
.
Task
(
client
(),
addr
=
server
.
start_callback
()
msg
=
self
.
loop
.
run_until_complete
(
asyncio
.
Task
(
client
(
addr
),
loop
=
self
.
loop
))
server
.
stop
()
self
.
assertEqual
(
msg
,
b"hello world!
\
n
"
)
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'No UNIX Sockets'
)
def
test_start_unix_server
(
self
):
class
MyServer
:
def
__init__
(
self
,
loop
,
path
):
self
.
server
=
None
self
.
loop
=
loop
self
.
path
=
path
@
asyncio
.
coroutine
def
handle_client
(
self
,
client_reader
,
client_writer
):
data
=
yield
from
client_reader
.
readline
()
client_writer
.
write
(
data
)
def
start
(
self
):
self
.
server
=
self
.
loop
.
run_until_complete
(
asyncio
.
start_unix_server
(
self
.
handle_client
,
path
=
self
.
path
,
loop
=
self
.
loop
))
def
handle_client_callback
(
self
,
client_reader
,
client_writer
):
task
=
asyncio
.
Task
(
client_reader
.
readline
(),
loop
=
self
.
loop
)
def
done
(
task
):
client_writer
.
write
(
task
.
result
())
task
.
add_done_callback
(
done
)
def
start_callback
(
self
):
self
.
server
=
self
.
loop
.
run_until_complete
(
asyncio
.
start_unix_server
(
self
.
handle_client_callback
,
path
=
self
.
path
,
loop
=
self
.
loop
))
def
stop
(
self
):
if
self
.
server
is
not
None
:
self
.
server
.
close
()
self
.
loop
.
run_until_complete
(
self
.
server
.
wait_closed
())
self
.
server
=
None
@
asyncio
.
coroutine
def
client
(
path
):
reader
,
writer
=
yield
from
asyncio
.
open_unix_connection
(
path
,
loop
=
self
.
loop
)
# send a line
writer
.
write
(
b"hello world!
\
n
"
)
# read it back
msgback
=
yield
from
reader
.
readline
()
writer
.
close
()
return
msgback
# test the server variant with a coroutine as client handler
with
test_utils
.
unix_socket_path
()
as
path
:
server
=
MyServer
(
self
.
loop
,
path
)
server
.
start
()
msg
=
self
.
loop
.
run_until_complete
(
asyncio
.
Task
(
client
(
path
),
loop
=
self
.
loop
))
server
.
stop
()
self
.
assertEqual
(
msg
,
b"hello world!
\
n
"
)
# test the server variant with a callback as client handler
with
test_utils
.
unix_socket_path
()
as
path
:
server
=
MyServer
(
self
.
loop
,
path
)
server
.
start_callback
()
msg
=
self
.
loop
.
run_until_complete
(
asyncio
.
Task
(
client
(
path
),
loop
=
self
.
loop
))
server
.
stop
()
self
.
assertEqual
(
msg
,
b"hello world!
\
n
"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Lib/test/test_asyncio/test_unix_events.py
View file @
519883d2
...
...
@@ -7,8 +7,10 @@ import io
import
os
import
pprint
import
signal
import
socket
import
stat
import
sys
import
tempfile
import
threading
import
unittest
import
unittest.mock
...
...
@@ -24,7 +26,7 @@ from asyncio import unix_events
@
unittest
.
skipUnless
(
signal
,
'Signals are not supported'
)
class
SelectorEventLoopTests
(
unittest
.
TestCase
):
class
SelectorEventLoop
Signal
Tests
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
loop
=
asyncio
.
SelectorEventLoop
()
...
...
@@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase):
m_signal
.
set_wakeup_fd
.
assert_called_once_with
(
-
1
)
@
unittest
.
skipUnless
(
hasattr
(
socket
,
'AF_UNIX'
),
'UNIX Sockets are not supported'
)
class
SelectorEventLoopUnixSocketTests
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
loop
=
asyncio
.
SelectorEventLoop
()
asyncio
.
set_event_loop
(
None
)
def
tearDown
(
self
):
self
.
loop
.
close
()
def
test_create_unix_server_existing_path_sock
(
self
):
with
test_utils
.
unix_socket_path
()
as
path
:
sock
=
socket
.
socket
(
socket
.
AF_UNIX
)
sock
.
bind
(
path
)
coro
=
self
.
loop
.
create_unix_server
(
lambda
:
None
,
path
)
with
self
.
assertRaisesRegexp
(
OSError
,
'Address.*is already in use'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_server_existing_path_nonsock
(
self
):
with
tempfile
.
NamedTemporaryFile
()
as
file
:
coro
=
self
.
loop
.
create_unix_server
(
lambda
:
None
,
file
.
name
)
with
self
.
assertRaisesRegexp
(
OSError
,
'Address.*is already in use'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_server_ssl_bool
(
self
):
coro
=
self
.
loop
.
create_unix_server
(
lambda
:
None
,
path
=
'spam'
,
ssl
=
True
)
with
self
.
assertRaisesRegex
(
TypeError
,
'ssl argument must be an SSLContext'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_server_nopath_nosock
(
self
):
coro
=
self
.
loop
.
create_unix_server
(
lambda
:
None
,
path
=
None
)
with
self
.
assertRaisesRegex
(
ValueError
,
'path was not specified, and no sock'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_server_path_inetsock
(
self
):
coro
=
self
.
loop
.
create_unix_server
(
lambda
:
None
,
path
=
None
,
sock
=
socket
.
socket
())
with
self
.
assertRaisesRegex
(
ValueError
,
'A UNIX Domain Socket was expected'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_connection_path_sock
(
self
):
coro
=
self
.
loop
.
create_unix_connection
(
lambda
:
None
,
'/dev/null'
,
sock
=
object
())
with
self
.
assertRaisesRegex
(
ValueError
,
'path and sock can not be'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_connection_nopath_nosock
(
self
):
coro
=
self
.
loop
.
create_unix_connection
(
lambda
:
None
,
None
)
with
self
.
assertRaisesRegex
(
ValueError
,
'no path and sock were specified'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_connection_nossl_serverhost
(
self
):
coro
=
self
.
loop
.
create_unix_connection
(
lambda
:
None
,
'/dev/null'
,
server_hostname
=
'spam'
)
with
self
.
assertRaisesRegex
(
ValueError
,
'server_hostname is only meaningful'
):
self
.
loop
.
run_until_complete
(
coro
)
def
test_create_unix_connection_ssl_noserverhost
(
self
):
coro
=
self
.
loop
.
create_unix_connection
(
lambda
:
None
,
'/dev/null'
,
ssl
=
True
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'you have to pass server_hostname when using ssl'
):
self
.
loop
.
run_until_complete
(
coro
)
class
UnixReadPipeTransportTests
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment