Commit 65a7a2bb authored by Denis Bilenko's avatar Denis Bilenko

greentest: instead of parsing stderr, hook up into Hub.handle_error

parent 097890fc
......@@ -24,8 +24,6 @@ import sys
import unittest
from unittest import TestCase as BaseTestCase
import time
import traceback
import re
import os
from os.path import basename, splitext
import gevent
......@@ -86,6 +84,45 @@ def wrap_refcount(method):
return wrapped
def wrap_error_fatal(method):
@wraps(method)
def wrapped(self, *args, **kwargs):
SYSTEM_ERROR = self._hub.SYSTEM_ERROR
self._hub.SYSTEM_ERROR = object
try:
return method(self, *args, **kwargs)
finally:
self._hub.SYSTEM_ERROR = SYSTEM_ERROR
return wrapped
def wrap_restore_handle_error(method):
@wraps(method)
def wrapped(self, *args, **kwargs):
old = self._hub.handle_error
try:
return method(self, *args, **kwargs)
finally:
self._hub.handle_error = old
if self.peek_error()[0] is not None:
gevent.getcurrent().throw(*self.peek_error()[1:])
return wrapped
def _get_class_attr(classDict, bases, attr, default=AttributeError):
NONE = object()
value = classDict.get(attr, NONE)
if value is not NONE:
return value
for base in bases:
value = getattr(bases[0], attr, NONE)
if value is not NONE:
return value
if default is AttributeError:
raise AttributeError('Attribute %r not found\n%s\n%s\n' % (attr, classDict, bases))
return default
class TestCaseMetaClass(type):
# wrap each test method with
# a) timeout check
......@@ -94,13 +131,18 @@ class TestCaseMetaClass(type):
timeout = classDict.get('__timeout__', 'NONE')
if timeout == 'NONE':
timeout = getattr(bases[0], '__timeout__', None)
check_totalrefcount = classDict.get('check_totalrefcount')
if check_totalrefcount is None:
check_totalrefcount = getattr(bases[0], 'check_totalrefcount', True)
check_totalrefcount = _get_class_attr(classDict, bases, 'check_totalrefcount', True)
error_fatal = _get_class_attr(classDict, bases, 'error_fatal', True)
for key, value in classDict.items():
if (key.startswith('test_') or key == 'test') and callable(value):
if key.startswith('test') and callable(value):
classDict.pop(key)
value = wrap_timeout(timeout, value)
my_error_fatal = getattr(value, 'error_fatal', None)
if my_error_fatal is None:
my_error_fatal = error_fatal
if my_error_fatal:
value = wrap_error_fatal(value)
value = wrap_restore_handle_error(value)
if check_totalrefcount:
value = wrap_refcount(value)
classDict[key] = value
......@@ -112,8 +154,8 @@ class TestCase(BaseTestCase):
__metaclass__ = TestCaseMetaClass
__timeout__ = 1
switch_expected = 'default'
error_fatal = True
_switch_count = None
check_totalrefcount = True
def __init__(self, *args, **kwargs):
BaseTestCase.__init__(self, *args, **kwargs)
......@@ -126,22 +168,25 @@ class TestCase(BaseTestCase):
return BaseTestCase.run(self, *args, **kwargs)
def setUp(self):
hub = gevent.hub._get_hub()
if hub is not None:
hub.loop.update_now()
self._hub.loop.update_now()
if hasattr(self._hub, 'switch_count'):
self._switch_count = self._hub.switch_count
def tearDown(self):
if hasattr(self, 'cleanup'):
self.cleanup()
try:
if not hasattr(self, 'stderr'):
self.unhook_stderr()
if hasattr(self, 'stderr'):
sys.__stderr__.write(self.stderr)
except:
traceback.print_exc()
if self._switch_count is not None and hasattr(self._hub, 'switch_count'):
msg = ''
if self._hub.switch_count < self._switch_count:
msg = 'hub.switch_count decreased?\n'
elif self._hub.switch_count == self._switch_count:
if self.switch_expected:
msg = '%s.%s did not switch\n' % (type(self).__name__, self.testname)
elif self._hub.switch_count > self._switch_count:
if not self.switch_expected:
msg = '%s.%s switched but expected not to\n' % (type(self).__name__, self.testname)
if msg:
print >> sys.stderr, 'WARNING: ' + msg
@property
def testname(self):
......@@ -163,86 +208,44 @@ class TestCase(BaseTestCase):
def fullname(self):
return splitext(basename(self.modulename))[0] + '.' + self.testcasename
def hook_stderr(self):
if VERBOSE:
return
from cStringIO import StringIO
self.new_stderr = StringIO()
self.old_stderr = sys.stderr
sys.stderr = self.new_stderr
def unhook_stderr(self):
if VERBOSE:
return
try:
value = self.new_stderr.getvalue()
except AttributeError:
return None
sys.stderr = self.old_stderr
self.stderr = value
return value
_none = (None, None, None)
_error = _none
def assert_no_stderr(self):
stderr = self.unhook_stderr()
assert not stderr, 'Expected no stderr, got:\n__________\n%s\n^^^^^^^^^^\n\n' % (stderr, )
def assert_stderr_traceback(self, typ, value=None):
if VERBOSE:
return
if isinstance(typ, Exception):
if value is None:
value = str(typ)
typ = typ.__class__.__name__
else:
typ = getattr(typ, '__name__', typ)
stderr = self.unhook_stderr()
assert stderr is not None, repr(stderr)
traceback_re = '^(Traceback \\(most recent call last\\):\n( +.*?\n)+)?^(?P<type>\w+(\.\w+)*): (?P<value>.*?)$'
self.extract_re(traceback_re, type=typ, value=value)
def assert_stderr(self, message):
if VERBOSE:
return
exact_re = '^' + message + '.*?\n$.*'
if re.match(exact_re, self.stderr):
self.extract_re(exact_re)
else:
words_re = '^' + '.*?'.join(message.split()) + '.*?\n$'
if re.match(words_re, self.stderr):
self.extract_re(words_re)
def expect_one_error(self):
assert self._error == self._none, self._error
self._old_handle_error = self._hub.handle_error
self._hub.handle_error = self._store_error
def _store_error(self, where, type, value, tb):
del tb
if self._error != self._none:
if self._hub is gevent.getcurrent():
self._hub.parent.throw(type, value)
else:
if message.endswith('...'):
another_re = '^' + '.*?'.join(message.split()) + '.*?(\n +.*?$){2,5}\n\n'
self.extract_re(another_re)
self._hub.loop.run_callback(self._hub.parent.throw, type, value)
else:
raise AssertionError('%r did not match:\n%r' % (message, self.stderr))
def assert_mainloop_assertion(self, message=None):
self.assert_stderr_traceback('AssertionError', 'Impossible to call blocking function in the event loop callback')
if message is not None:
self.assert_stderr(message)
def _extract_re(self, regex, type=None, value=None):
assert self.stderr is not None
m = re.search(regex, self.stderr, re.DOTALL | re.M)
if m is None:
raise AssertionError('Cannot find traceback:\nregex: %s\nstderr:\n%s' % (regex, self.stderr))
if type is not None:
if m.group('type') != type and m.group('type').split('.')[-1] != type:
raise AssertionError('Unexpected exception type: %r (expected %r)' % (m.group(type), type))
if value is not None:
self.assertEqual(m.group('value'), value)
if DEBUG:
ate = '\n#ATE#: ' + self.stderr[m.start(0):m.end(0)].replace('\n', '\n#ATE#: ') + '\n'
sys.__stderr__.write(ate)
self.stderr = self.stderr[:m.start(0)] + self.stderr[m.end(0) + 1:]
self._error = (where, type, value)
def extract_re(self, regex, type=None, value=None):
def peek_error(self):
return self._error
def get_error(self):
try:
return self._extract_re(regex, type, value)
except Exception:
print 'failed to process: %r' % (self.stderr, )
raise
return self._error
finally:
self._error = self._none
def assert_error(self, type=None, value=None, error=None):
if error is None:
error = self.get_error()
if type is not None:
assert error[1] is type, error
if value is not None:
if isinstance(value, str):
assert str(error[2]) == value, error
else:
assert error[2] is value, error
return error
main = unittest.main
......
......@@ -26,6 +26,7 @@ def patch_all(timeout=None):
import greentest
unittest.TestCase = greentest.TestCase
unittest.TestCase.check_totalrefcount = False
unittest.TestCase.error_fatal = False
if timeout is not None:
unittest.TestCase.__timeout__ = timeout
......
......@@ -59,7 +59,8 @@ class TestAsyncResult(greentest.TestCase):
g.kill()
class TestAsync_ResultAsLinkTarget(greentest.TestCase):
class TestAsyncResultAsLinkTarget(greentest.TestCase):
error_fatal = False
def test_set(self):
g = gevent.spawn(lambda: 1)
......@@ -73,7 +74,7 @@ class TestAsync_ResultAsLinkTarget(greentest.TestCase):
def test_set_exception(self):
def func():
raise greentest.ExpectedException('TestAsync_ResultAsLinkTarget.test_set_exception')
raise greentest.ExpectedException('TestAsyncResultAsLinkTarget.test_set_exception')
g = gevent.spawn(func)
s1, s2, s3 = AsyncResult(), AsyncResult(), AsyncResult()
g.link(s1)
......
import gevent
import sys
import greentest
from gevent.hub import get_hub
sys.exc_clear()
......@@ -34,11 +33,10 @@ class Test(greentest.TestCase):
try:
raise error
except:
self.hook_stderr()
self.expect_one_error()
g = gevent.spawn(hello)
g.join()
self.assert_stderr_traceback(expected_error)
self.assert_stderr('Ignoring ExpectedError in <Greenlet')
self.assert_error(ExpectedError, expected_error)
if not isinstance(g.exception, ExpectedError):
raise g.exception
try:
......@@ -47,7 +45,7 @@ class Test(greentest.TestCase):
assert ex is error, (ex, error)
def test2(self):
timer = get_hub().loop.timer(0)
timer = self._hub.loop.timer(0)
timer.start(hello2)
gevent.sleep(0.1)
assert sys.exc_info() == (None, None, None), sys.exc_info()
......
......@@ -29,6 +29,7 @@ from gevent.event import AsyncResult
from gevent.queue import Queue
DELAY = 0.01
greentest.TestCase.error_fatal = False
class ExpectedError(Exception):
......
......@@ -53,16 +53,22 @@ class TestExceptionInMainloop(greentest.TestCase):
assert delay >= DELAY * 0.9, 'sleep returned after %s seconds (was scheduled for %s)' % (delay, DELAY)
error = greentest.ExpectedException('TestExceptionInMainloop.test_sleep/fail')
def fail():
raise greentest.ExpectedException('TestExceptionInMainloop.test_sleep/fail')
raise error
t = get_hub().loop.timer(0.001)
t.start(fail)
self.expect_one_error()
start = time.time()
gevent.sleep(DELAY)
delay = time.time() - start
self.assert_error(value=error)
assert delay >= DELAY * 0.9, 'sleep returned after %s seconds (was scheduled for %s)' % (delay, DELAY)
......
......@@ -628,20 +628,25 @@ class TestInputReadlines(TestInputReadline):
class TestError(TestCase):
@staticmethod
def application(env, start_response):
raise greentest.ExpectedException('TestError.application')
error = greentest.ExpectedException('TestError.application')
error_fatal = False
def application(self, env, start_response):
raise self.error
def test(self):
self.expect_one_error()
self.urlopen(code=500)
self.assert_error(greentest.ExpectedException, self.error)
class TestError_after_start_response(TestError):
@staticmethod
def application(env, start_response):
error = greentest.ExpectedException('TestError_after_start_response.application')
def application(self, env, start_response):
start_response('200 OK', [('Content-Type', 'text/plain')])
raise greentest.ExpectedException('TestError_after_start_response.application')
raise self.error
class TestEmptyYield(TestCase):
......@@ -845,7 +850,8 @@ class ChunkedInputTests(TestCase):
read_http(fd, body='this is chunked\nline 2\nline3')
def test_close_before_finished(self):
self.hook_stderr()
if server_implements_chunked:
self.expect_one_error()
body = '4\r\nthi'
req = "POST /short-read HTTP/1.1\r\ntransfer-encoding: Chunked\r\n\r\n" + body
fd = self.connect().makefile(bufsize=1)
......@@ -853,7 +859,7 @@ class ChunkedInputTests(TestCase):
fd.close()
gevent.sleep(0.01)
if server_implements_chunked:
self.assert_stderr_traceback(IOError, 'unexpected end of file while parsing chunked data')
self.assert_error(IOError, 'unexpected end of file while parsing chunked data')
class Expect100ContinueTests(TestCase):
......
......@@ -155,12 +155,11 @@ class TestCase(greentest.TestCase):
def _test_invalid_callback(self):
try:
self.hook_stderr()
self.expect_one_error()
self.server = self.ServerClass(('127.0.0.1', 0), lambda: None)
self.server.start()
self.assert500()
self.assert_stderr_traceback('TypeError')
self.assert_stderr(self.invalid_callback_message)
self.assert_error(TypeError)
finally:
self.server.stop()
......@@ -279,13 +278,11 @@ class TestDefaultSpawn(TestCase):
def test_error_in_spawn(self):
self.init_server()
assert self.server.started
self.hook_stderr()
error = ExpectedError('test_error_in_spawn')
self.server._spawn = lambda *args: gevent.getcurrent().throw(error)
self.expect_one_error()
self.assertAcceptedConnectionError()
self.assert_stderr_traceback(error)
#self.assert_stderr('^WARNING: <SimpleStreamServer .*?>: ignoring test_error_in_spawn \\(sleeping \d.\d+ seconds\\)\n$')
self.assert_stderr('<.*?>: Failed to handle...')
self.assert_error(ExpectedError, error)
return
if Settings.restartable:
assert not self.server.started
......@@ -323,6 +320,8 @@ class TestPoolSpawn(TestDefaultSpawn):
gevent.sleep(0.1)
self.assertRequestSucceeded()
test_pool_full.error_fatal = False
class TestNoneSpawn(TestCase):
......@@ -339,9 +338,9 @@ class TestNoneSpawn(TestCase):
gevent.sleep(0)
self.server = Settings.ServerClass(('127.0.0.1', 0), sleep, spawn=None)
self.server.start()
self.hook_stderr()
self.expect_one_error()
self.assert500()
self.assert_mainloop_assertion(self.invalid_callback_message)
self.assert_error(AssertionError, 'Impossible to call blocking function in the event loop callback')
class ExpectedError(Exception):
......
......@@ -15,6 +15,7 @@ if hasattr(signal, 'SIGALRM'):
class TestSignal(greentest.TestCase):
error_fatal = False
__timeout__ = 5
def test(self):
......
......@@ -13,6 +13,8 @@ MSG = 'should be re-raised and caught'
class Test(greentest.TestCase):
error_fatal = False
def test_sys_exit(self):
self.start(sys.exit, MSG)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment