Commit dad57116 authored by Gregory P. Smith's avatar Gregory P. Smith

Fixes Issue #14635: telnetlib will use poll() rather than select() when possible

to avoid failing due to the select() file descriptor limit.
parent 4774946c
...@@ -34,6 +34,7 @@ To do: ...@@ -34,6 +34,7 @@ To do:
# Imported modules # Imported modules
import errno
import sys import sys
import socket import socket
import select import select
...@@ -205,6 +206,7 @@ class Telnet: ...@@ -205,6 +206,7 @@ class Telnet:
self.sb = 0 # flag for SB and SE sequence. self.sb = 0 # flag for SB and SE sequence.
self.sbdataq = b'' self.sbdataq = b''
self.option_callback = None self.option_callback = None
self._has_poll = hasattr(select, 'poll')
if host is not None: if host is not None:
self.open(host, port, timeout) self.open(host, port, timeout)
...@@ -286,6 +288,61 @@ class Telnet: ...@@ -286,6 +288,61 @@ class Telnet:
possibly the empty string. Raise EOFError if the connection possibly the empty string. Raise EOFError if the connection
is closed and no cooked data is available. is closed and no cooked data is available.
"""
if self._has_poll:
return self._read_until_with_poll(match, timeout)
else:
return self._read_until_with_select(match, timeout)
def _read_until_with_poll(self, match, timeout):
"""Read until a given string is encountered or until timeout.
This method uses select.poll() to implement the timeout.
"""
n = len(match)
call_timeout = timeout
if timeout is not None:
from time import time
time_start = time()
self.process_rawq()
i = self.cookedq.find(match)
if i < 0:
poller = select.poll()
poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
poller.register(self, poll_in_or_priority_flags)
while i < 0 and not self.eof:
try:
ready = poller.poll(call_timeout)
except select.error as e:
if e.errno == errno.EINTR:
if timeout is not None:
elapsed = time() - time_start
call_timeout = timeout-elapsed
continue
raise
for fd, mode in ready:
if mode & poll_in_or_priority_flags:
i = max(0, len(self.cookedq)-n)
self.fill_rawq()
self.process_rawq()
i = self.cookedq.find(match, i)
if timeout is not None:
elapsed = time() - time_start
if elapsed >= timeout:
break
call_timeout = timeout-elapsed
poller.unregister(self)
if i >= 0:
i = i + n
buf = self.cookedq[:i]
self.cookedq = self.cookedq[i:]
return buf
return self.read_very_lazy()
def _read_until_with_select(self, match, timeout=None):
"""Read until a given string is encountered or until timeout.
The timeout is implemented using select.select().
""" """
n = len(match) n = len(match)
self.process_rawq() self.process_rawq()
...@@ -588,6 +645,79 @@ class Telnet: ...@@ -588,6 +645,79 @@ class Telnet:
or if more than one expression can match the same input, the or if more than one expression can match the same input, the
results are undeterministic, and may depend on the I/O timing. results are undeterministic, and may depend on the I/O timing.
"""
if self._has_poll:
return self._expect_with_poll(list, timeout)
else:
return self._expect_with_select(list, timeout)
def _expect_with_poll(self, expect_list, timeout=None):
"""Read until one from a list of a regular expressions matches.
This method uses select.poll() to implement the timeout.
"""
re = None
expect_list = expect_list[:]
indices = range(len(expect_list))
for i in indices:
if not hasattr(expect_list[i], "search"):
if not re: import re
expect_list[i] = re.compile(expect_list[i])
call_timeout = timeout
if timeout is not None:
from time import time
time_start = time()
self.process_rawq()
m = None
for i in indices:
m = expect_list[i].search(self.cookedq)
if m:
e = m.end()
text = self.cookedq[:e]
self.cookedq = self.cookedq[e:]
break
if not m:
poller = select.poll()
poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
poller.register(self, poll_in_or_priority_flags)
while not m and not self.eof:
try:
ready = poller.poll(call_timeout)
except select.error as e:
if e.errno == errno.EINTR:
if timeout is not None:
elapsed = time() - time_start
call_timeout = timeout-elapsed
continue
raise
for fd, mode in ready:
if mode & poll_in_or_priority_flags:
self.fill_rawq()
self.process_rawq()
for i in indices:
m = expect_list[i].search(self.cookedq)
if m:
e = m.end()
text = self.cookedq[:e]
self.cookedq = self.cookedq[e:]
break
if timeout is not None:
elapsed = time() - time_start
if elapsed >= timeout:
break
call_timeout = timeout-elapsed
poller.unregister(self)
if m:
return (i, m, text)
text = self.read_very_lazy()
if not text and self.eof:
raise EOFError
return (-1, None, text)
def _expect_with_select(self, list, timeout=None):
"""Read until one from a list of a regular expressions matches.
The timeout is implemented using select.select().
""" """
re = None re = None
list = list[:] list = list[:]
......
...@@ -75,8 +75,8 @@ class GeneralTests(TestCase): ...@@ -75,8 +75,8 @@ class GeneralTests(TestCase):
class SocketStub(object): class SocketStub(object):
''' a socket proxy that re-defines sendall() ''' ''' a socket proxy that re-defines sendall() '''
def __init__(self, reads=[]): def __init__(self, reads=()):
self.reads = reads self.reads = list(reads) # Intentionally make a copy.
self.writes = [] self.writes = []
self.block = False self.block = False
def sendall(self, data): def sendall(self, data):
...@@ -102,7 +102,7 @@ class TelnetAlike(telnetlib.Telnet): ...@@ -102,7 +102,7 @@ class TelnetAlike(telnetlib.Telnet):
self._messages += out.getvalue() self._messages += out.getvalue()
return return
def new_select(*s_args): def mock_select(*s_args):
block = False block = False
for l in s_args: for l in s_args:
for fob in l: for fob in l:
...@@ -113,6 +113,30 @@ def new_select(*s_args): ...@@ -113,6 +113,30 @@ def new_select(*s_args):
else: else:
return s_args return s_args
class MockPoller(object):
test_case = None # Set during TestCase setUp.
def __init__(self):
self._file_objs = []
def register(self, fd, eventmask):
self.test_case.assertTrue(hasattr(fd, 'fileno'), fd)
self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI)
self._file_objs.append(fd)
def poll(self, timeout=None):
block = False
for fob in self._file_objs:
if isinstance(fob, TelnetAlike):
block = fob.sock.block
if block:
return []
else:
return zip(self._file_objs, [select.POLLIN]*len(self._file_objs))
def unregister(self, fd):
self._file_objs.remove(fd)
@contextlib.contextmanager @contextlib.contextmanager
def test_socket(reads): def test_socket(reads):
def new_conn(*ignored): def new_conn(*ignored):
...@@ -125,7 +149,7 @@ def test_socket(reads): ...@@ -125,7 +149,7 @@ def test_socket(reads):
socket.create_connection = old_conn socket.create_connection = old_conn
return return
def test_telnet(reads=[], cls=TelnetAlike): def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
''' return a telnetlib.Telnet object that uses a SocketStub with ''' return a telnetlib.Telnet object that uses a SocketStub with
reads queued up to be read ''' reads queued up to be read '''
for x in reads: for x in reads:
...@@ -133,15 +157,28 @@ def test_telnet(reads=[], cls=TelnetAlike): ...@@ -133,15 +157,28 @@ def test_telnet(reads=[], cls=TelnetAlike):
with test_socket(reads): with test_socket(reads):
telnet = cls('dummy', 0) telnet = cls('dummy', 0)
telnet._messages = '' # debuglevel output telnet._messages = '' # debuglevel output
if use_poll is not None:
if use_poll and not telnet._has_poll:
raise unittest.SkipTest('select.poll() required.')
telnet._has_poll = use_poll
return telnet return telnet
class ReadTests(TestCase):
class ExpectAndReadTestCase(TestCase):
def setUp(self): def setUp(self):
self.old_select = select.select self.old_select = select.select
select.select = new_select self.old_poll = select.poll
select.select = mock_select
select.poll = MockPoller
MockPoller.test_case = self
def tearDown(self): def tearDown(self):
MockPoller.test_case = None
select.poll = self.old_poll
select.select = self.old_select select.select = self.old_select
class ReadTests(ExpectAndReadTestCase):
def test_read_until(self): def test_read_until(self):
""" """
read_until(expected, timeout=None) read_until(expected, timeout=None)
...@@ -158,6 +195,21 @@ class ReadTests(TestCase): ...@@ -158,6 +195,21 @@ class ReadTests(TestCase):
data = telnet.read_until(b'match') data = telnet.read_until(b'match')
self.assertEqual(data, expect) self.assertEqual(data, expect)
def test_read_until_with_poll(self):
"""Use select.poll() to implement telnet.read_until()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=True)
select.select = lambda *_: self.fail('unexpected select() call.')
data = telnet.read_until(b'match')
self.assertEqual(data, b''.join(want[:-1]))
def test_read_until_with_select(self):
"""Use select.select() to implement telnet.read_until()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=False)
select.poll = lambda *_: self.fail('unexpected poll() call.')
data = telnet.read_until(b'match')
self.assertEqual(data, b''.join(want[:-1]))
def test_read_all(self): def test_read_all(self):
""" """
...@@ -349,8 +401,38 @@ class OptionTests(TestCase): ...@@ -349,8 +401,38 @@ class OptionTests(TestCase):
self.assertRegex(telnet._messages, r'0.*test') self.assertRegex(telnet._messages, r'0.*test')
class ExpectTests(ExpectAndReadTestCase):
def test_expect(self):
"""
expect(expected, [timeout])
Read until the expected string has been seen, or a timeout is
hit (default is no timeout); may block.
"""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want)
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_expect_with_poll(self):
"""Use select.poll() to implement telnet.expect()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=True)
select.select = lambda *_: self.fail('unexpected select() call.')
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_expect_with_select(self):
"""Use select.select() to implement telnet.expect()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=False)
select.poll = lambda *_: self.fail('unexpected poll() call.')
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_main(verbose=None): def test_main(verbose=None):
support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests) support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests,
ExpectTests)
if __name__ == '__main__': if __name__ == '__main__':
test_main() test_main()
...@@ -410,6 +410,7 @@ Chris Hoffman ...@@ -410,6 +410,7 @@ Chris Hoffman
Albert Hofkamp Albert Hofkamp
Tomas Hoger Tomas Hoger
Jonathan Hogg Jonathan Hogg
Akintayo Holder
Gerrit Holl Gerrit Holl
Shane Holloway Shane Holloway
Rune Holm Rune Holm
......
...@@ -87,6 +87,9 @@ Core and Builtins ...@@ -87,6 +87,9 @@ Core and Builtins
Library Library
------- -------
- Issue #14635: telnetlib will use poll() rather than select() when possible
to avoid failing due to the select() file descriptor limit.
- Issue #15180: Clarify posixpath.join() error message when mixing str & bytes - Issue #15180: Clarify posixpath.join() error message when mixing str & bytes
- Issue #15230: runpy.run_path now correctly sets __package__ as described - Issue #15230: runpy.run_path now correctly sets __package__ as described
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment