Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cpython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Kirill Smelkov
cpython
Commits
f111b3dc
Commit
f111b3dc
authored
Dec 30, 2017
by
Yury Selivanov
Committed by
GitHub
Dec 30, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bpo-23749: Implement loop.start_tls() (#5039)
parent
bbdb17d1
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
580 additions
and
54 deletions
+580
-54
Doc/library/asyncio-eventloop.rst
Doc/library/asyncio-eventloop.rst
+32
-0
Lib/asyncio/base_events.py
Lib/asyncio/base_events.py
+44
-1
Lib/asyncio/events.py
Lib/asyncio/events.py
+11
-0
Lib/asyncio/proactor_events.py
Lib/asyncio/proactor_events.py
+2
-0
Lib/asyncio/selector_events.py
Lib/asyncio/selector_events.py
+2
-0
Lib/test/test_asyncio/functional.py
Lib/test/test_asyncio/functional.py
+279
-0
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_events.py
+14
-53
Lib/test/test_asyncio/test_sslproto.py
Lib/test/test_asyncio/test_sslproto.py
+152
-0
Lib/test/test_asyncio/utils.py
Lib/test/test_asyncio/utils.py
+43
-0
Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst
...S.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst
+1
-0
No files found.
Doc/library/asyncio-eventloop.rst
View file @
f111b3dc
...
...
@@ -537,6 +537,38 @@ Creating listening connections
.. versionadded:: 3.5.3
TLS Upgrade
-----------
.. coroutinemethod:: AbstractEventLoop.start_tls(transport, protocol, sslcontext, \*, server_side=False, server_hostname=None, ssl_handshake_timeout=None)
Upgrades an existing connection to TLS.
Returns a new transport instance, that the *protocol* must start using
immediately after the *await*. The *transport* instance passed to
the *start_tls* method should never be used again.
Parameters:
* *transport* and *protocol* instances that methods like
:meth:`~AbstractEventLoop.create_server` and
:meth:`~AbstractEventLoop.create_connection` return.
* *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
* *server_side* pass ``True`` when a server-side connection is being
upgraded (like the one created by :meth:`~AbstractEventLoop.create_server`).
* *server_hostname*: sets or overrides the host name that the target
server's certificate will be matched against.
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
.. versionadded:: 3.7
Watch file descriptors
----------------------
...
...
Lib/asyncio/base_events.py
View file @
f111b3dc
...
...
@@ -29,9 +29,15 @@ import sys
import
warnings
import
weakref
try
:
import
ssl
except
ImportError
:
# pragma: no cover
ssl
=
None
from
.
import
coroutines
from
.
import
events
from
.
import
futures
from
.
import
sslproto
from
.
import
tasks
from
.log
import
logger
...
...
@@ -279,7 +285,8 @@ class BaseEventLoop(events.AbstractEventLoop):
self
,
rawsock
,
protocol
,
sslcontext
,
waiter
=
None
,
*
,
server_side
=
False
,
server_hostname
=
None
,
extra
=
None
,
server
=
None
,
ssl_handshake_timeout
=
None
):
ssl_handshake_timeout
=
None
,
call_connection_made
=
True
):
"""Create SSL transport."""
raise
NotImplementedError
...
...
@@ -795,6 +802,42 @@ class BaseEventLoop(events.AbstractEventLoop):
return
transport
,
protocol
async
def
start_tls
(
self
,
transport
,
protocol
,
sslcontext
,
*
,
server_side
=
False
,
server_hostname
=
None
,
ssl_handshake_timeout
=
None
):
"""Upgrade transport to TLS.
Return a new transport that *protocol* should start using
immediately.
"""
if
ssl
is
None
:
raise
RuntimeError
(
'Python ssl module is not available'
)
if
not
isinstance
(
sslcontext
,
ssl
.
SSLContext
):
raise
TypeError
(
f'sslcontext is expected to be an instance of ssl.SSLContext, '
f'got
{
sslcontext
!
r
}
'
)
if
not
getattr
(
transport
,
'_start_tls_compatible'
,
False
):
raise
TypeError
(
f'transport
{
self
!
r
}
is not supported by start_tls()'
)
waiter
=
self
.
create_future
()
ssl_protocol
=
sslproto
.
SSLProtocol
(
self
,
protocol
,
sslcontext
,
waiter
,
server_side
,
server_hostname
,
ssl_handshake_timeout
=
ssl_handshake_timeout
,
call_connection_made
=
False
)
transport
.
set_protocol
(
ssl_protocol
)
self
.
call_soon
(
ssl_protocol
.
connection_made
,
transport
)
if
not
transport
.
is_reading
():
self
.
call_soon
(
transport
.
resume_reading
)
await
waiter
return
ssl_protocol
.
_app_transport
async
def
create_datagram_endpoint
(
self
,
protocol_factory
,
local_addr
=
None
,
remote_addr
=
None
,
*
,
family
=
0
,
proto
=
0
,
flags
=
0
,
...
...
Lib/asyncio/events.py
View file @
f111b3dc
...
...
@@ -305,6 +305,17 @@ class AbstractEventLoop:
"""
raise
NotImplementedError
async
def
start_tls
(
self
,
transport
,
protocol
,
sslcontext
,
*
,
server_side
=
False
,
server_hostname
=
None
,
ssl_handshake_timeout
=
None
):
"""Upgrade a transport to TLS.
Return a new transport that *protocol* should start using
immediately.
"""
raise
NotImplementedError
async
def
create_unix_connection
(
self
,
protocol_factory
,
path
=
None
,
*
,
ssl
=
None
,
sock
=
None
,
...
...
Lib/asyncio/proactor_events.py
View file @
f111b3dc
...
...
@@ -223,6 +223,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
transports
.
WriteTransport
):
"""Transport for write pipes."""
_start_tls_compatible
=
True
def
write
(
self
,
data
):
if
not
isinstance
(
data
,
(
bytes
,
bytearray
,
memoryview
)):
raise
TypeError
(
...
...
Lib/asyncio/selector_events.py
View file @
f111b3dc
...
...
@@ -694,6 +694,8 @@ class _SelectorTransport(transports._FlowControlMixin,
class
_SelectorSocketTransport
(
_SelectorTransport
):
_start_tls_compatible
=
True
def
__init__
(
self
,
loop
,
sock
,
protocol
,
waiter
=
None
,
extra
=
None
,
server
=
None
):
super
().
__init__
(
loop
,
sock
,
protocol
,
extra
,
server
)
...
...
Lib/test/test_asyncio/functional.py
0 → 100644
View file @
f111b3dc
import
asyncio
import
asyncio.events
import
contextlib
import
os
import
pprint
import
select
import
socket
import
ssl
import
tempfile
import
threading
class
FunctionalTestCaseMixin
:
def
new_loop
(
self
):
return
asyncio
.
new_event_loop
()
def
run_loop_briefly
(
self
,
*
,
delay
=
0.01
):
self
.
loop
.
run_until_complete
(
asyncio
.
sleep
(
delay
,
loop
=
self
.
loop
))
def
loop_exception_handler
(
self
,
loop
,
context
):
self
.
__unhandled_exceptions
.
append
(
context
)
self
.
loop
.
default_exception_handler
(
context
)
def
setUp
(
self
):
self
.
loop
=
self
.
new_loop
()
asyncio
.
set_event_loop
(
None
)
self
.
loop
.
set_exception_handler
(
self
.
loop_exception_handler
)
self
.
__unhandled_exceptions
=
[]
# Disable `_get_running_loop`.
self
.
_old_get_running_loop
=
asyncio
.
events
.
_get_running_loop
asyncio
.
events
.
_get_running_loop
=
lambda
:
None
def
tearDown
(
self
):
try
:
self
.
loop
.
close
()
if
self
.
__unhandled_exceptions
:
print
(
'Unexpected calls to loop.call_exception_handler():'
)
pprint
.
pprint
(
self
.
__unhandled_exceptions
)
self
.
fail
(
'unexpected calls to loop.call_exception_handler()'
)
finally
:
asyncio
.
events
.
_get_running_loop
=
self
.
_old_get_running_loop
asyncio
.
set_event_loop
(
None
)
self
.
loop
=
None
def
tcp_server
(
self
,
server_prog
,
*
,
family
=
socket
.
AF_INET
,
addr
=
None
,
timeout
=
5
,
backlog
=
1
,
max_clients
=
10
):
if
addr
is
None
:
if
hasattr
(
socket
,
'AF_UNIX'
)
and
family
==
socket
.
AF_UNIX
:
with
tempfile
.
NamedTemporaryFile
()
as
tmp
:
addr
=
tmp
.
name
else
:
addr
=
(
'127.0.0.1'
,
0
)
sock
=
socket
.
socket
(
family
,
socket
.
SOCK_STREAM
)
if
timeout
is
None
:
raise
RuntimeError
(
'timeout is required'
)
if
timeout
<=
0
:
raise
RuntimeError
(
'only blocking sockets are supported'
)
sock
.
settimeout
(
timeout
)
try
:
sock
.
bind
(
addr
)
sock
.
listen
(
backlog
)
except
OSError
as
ex
:
sock
.
close
()
raise
ex
return
TestThreadedServer
(
self
,
sock
,
server_prog
,
timeout
,
max_clients
)
def
tcp_client
(
self
,
client_prog
,
family
=
socket
.
AF_INET
,
timeout
=
10
):
sock
=
socket
.
socket
(
family
,
socket
.
SOCK_STREAM
)
if
timeout
is
None
:
raise
RuntimeError
(
'timeout is required'
)
if
timeout
<=
0
:
raise
RuntimeError
(
'only blocking sockets are supported'
)
sock
.
settimeout
(
timeout
)
return
TestThreadedClient
(
self
,
sock
,
client_prog
,
timeout
)
def
unix_server
(
self
,
*
args
,
**
kwargs
):
if
not
hasattr
(
socket
,
'AF_UNIX'
):
raise
NotImplementedError
return
self
.
tcp_server
(
*
args
,
family
=
socket
.
AF_UNIX
,
**
kwargs
)
def
unix_client
(
self
,
*
args
,
**
kwargs
):
if
not
hasattr
(
socket
,
'AF_UNIX'
):
raise
NotImplementedError
return
self
.
tcp_client
(
*
args
,
family
=
socket
.
AF_UNIX
,
**
kwargs
)
@
contextlib
.
contextmanager
def
unix_sock_name
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
td
:
fn
=
os
.
path
.
join
(
td
,
'sock'
)
try
:
yield
fn
finally
:
try
:
os
.
unlink
(
fn
)
except
OSError
:
pass
def
_abort_socket_test
(
self
,
ex
):
try
:
self
.
loop
.
stop
()
finally
:
self
.
fail
(
ex
)
##############################################################################
# Socket Testing Utilities
##############################################################################
class
TestSocketWrapper
:
def
__init__
(
self
,
sock
):
self
.
__sock
=
sock
def
recv_all
(
self
,
n
):
buf
=
b''
while
len
(
buf
)
<
n
:
data
=
self
.
recv
(
n
-
len
(
buf
))
if
data
==
b''
:
raise
ConnectionAbortedError
buf
+=
data
return
buf
def
start_tls
(
self
,
ssl_context
,
*
,
server_side
=
False
,
server_hostname
=
None
):
assert
isinstance
(
ssl_context
,
ssl
.
SSLContext
)
ssl_sock
=
ssl_context
.
wrap_socket
(
self
.
__sock
,
server_side
=
server_side
,
server_hostname
=
server_hostname
,
do_handshake_on_connect
=
False
)
ssl_sock
.
do_handshake
()
self
.
__sock
.
close
()
self
.
__sock
=
ssl_sock
def
__getattr__
(
self
,
name
):
return
getattr
(
self
.
__sock
,
name
)
def
__repr__
(
self
):
return
'<{} {!r}>'
.
format
(
type
(
self
).
__name__
,
self
.
__sock
)
class
SocketThread
(
threading
.
Thread
):
def
stop
(
self
):
self
.
_active
=
False
self
.
join
()
def
__enter__
(
self
):
self
.
start
()
return
self
def
__exit__
(
self
,
*
exc
):
self
.
stop
()
class
TestThreadedClient
(
SocketThread
):
def
__init__
(
self
,
test
,
sock
,
prog
,
timeout
):
threading
.
Thread
.
__init__
(
self
,
None
,
None
,
'test-client'
)
self
.
daemon
=
True
self
.
_timeout
=
timeout
self
.
_sock
=
sock
self
.
_active
=
True
self
.
_prog
=
prog
self
.
_test
=
test
def
run
(
self
):
try
:
self
.
_prog
(
TestSocketWrapper
(
self
.
_sock
))
except
Exception
as
ex
:
self
.
_test
.
_abort_socket_test
(
ex
)
class
TestThreadedServer
(
SocketThread
):
def
__init__
(
self
,
test
,
sock
,
prog
,
timeout
,
max_clients
):
threading
.
Thread
.
__init__
(
self
,
None
,
None
,
'test-server'
)
self
.
daemon
=
True
self
.
_clients
=
0
self
.
_finished_clients
=
0
self
.
_max_clients
=
max_clients
self
.
_timeout
=
timeout
self
.
_sock
=
sock
self
.
_active
=
True
self
.
_prog
=
prog
self
.
_s1
,
self
.
_s2
=
socket
.
socketpair
()
self
.
_s1
.
setblocking
(
False
)
self
.
_test
=
test
def
stop
(
self
):
try
:
if
self
.
_s2
and
self
.
_s2
.
fileno
()
!=
-
1
:
try
:
self
.
_s2
.
send
(
b'stop'
)
except
OSError
:
pass
finally
:
super
().
stop
()
def
run
(
self
):
try
:
with
self
.
_sock
:
self
.
_sock
.
setblocking
(
0
)
self
.
_run
()
finally
:
self
.
_s1
.
close
()
self
.
_s2
.
close
()
def
_run
(
self
):
while
self
.
_active
:
if
self
.
_clients
>=
self
.
_max_clients
:
return
r
,
w
,
x
=
select
.
select
(
[
self
.
_sock
,
self
.
_s1
],
[],
[],
self
.
_timeout
)
if
self
.
_s1
in
r
:
return
if
self
.
_sock
in
r
:
try
:
conn
,
addr
=
self
.
_sock
.
accept
()
except
BlockingIOError
:
continue
except
socket
.
timeout
:
if
not
self
.
_active
:
return
else
:
raise
else
:
self
.
_clients
+=
1
conn
.
settimeout
(
self
.
_timeout
)
try
:
with
conn
:
self
.
_handle_client
(
conn
)
except
Exception
as
ex
:
self
.
_active
=
False
try
:
raise
finally
:
self
.
_test
.
_abort_socket_test
(
ex
)
def
_handle_client
(
self
,
sock
):
self
.
_prog
(
TestSocketWrapper
(
sock
))
@
property
def
addr
(
self
):
return
self
.
_sock
.
getsockname
()
Lib/test/test_asyncio/test_events.py
View file @
f111b3dc
...
...
@@ -31,21 +31,7 @@ from asyncio import events
from
asyncio
import
proactor_events
from
asyncio
import
selector_events
from
test.test_asyncio
import
utils
as
test_utils
try
:
from
test
import
support
except
ImportError
:
from
asyncio
import
test_support
as
support
def
data_file
(
filename
):
if
hasattr
(
support
,
'TEST_HOME_DIR'
):
fullname
=
os
.
path
.
join
(
support
.
TEST_HOME_DIR
,
filename
)
if
os
.
path
.
isfile
(
fullname
):
return
fullname
fullname
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
)
if
os
.
path
.
isfile
(
fullname
):
return
fullname
raise
FileNotFoundError
(
filename
)
from
test
import
support
def
osx_tiger
():
...
...
@@ -80,23 +66,6 @@ class CoroLike:
pass
ONLYCERT
=
data_file
(
'ssl_cert.pem'
)
ONLYKEY
=
data_file
(
'ssl_key.pem'
)
SIGNED_CERTFILE
=
data_file
(
'keycert3.pem'
)
SIGNING_CA
=
data_file
(
'pycacert.pem'
)
PEERCERT
=
{
'serialNumber'
:
'B09264B1F2DA21D1'
,
'version'
:
1
,
'subject'
:
(((
'countryName'
,
'XY'
),),
((
'localityName'
,
'Castle Anthrax'
),),
((
'organizationName'
,
'Python Software Foundation'
),),
((
'commonName'
,
'localhost'
),)),
'issuer'
:
(((
'countryName'
,
'XY'
),),
((
'organizationName'
,
'Python Software Foundation CA'
),),
((
'commonName'
,
'our-ca-server'
),)),
'notAfter'
:
'Nov 13 19:47:07 2022 GMT'
,
'notBefore'
:
'Jan 4 19:47:07 2013 GMT'
}
class
MyBaseProto
(
asyncio
.
Protocol
):
connected
=
None
done
=
None
...
...
@@ -853,16 +822,8 @@ class EventLoopTestsMixin:
'SSL not supported with proactor event loops before Python 3.5'
)
server_context
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_SERVER
)
server_context
.
load_cert_chain
(
ONLYCERT
,
ONLYKEY
)
if
hasattr
(
server_context
,
'check_hostname'
):
server_context
.
check_hostname
=
False
server_context
.
verify_mode
=
ssl
.
CERT_NONE
client_context
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_CLIENT
)
if
hasattr
(
server_context
,
'check_hostname'
):
client_context
.
check_hostname
=
False
client_context
.
verify_mode
=
ssl
.
CERT_NONE
server_context
=
test_utils
.
simple_server_sslcontext
()
client_context
=
test_utils
.
simple_client_sslcontext
()
self
.
test_connect_accepted_socket
(
server_context
,
client_context
)
...
...
@@ -1048,7 +1009,7 @@ class EventLoopTestsMixin:
def
test_create_server_ssl
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
ONLYCERT
,
ONLYKEY
)
lambda
:
proto
,
test_utils
.
ONLYCERT
,
test_utils
.
ONLYKEY
)
f_c
=
self
.
loop
.
create_connection
(
MyBaseProto
,
host
,
port
,
ssl
=
test_utils
.
dummy_ssl_context
())
...
...
@@ -1081,7 +1042,7 @@ class EventLoopTestsMixin:
def
test_create_unix_server_ssl
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
path
=
self
.
_make_ssl_unix_server
(
lambda
:
proto
,
ONLYCERT
,
ONLYKEY
)
lambda
:
proto
,
test_utils
.
ONLYCERT
,
test_utils
.
ONLYKEY
)
f_c
=
self
.
loop
.
create_unix_connection
(
MyBaseProto
,
path
,
ssl
=
test_utils
.
dummy_ssl_context
(),
...
...
@@ -1111,7 +1072,7 @@ class EventLoopTestsMixin:
def
test_create_server_ssl_verify_failed
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
lambda
:
proto
,
test_utils
.
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_CLIENT
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
...
...
@@ -1141,7 +1102,7 @@ class EventLoopTestsMixin:
def
test_create_unix_server_ssl_verify_failed
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
path
=
self
.
_make_ssl_unix_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
lambda
:
proto
,
test_utils
.
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_CLIENT
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
...
...
@@ -1170,13 +1131,13 @@ class EventLoopTestsMixin:
def
test_create_server_ssl_match_failed
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
lambda
:
proto
,
test_utils
.
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_CLIENT
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
sslcontext_client
.
verify_mode
=
ssl
.
CERT_REQUIRED
sslcontext_client
.
load_verify_locations
(
cafile
=
SIGNING_CA
)
cafile
=
test_utils
.
SIGNING_CA
)
if
hasattr
(
sslcontext_client
,
'check_hostname'
):
sslcontext_client
.
check_hostname
=
True
...
...
@@ -1199,12 +1160,12 @@ class EventLoopTestsMixin:
def
test_create_unix_server_ssl_verified
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
path
=
self
.
_make_ssl_unix_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
lambda
:
proto
,
test_utils
.
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_CLIENT
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
sslcontext_client
.
verify_mode
=
ssl
.
CERT_REQUIRED
sslcontext_client
.
load_verify_locations
(
cafile
=
SIGNING_CA
)
sslcontext_client
.
load_verify_locations
(
cafile
=
test_utils
.
SIGNING_CA
)
if
hasattr
(
sslcontext_client
,
'check_hostname'
):
sslcontext_client
.
check_hostname
=
True
...
...
@@ -1224,12 +1185,12 @@ class EventLoopTestsMixin:
def
test_create_server_ssl_verified
(
self
):
proto
=
MyProto
(
loop
=
self
.
loop
)
server
,
host
,
port
=
self
.
_make_ssl_server
(
lambda
:
proto
,
SIGNED_CERTFILE
)
lambda
:
proto
,
test_utils
.
SIGNED_CERTFILE
)
sslcontext_client
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_CLIENT
)
sslcontext_client
.
options
|=
ssl
.
OP_NO_SSLv2
sslcontext_client
.
verify_mode
=
ssl
.
CERT_REQUIRED
sslcontext_client
.
load_verify_locations
(
cafile
=
SIGNING_CA
)
sslcontext_client
.
load_verify_locations
(
cafile
=
test_utils
.
SIGNING_CA
)
if
hasattr
(
sslcontext_client
,
'check_hostname'
):
sslcontext_client
.
check_hostname
=
True
...
...
@@ -1241,7 +1202,7 @@ class EventLoopTestsMixin:
# extra info is available
self
.
check_ssl_extra_info
(
client
,
peername
=
(
host
,
port
),
peercert
=
PEERCERT
)
peercert
=
test_utils
.
PEERCERT
)
# close connection
proto
.
transport
.
close
()
...
...
Lib/test/test_asyncio/test_sslproto.py
View file @
f111b3dc
...
...
@@ -13,6 +13,7 @@ from asyncio import log
from
asyncio
import
sslproto
from
asyncio
import
tasks
from
test.test_asyncio
import
utils
as
test_utils
from
test.test_asyncio
import
functional
as
func_tests
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
...
...
@@ -158,5 +159,156 @@ class SslProtoHandshakeTests(test_utils.TestCase):
self
.
assertIs
(
ssl_proto
.
_app_protocol
,
new_app_proto
)
##############################################################################
# Start TLS Tests
##############################################################################
class
BaseStartTLS
(
func_tests
.
FunctionalTestCaseMixin
):
def
new_loop
(
self
):
raise
NotImplementedError
def
test_start_tls_client_1
(
self
):
HELLO_MSG
=
b'1'
*
1024
*
1024
*
5
server_context
=
test_utils
.
simple_server_sslcontext
()
client_context
=
test_utils
.
simple_client_sslcontext
()
def
serve
(
sock
):
data
=
sock
.
recv_all
(
len
(
HELLO_MSG
))
self
.
assertEqual
(
len
(
data
),
len
(
HELLO_MSG
))
sock
.
start_tls
(
server_context
,
server_side
=
True
)
sock
.
sendall
(
b'O'
)
data
=
sock
.
recv_all
(
len
(
HELLO_MSG
))
self
.
assertEqual
(
len
(
data
),
len
(
HELLO_MSG
))
sock
.
close
()
class
ClientProto
(
asyncio
.
Protocol
):
def
__init__
(
self
,
on_data
,
on_eof
):
self
.
on_data
=
on_data
self
.
on_eof
=
on_eof
self
.
con_made_cnt
=
0
def
connection_made
(
proto
,
tr
):
proto
.
con_made_cnt
+=
1
# Ensure connection_made gets called only once.
self
.
assertEqual
(
proto
.
con_made_cnt
,
1
)
def
data_received
(
self
,
data
):
self
.
on_data
.
set_result
(
data
)
def
eof_received
(
self
):
self
.
on_eof
.
set_result
(
True
)
async
def
client
(
addr
):
on_data
=
self
.
loop
.
create_future
()
on_eof
=
self
.
loop
.
create_future
()
tr
,
proto
=
await
self
.
loop
.
create_connection
(
lambda
:
ClientProto
(
on_data
,
on_eof
),
*
addr
)
tr
.
write
(
HELLO_MSG
)
new_tr
=
await
self
.
loop
.
start_tls
(
tr
,
proto
,
client_context
)
self
.
assertEqual
(
await
on_data
,
b'O'
)
new_tr
.
write
(
HELLO_MSG
)
await
on_eof
new_tr
.
close
()
with
self
.
tcp_server
(
serve
)
as
srv
:
self
.
loop
.
run_until_complete
(
asyncio
.
wait_for
(
client
(
srv
.
addr
),
loop
=
self
.
loop
,
timeout
=
10
))
def
test_start_tls_server_1
(
self
):
HELLO_MSG
=
b'1'
*
1024
*
1024
*
5
server_context
=
test_utils
.
simple_server_sslcontext
()
client_context
=
test_utils
.
simple_client_sslcontext
()
def
client
(
sock
,
addr
):
sock
.
connect
(
addr
)
data
=
sock
.
recv_all
(
len
(
HELLO_MSG
))
self
.
assertEqual
(
len
(
data
),
len
(
HELLO_MSG
))
sock
.
start_tls
(
client_context
)
sock
.
sendall
(
HELLO_MSG
)
sock
.
close
()
class
ServerProto
(
asyncio
.
Protocol
):
def
__init__
(
self
,
on_con
,
on_eof
):
self
.
on_con
=
on_con
self
.
on_eof
=
on_eof
self
.
data
=
b''
def
connection_made
(
self
,
tr
):
self
.
on_con
.
set_result
(
tr
)
def
data_received
(
self
,
data
):
self
.
data
+=
data
def
eof_received
(
self
):
self
.
on_eof
.
set_result
(
1
)
async
def
main
():
tr
=
await
on_con
tr
.
write
(
HELLO_MSG
)
self
.
assertEqual
(
proto
.
data
,
b''
)
new_tr
=
await
self
.
loop
.
start_tls
(
tr
,
proto
,
server_context
,
server_side
=
True
)
await
on_eof
self
.
assertEqual
(
proto
.
data
,
HELLO_MSG
)
new_tr
.
close
()
server
.
close
()
await
server
.
wait_closed
()
on_con
=
self
.
loop
.
create_future
()
on_eof
=
self
.
loop
.
create_future
()
proto
=
ServerProto
(
on_con
,
on_eof
)
server
=
self
.
loop
.
run_until_complete
(
self
.
loop
.
create_server
(
lambda
:
proto
,
'127.0.0.1'
,
0
))
addr
=
server
.
sockets
[
0
].
getsockname
()
with
self
.
tcp_client
(
lambda
sock
:
client
(
sock
,
addr
)):
self
.
loop
.
run_until_complete
(
asyncio
.
wait_for
(
main
(),
loop
=
self
.
loop
,
timeout
=
10
))
def
test_start_tls_wrong_args
(
self
):
async
def
main
():
with
self
.
assertRaisesRegex
(
TypeError
,
'SSLContext, got'
):
await
self
.
loop
.
start_tls
(
None
,
None
,
None
)
sslctx
=
test_utils
.
simple_server_sslcontext
()
with
self
.
assertRaisesRegex
(
TypeError
,
'is not supported'
):
await
self
.
loop
.
start_tls
(
None
,
None
,
sslctx
)
self
.
loop
.
run_until_complete
(
main
())
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
class
SelectorStartTLS
(
BaseStartTLS
,
unittest
.
TestCase
):
def
new_loop
(
self
):
return
asyncio
.
SelectorEventLoop
()
@
unittest
.
skipIf
(
ssl
is
None
,
'No ssl module'
)
@
unittest
.
skipUnless
(
hasattr
(
asyncio
,
'ProactorEventLoop'
),
'Windows only'
)
class
ProactorStartTLS
(
BaseStartTLS
,
unittest
.
TestCase
):
def
new_loop
(
self
):
return
asyncio
.
ProactorEventLoop
()
if
__name__
==
'__main__'
:
unittest
.
main
()
Lib/test/test_asyncio/utils.py
View file @
f111b3dc
...
...
@@ -35,6 +35,49 @@ from asyncio.log import logger
from
test
import
support
def
data_file
(
filename
):
if
hasattr
(
support
,
'TEST_HOME_DIR'
):
fullname
=
os
.
path
.
join
(
support
.
TEST_HOME_DIR
,
filename
)
if
os
.
path
.
isfile
(
fullname
):
return
fullname
fullname
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
)
if
os
.
path
.
isfile
(
fullname
):
return
fullname
raise
FileNotFoundError
(
filename
)
ONLYCERT
=
data_file
(
'ssl_cert.pem'
)
ONLYKEY
=
data_file
(
'ssl_key.pem'
)
SIGNED_CERTFILE
=
data_file
(
'keycert3.pem'
)
SIGNING_CA
=
data_file
(
'pycacert.pem'
)
PEERCERT
=
{
'serialNumber'
:
'B09264B1F2DA21D1'
,
'version'
:
1
,
'subject'
:
(((
'countryName'
,
'XY'
),),
((
'localityName'
,
'Castle Anthrax'
),),
((
'organizationName'
,
'Python Software Foundation'
),),
((
'commonName'
,
'localhost'
),)),
'issuer'
:
(((
'countryName'
,
'XY'
),),
((
'organizationName'
,
'Python Software Foundation CA'
),),
((
'commonName'
,
'our-ca-server'
),)),
'notAfter'
:
'Nov 13 19:47:07 2022 GMT'
,
'notBefore'
:
'Jan 4 19:47:07 2013 GMT'
}
def
simple_server_sslcontext
():
server_context
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_SERVER
)
server_context
.
load_cert_chain
(
ONLYCERT
,
ONLYKEY
)
server_context
.
check_hostname
=
False
server_context
.
verify_mode
=
ssl
.
CERT_NONE
return
server_context
def
simple_client_sslcontext
():
client_context
=
ssl
.
SSLContext
(
ssl
.
PROTOCOL_TLS_CLIENT
)
client_context
.
check_hostname
=
False
client_context
.
verify_mode
=
ssl
.
CERT_NONE
return
client_context
def
dummy_ssl_context
():
if
ssl
is
None
:
return
None
...
...
Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst
0 → 100644
View file @
f111b3dc
asyncio: Implement loop.start_tls()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment