Commit 22293b78 authored by Nick Coghlan's avatar Nick Coghlan

Also chain codec exceptions that allow weakrefs

The zlib and hex codecs throw custom exception types with
weakref support if the input type is valid, but the data
fails validation. Make sure the exception chaining in the
codec infrastructure can wrap those as well.
parent c72e4285
...@@ -2402,6 +2402,25 @@ class TransformCodecTest(unittest.TestCase): ...@@ -2402,6 +2402,25 @@ class TransformCodecTest(unittest.TestCase):
self.assertTrue(isinstance(failure.exception.__cause__, self.assertTrue(isinstance(failure.exception.__cause__,
AttributeError)) AttributeError))
def test_custom_zlib_error_is_wrapped(self):
# Check zlib codec gives a good error for malformed input
msg = "^decoding with 'zlib_codec' codec failed"
with self.assertRaisesRegex(Exception, msg) as failure:
b"hello".decode("zlib_codec")
self.assertTrue(isinstance(failure.exception.__cause__,
type(failure.exception)))
def test_custom_hex_error_is_wrapped(self):
# Check hex codec gives a good error for malformed input
msg = "^decoding with 'hex_codec' codec failed"
with self.assertRaisesRegex(Exception, msg) as failure:
b"hello".decode("hex_codec")
self.assertTrue(isinstance(failure.exception.__cause__,
type(failure.exception)))
# Unfortunately, the bz2 module throws OSError, which the codec
# machinery currently can't wrap :(
def test_bad_decoding_output_type(self): def test_bad_decoding_output_type(self):
# Check bytes.decode and bytearray.decode give a good error # Check bytes.decode and bytearray.decode give a good error
# message for binary -> binary codecs # message for binary -> binary codecs
...@@ -2466,15 +2485,15 @@ class ExceptionChainingTest(unittest.TestCase): ...@@ -2466,15 +2485,15 @@ class ExceptionChainingTest(unittest.TestCase):
with self.assertRaisesRegex(exc_type, full_msg) as caught: with self.assertRaisesRegex(exc_type, full_msg) as caught:
yield caught yield caught
def check_wrapped(self, obj_to_raise, msg): def check_wrapped(self, obj_to_raise, msg, exc_type=RuntimeError):
self.set_codec(obj_to_raise) self.set_codec(obj_to_raise)
with self.assertWrapped("encoding", RuntimeError, msg): with self.assertWrapped("encoding", exc_type, msg):
"str_input".encode(self.codec_name) "str_input".encode(self.codec_name)
with self.assertWrapped("encoding", RuntimeError, msg): with self.assertWrapped("encoding", exc_type, msg):
codecs.encode("str_input", self.codec_name) codecs.encode("str_input", self.codec_name)
with self.assertWrapped("decoding", RuntimeError, msg): with self.assertWrapped("decoding", exc_type, msg):
b"bytes input".decode(self.codec_name) b"bytes input".decode(self.codec_name)
with self.assertWrapped("decoding", RuntimeError, msg): with self.assertWrapped("decoding", exc_type, msg):
codecs.decode(b"bytes input", self.codec_name) codecs.decode(b"bytes input", self.codec_name)
def test_raise_by_type(self): def test_raise_by_type(self):
...@@ -2484,6 +2503,18 @@ class ExceptionChainingTest(unittest.TestCase): ...@@ -2484,6 +2503,18 @@ class ExceptionChainingTest(unittest.TestCase):
msg = "This should be wrapped" msg = "This should be wrapped"
self.check_wrapped(RuntimeError(msg), msg) self.check_wrapped(RuntimeError(msg), msg)
def test_raise_grandchild_subclass_exact_size(self):
msg = "This should be wrapped"
class MyRuntimeError(RuntimeError):
__slots__ = ()
self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError)
def test_raise_subclass_with_weakref_support(self):
msg = "This should be wrapped"
class MyRuntimeError(RuntimeError):
pass
self.check_wrapped(MyRuntimeError(msg), msg, MyRuntimeError)
@contextlib.contextmanager @contextlib.contextmanager
def assertNotWrapped(self, operation, exc_type, msg_re, msg=None): def assertNotWrapped(self, operation, exc_type, msg_re, msg=None):
if msg is None: if msg is None:
......
...@@ -2630,16 +2630,27 @@ _PyErr_TrySetFromCause(const char *format, ...) ...@@ -2630,16 +2630,27 @@ _PyErr_TrySetFromCause(const char *format, ...)
PyTypeObject *caught_type; PyTypeObject *caught_type;
PyObject **dictptr; PyObject **dictptr;
PyObject *instance_args; PyObject *instance_args;
Py_ssize_t num_args; Py_ssize_t num_args, caught_type_size, base_exc_size;
PyObject *new_exc, *new_val, *new_tb; PyObject *new_exc, *new_val, *new_tb;
va_list vargs; va_list vargs;
int same_basic_size;
PyErr_Fetch(&exc, &val, &tb); PyErr_Fetch(&exc, &val, &tb);
caught_type = (PyTypeObject *)exc; caught_type = (PyTypeObject *)exc;
/* Ensure type info indicates no extra state is stored at the C level */ /* Ensure type info indicates no extra state is stored at the C level
* and that the type can be reinstantiated using PyErr_Format
*/
caught_type_size = caught_type->tp_basicsize;
base_exc_size = _PyExc_BaseException.tp_basicsize;
same_basic_size = (
caught_type_size == base_exc_size ||
(PyType_SUPPORTS_WEAKREFS(caught_type) &&
(caught_type_size == base_exc_size + sizeof(PyObject *))
)
);
if (caught_type->tp_init != (initproc)BaseException_init || if (caught_type->tp_init != (initproc)BaseException_init ||
caught_type->tp_new != BaseException_new || caught_type->tp_new != BaseException_new ||
caught_type->tp_basicsize != _PyExc_BaseException.tp_basicsize || !same_basic_size ||
caught_type->tp_itemsize != _PyExc_BaseException.tp_itemsize) { caught_type->tp_itemsize != _PyExc_BaseException.tp_itemsize) {
/* We can't be sure we can wrap this safely, since it may contain /* We can't be sure we can wrap this safely, since it may contain
* more state than just the exception type. Accordingly, we just * more state than just the exception type. Accordingly, we just
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment