Commit 65a7a2bb authored by Denis Bilenko's avatar Denis Bilenko

greentest: instead of parsing stderr, hook up into Hub.handle_error

parent 097890fc
...@@ -24,8 +24,6 @@ import sys ...@@ -24,8 +24,6 @@ import sys
import unittest import unittest
from unittest import TestCase as BaseTestCase from unittest import TestCase as BaseTestCase
import time import time
import traceback
import re
import os import os
from os.path import basename, splitext from os.path import basename, splitext
import gevent import gevent
...@@ -86,6 +84,45 @@ def wrap_refcount(method): ...@@ -86,6 +84,45 @@ def wrap_refcount(method):
return wrapped return wrapped
def wrap_error_fatal(method):
@wraps(method)
def wrapped(self, *args, **kwargs):
SYSTEM_ERROR = self._hub.SYSTEM_ERROR
self._hub.SYSTEM_ERROR = object
try:
return method(self, *args, **kwargs)
finally:
self._hub.SYSTEM_ERROR = SYSTEM_ERROR
return wrapped
def wrap_restore_handle_error(method):
@wraps(method)
def wrapped(self, *args, **kwargs):
old = self._hub.handle_error
try:
return method(self, *args, **kwargs)
finally:
self._hub.handle_error = old
if self.peek_error()[0] is not None:
gevent.getcurrent().throw(*self.peek_error()[1:])
return wrapped
def _get_class_attr(classDict, bases, attr, default=AttributeError):
NONE = object()
value = classDict.get(attr, NONE)
if value is not NONE:
return value
for base in bases:
value = getattr(bases[0], attr, NONE)
if value is not NONE:
return value
if default is AttributeError:
raise AttributeError('Attribute %r not found\n%s\n%s\n' % (attr, classDict, bases))
return default
class TestCaseMetaClass(type): class TestCaseMetaClass(type):
# wrap each test method with # wrap each test method with
# a) timeout check # a) timeout check
...@@ -94,13 +131,18 @@ class TestCaseMetaClass(type): ...@@ -94,13 +131,18 @@ class TestCaseMetaClass(type):
timeout = classDict.get('__timeout__', 'NONE') timeout = classDict.get('__timeout__', 'NONE')
if timeout == 'NONE': if timeout == 'NONE':
timeout = getattr(bases[0], '__timeout__', None) timeout = getattr(bases[0], '__timeout__', None)
check_totalrefcount = classDict.get('check_totalrefcount') check_totalrefcount = _get_class_attr(classDict, bases, 'check_totalrefcount', True)
if check_totalrefcount is None: error_fatal = _get_class_attr(classDict, bases, 'error_fatal', True)
check_totalrefcount = getattr(bases[0], 'check_totalrefcount', True)
for key, value in classDict.items(): for key, value in classDict.items():
if (key.startswith('test_') or key == 'test') and callable(value): if key.startswith('test') and callable(value):
classDict.pop(key) classDict.pop(key)
value = wrap_timeout(timeout, value) value = wrap_timeout(timeout, value)
my_error_fatal = getattr(value, 'error_fatal', None)
if my_error_fatal is None:
my_error_fatal = error_fatal
if my_error_fatal:
value = wrap_error_fatal(value)
value = wrap_restore_handle_error(value)
if check_totalrefcount: if check_totalrefcount:
value = wrap_refcount(value) value = wrap_refcount(value)
classDict[key] = value classDict[key] = value
...@@ -112,8 +154,8 @@ class TestCase(BaseTestCase): ...@@ -112,8 +154,8 @@ class TestCase(BaseTestCase):
__metaclass__ = TestCaseMetaClass __metaclass__ = TestCaseMetaClass
__timeout__ = 1 __timeout__ = 1
switch_expected = 'default' switch_expected = 'default'
error_fatal = True
_switch_count = None _switch_count = None
check_totalrefcount = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
BaseTestCase.__init__(self, *args, **kwargs) BaseTestCase.__init__(self, *args, **kwargs)
...@@ -126,22 +168,25 @@ class TestCase(BaseTestCase): ...@@ -126,22 +168,25 @@ class TestCase(BaseTestCase):
return BaseTestCase.run(self, *args, **kwargs) return BaseTestCase.run(self, *args, **kwargs)
def setUp(self): def setUp(self):
hub = gevent.hub._get_hub() self._hub.loop.update_now()
if hub is not None:
hub.loop.update_now()
if hasattr(self._hub, 'switch_count'): if hasattr(self._hub, 'switch_count'):
self._switch_count = self._hub.switch_count self._switch_count = self._hub.switch_count
def tearDown(self): def tearDown(self):
if hasattr(self, 'cleanup'): if hasattr(self, 'cleanup'):
self.cleanup() self.cleanup()
try: if self._switch_count is not None and hasattr(self._hub, 'switch_count'):
if not hasattr(self, 'stderr'): msg = ''
self.unhook_stderr() if self._hub.switch_count < self._switch_count:
if hasattr(self, 'stderr'): msg = 'hub.switch_count decreased?\n'
sys.__stderr__.write(self.stderr) elif self._hub.switch_count == self._switch_count:
except: if self.switch_expected:
traceback.print_exc() msg = '%s.%s did not switch\n' % (type(self).__name__, self.testname)
elif self._hub.switch_count > self._switch_count:
if not self.switch_expected:
msg = '%s.%s switched but expected not to\n' % (type(self).__name__, self.testname)
if msg:
print >> sys.stderr, 'WARNING: ' + msg
@property @property
def testname(self): def testname(self):
...@@ -163,86 +208,44 @@ class TestCase(BaseTestCase): ...@@ -163,86 +208,44 @@ class TestCase(BaseTestCase):
def fullname(self): def fullname(self):
return splitext(basename(self.modulename))[0] + '.' + self.testcasename return splitext(basename(self.modulename))[0] + '.' + self.testcasename
def hook_stderr(self): _none = (None, None, None)
if VERBOSE: _error = _none
return
from cStringIO import StringIO
self.new_stderr = StringIO()
self.old_stderr = sys.stderr
sys.stderr = self.new_stderr
def unhook_stderr(self):
if VERBOSE:
return
try:
value = self.new_stderr.getvalue()
except AttributeError:
return None
sys.stderr = self.old_stderr
self.stderr = value
return value
def assert_no_stderr(self): def expect_one_error(self):
stderr = self.unhook_stderr() assert self._error == self._none, self._error
assert not stderr, 'Expected no stderr, got:\n__________\n%s\n^^^^^^^^^^\n\n' % (stderr, ) self._old_handle_error = self._hub.handle_error
self._hub.handle_error = self._store_error
def assert_stderr_traceback(self, typ, value=None):
if VERBOSE: def _store_error(self, where, type, value, tb):
return del tb
if isinstance(typ, Exception): if self._error != self._none:
if value is None: if self._hub is gevent.getcurrent():
value = str(typ) self._hub.parent.throw(type, value)
typ = typ.__class__.__name__
else:
typ = getattr(typ, '__name__', typ)
stderr = self.unhook_stderr()
assert stderr is not None, repr(stderr)
traceback_re = '^(Traceback \\(most recent call last\\):\n( +.*?\n)+)?^(?P<type>\w+(\.\w+)*): (?P<value>.*?)$'
self.extract_re(traceback_re, type=typ, value=value)
def assert_stderr(self, message):
if VERBOSE:
return
exact_re = '^' + message + '.*?\n$.*'
if re.match(exact_re, self.stderr):
self.extract_re(exact_re)
else:
words_re = '^' + '.*?'.join(message.split()) + '.*?\n$'
if re.match(words_re, self.stderr):
self.extract_re(words_re)
else: else:
if message.endswith('...'): self._hub.loop.run_callback(self._hub.parent.throw, type, value)
another_re = '^' + '.*?'.join(message.split()) + '.*?(\n +.*?$){2,5}\n\n' else:
self.extract_re(another_re) self._error = (where, type, value)
else:
raise AssertionError('%r did not match:\n%r' % (message, self.stderr))
def assert_mainloop_assertion(self, message=None):
self.assert_stderr_traceback('AssertionError', 'Impossible to call blocking function in the event loop callback')
if message is not None:
self.assert_stderr(message)
def _extract_re(self, regex, type=None, value=None):
assert self.stderr is not None
m = re.search(regex, self.stderr, re.DOTALL | re.M)
if m is None:
raise AssertionError('Cannot find traceback:\nregex: %s\nstderr:\n%s' % (regex, self.stderr))
if type is not None:
if m.group('type') != type and m.group('type').split('.')[-1] != type:
raise AssertionError('Unexpected exception type: %r (expected %r)' % (m.group(type), type))
if value is not None:
self.assertEqual(m.group('value'), value)
if DEBUG:
ate = '\n#ATE#: ' + self.stderr[m.start(0):m.end(0)].replace('\n', '\n#ATE#: ') + '\n'
sys.__stderr__.write(ate)
self.stderr = self.stderr[:m.start(0)] + self.stderr[m.end(0) + 1:]
def extract_re(self, regex, type=None, value=None): def peek_error(self):
return self._error
def get_error(self):
try: try:
return self._extract_re(regex, type, value) return self._error
except Exception: finally:
print 'failed to process: %r' % (self.stderr, ) self._error = self._none
raise
def assert_error(self, type=None, value=None, error=None):
if error is None:
error = self.get_error()
if type is not None:
assert error[1] is type, error
if value is not None:
if isinstance(value, str):
assert str(error[2]) == value, error
else:
assert error[2] is value, error
return error
main = unittest.main main = unittest.main
......
...@@ -26,6 +26,7 @@ def patch_all(timeout=None): ...@@ -26,6 +26,7 @@ def patch_all(timeout=None):
import greentest import greentest
unittest.TestCase = greentest.TestCase unittest.TestCase = greentest.TestCase
unittest.TestCase.check_totalrefcount = False unittest.TestCase.check_totalrefcount = False
unittest.TestCase.error_fatal = False
if timeout is not None: if timeout is not None:
unittest.TestCase.__timeout__ = timeout unittest.TestCase.__timeout__ = timeout
......
...@@ -59,7 +59,8 @@ class TestAsyncResult(greentest.TestCase): ...@@ -59,7 +59,8 @@ class TestAsyncResult(greentest.TestCase):
g.kill() g.kill()
class TestAsync_ResultAsLinkTarget(greentest.TestCase): class TestAsyncResultAsLinkTarget(greentest.TestCase):
error_fatal = False
def test_set(self): def test_set(self):
g = gevent.spawn(lambda: 1) g = gevent.spawn(lambda: 1)
...@@ -73,7 +74,7 @@ class TestAsync_ResultAsLinkTarget(greentest.TestCase): ...@@ -73,7 +74,7 @@ class TestAsync_ResultAsLinkTarget(greentest.TestCase):
def test_set_exception(self): def test_set_exception(self):
def func(): def func():
raise greentest.ExpectedException('TestAsync_ResultAsLinkTarget.test_set_exception') raise greentest.ExpectedException('TestAsyncResultAsLinkTarget.test_set_exception')
g = gevent.spawn(func) g = gevent.spawn(func)
s1, s2, s3 = AsyncResult(), AsyncResult(), AsyncResult() s1, s2, s3 = AsyncResult(), AsyncResult(), AsyncResult()
g.link(s1) g.link(s1)
......
import gevent import gevent
import sys import sys
import greentest import greentest
from gevent.hub import get_hub
sys.exc_clear() sys.exc_clear()
...@@ -34,11 +33,10 @@ class Test(greentest.TestCase): ...@@ -34,11 +33,10 @@ class Test(greentest.TestCase):
try: try:
raise error raise error
except: except:
self.hook_stderr() self.expect_one_error()
g = gevent.spawn(hello) g = gevent.spawn(hello)
g.join() g.join()
self.assert_stderr_traceback(expected_error) self.assert_error(ExpectedError, expected_error)
self.assert_stderr('Ignoring ExpectedError in <Greenlet')
if not isinstance(g.exception, ExpectedError): if not isinstance(g.exception, ExpectedError):
raise g.exception raise g.exception
try: try:
...@@ -47,7 +45,7 @@ class Test(greentest.TestCase): ...@@ -47,7 +45,7 @@ class Test(greentest.TestCase):
assert ex is error, (ex, error) assert ex is error, (ex, error)
def test2(self): def test2(self):
timer = get_hub().loop.timer(0) timer = self._hub.loop.timer(0)
timer.start(hello2) timer.start(hello2)
gevent.sleep(0.1) gevent.sleep(0.1)
assert sys.exc_info() == (None, None, None), sys.exc_info() assert sys.exc_info() == (None, None, None), sys.exc_info()
......
...@@ -29,6 +29,7 @@ from gevent.event import AsyncResult ...@@ -29,6 +29,7 @@ from gevent.event import AsyncResult
from gevent.queue import Queue from gevent.queue import Queue
DELAY = 0.01 DELAY = 0.01
greentest.TestCase.error_fatal = False
class ExpectedError(Exception): class ExpectedError(Exception):
......
...@@ -53,16 +53,22 @@ class TestExceptionInMainloop(greentest.TestCase): ...@@ -53,16 +53,22 @@ class TestExceptionInMainloop(greentest.TestCase):
assert delay >= DELAY * 0.9, 'sleep returned after %s seconds (was scheduled for %s)' % (delay, DELAY) assert delay >= DELAY * 0.9, 'sleep returned after %s seconds (was scheduled for %s)' % (delay, DELAY)
error = greentest.ExpectedException('TestExceptionInMainloop.test_sleep/fail')
def fail(): def fail():
raise greentest.ExpectedException('TestExceptionInMainloop.test_sleep/fail') raise error
t = get_hub().loop.timer(0.001) t = get_hub().loop.timer(0.001)
t.start(fail) t.start(fail)
self.expect_one_error()
start = time.time() start = time.time()
gevent.sleep(DELAY) gevent.sleep(DELAY)
delay = time.time() - start delay = time.time() - start
self.assert_error(value=error)
assert delay >= DELAY * 0.9, 'sleep returned after %s seconds (was scheduled for %s)' % (delay, DELAY) assert delay >= DELAY * 0.9, 'sleep returned after %s seconds (was scheduled for %s)' % (delay, DELAY)
......
...@@ -628,20 +628,25 @@ class TestInputReadlines(TestInputReadline): ...@@ -628,20 +628,25 @@ class TestInputReadlines(TestInputReadline):
class TestError(TestCase): class TestError(TestCase):
@staticmethod error = greentest.ExpectedException('TestError.application')
def application(env, start_response): error_fatal = False
raise greentest.ExpectedException('TestError.application')
def application(self, env, start_response):
raise self.error
def test(self): def test(self):
self.expect_one_error()
self.urlopen(code=500) self.urlopen(code=500)
self.assert_error(greentest.ExpectedException, self.error)
class TestError_after_start_response(TestError): class TestError_after_start_response(TestError):
@staticmethod error = greentest.ExpectedException('TestError_after_start_response.application')
def application(env, start_response):
def application(self, env, start_response):
start_response('200 OK', [('Content-Type', 'text/plain')]) start_response('200 OK', [('Content-Type', 'text/plain')])
raise greentest.ExpectedException('TestError_after_start_response.application') raise self.error
class TestEmptyYield(TestCase): class TestEmptyYield(TestCase):
...@@ -845,7 +850,8 @@ class ChunkedInputTests(TestCase): ...@@ -845,7 +850,8 @@ class ChunkedInputTests(TestCase):
read_http(fd, body='this is chunked\nline 2\nline3') read_http(fd, body='this is chunked\nline 2\nline3')
def test_close_before_finished(self): def test_close_before_finished(self):
self.hook_stderr() if server_implements_chunked:
self.expect_one_error()
body = '4\r\nthi' body = '4\r\nthi'
req = "POST /short-read HTTP/1.1\r\ntransfer-encoding: Chunked\r\n\r\n" + body req = "POST /short-read HTTP/1.1\r\ntransfer-encoding: Chunked\r\n\r\n" + body
fd = self.connect().makefile(bufsize=1) fd = self.connect().makefile(bufsize=1)
...@@ -853,7 +859,7 @@ class ChunkedInputTests(TestCase): ...@@ -853,7 +859,7 @@ class ChunkedInputTests(TestCase):
fd.close() fd.close()
gevent.sleep(0.01) gevent.sleep(0.01)
if server_implements_chunked: if server_implements_chunked:
self.assert_stderr_traceback(IOError, 'unexpected end of file while parsing chunked data') self.assert_error(IOError, 'unexpected end of file while parsing chunked data')
class Expect100ContinueTests(TestCase): class Expect100ContinueTests(TestCase):
......
...@@ -155,12 +155,11 @@ class TestCase(greentest.TestCase): ...@@ -155,12 +155,11 @@ class TestCase(greentest.TestCase):
def _test_invalid_callback(self): def _test_invalid_callback(self):
try: try:
self.hook_stderr() self.expect_one_error()
self.server = self.ServerClass(('127.0.0.1', 0), lambda: None) self.server = self.ServerClass(('127.0.0.1', 0), lambda: None)
self.server.start() self.server.start()
self.assert500() self.assert500()
self.assert_stderr_traceback('TypeError') self.assert_error(TypeError)
self.assert_stderr(self.invalid_callback_message)
finally: finally:
self.server.stop() self.server.stop()
...@@ -279,13 +278,11 @@ class TestDefaultSpawn(TestCase): ...@@ -279,13 +278,11 @@ class TestDefaultSpawn(TestCase):
def test_error_in_spawn(self): def test_error_in_spawn(self):
self.init_server() self.init_server()
assert self.server.started assert self.server.started
self.hook_stderr()
error = ExpectedError('test_error_in_spawn') error = ExpectedError('test_error_in_spawn')
self.server._spawn = lambda *args: gevent.getcurrent().throw(error) self.server._spawn = lambda *args: gevent.getcurrent().throw(error)
self.expect_one_error()
self.assertAcceptedConnectionError() self.assertAcceptedConnectionError()
self.assert_stderr_traceback(error) self.assert_error(ExpectedError, error)
#self.assert_stderr('^WARNING: <SimpleStreamServer .*?>: ignoring test_error_in_spawn \\(sleeping \d.\d+ seconds\\)\n$')
self.assert_stderr('<.*?>: Failed to handle...')
return return
if Settings.restartable: if Settings.restartable:
assert not self.server.started assert not self.server.started
...@@ -323,6 +320,8 @@ class TestPoolSpawn(TestDefaultSpawn): ...@@ -323,6 +320,8 @@ class TestPoolSpawn(TestDefaultSpawn):
gevent.sleep(0.1) gevent.sleep(0.1)
self.assertRequestSucceeded() self.assertRequestSucceeded()
test_pool_full.error_fatal = False
class TestNoneSpawn(TestCase): class TestNoneSpawn(TestCase):
...@@ -339,9 +338,9 @@ class TestNoneSpawn(TestCase): ...@@ -339,9 +338,9 @@ class TestNoneSpawn(TestCase):
gevent.sleep(0) gevent.sleep(0)
self.server = Settings.ServerClass(('127.0.0.1', 0), sleep, spawn=None) self.server = Settings.ServerClass(('127.0.0.1', 0), sleep, spawn=None)
self.server.start() self.server.start()
self.hook_stderr() self.expect_one_error()
self.assert500() self.assert500()
self.assert_mainloop_assertion(self.invalid_callback_message) self.assert_error(AssertionError, 'Impossible to call blocking function in the event loop callback')
class ExpectedError(Exception): class ExpectedError(Exception):
......
...@@ -15,6 +15,7 @@ if hasattr(signal, 'SIGALRM'): ...@@ -15,6 +15,7 @@ if hasattr(signal, 'SIGALRM'):
class TestSignal(greentest.TestCase): class TestSignal(greentest.TestCase):
error_fatal = False
__timeout__ = 5 __timeout__ = 5
def test(self): def test(self):
......
...@@ -13,6 +13,8 @@ MSG = 'should be re-raised and caught' ...@@ -13,6 +13,8 @@ MSG = 'should be re-raised and caught'
class Test(greentest.TestCase): class Test(greentest.TestCase):
error_fatal = False
def test_sys_exit(self): def test_sys_exit(self):
self.start(sys.exit, MSG) self.start(sys.exit, MSG)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment