Commit 247c0171 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Support overriding AssertionError

Apparently the 'assert' statement works by looking up the "AssertionError"
name in the global scope, and then raising an exception of that type.

pytest uses this to override assertion behavior.
parent ed98d1bd
...@@ -843,7 +843,10 @@ Value ASTInterpreter::visit_assert(AST_Assert* node) { ...@@ -843,7 +843,10 @@ Value ASTInterpreter::visit_assert(AST_Assert* node) {
Value v = visit_expr(node->test); Value v = visit_expr(node->test);
assert(v.o->cls == int_cls && static_cast<BoxedInt*>(v.o)->n == 0); assert(v.o->cls == int_cls && static_cast<BoxedInt*>(v.o)->n == 0);
#endif #endif
assertFail(source_info->parent_module, node->msg ? visit_expr(node->msg).o : 0);
static std::string AssertionError_str("AssertionError");
Box* assertion_type = getGlobal(globals, &AssertionError_str);
assertFail(assertion_type, node->msg ? visit_expr(node->msg).o : 0);
return Value(); return Value();
} }
......
...@@ -1660,6 +1660,7 @@ private: ...@@ -1660,6 +1660,7 @@ private:
} }
void doAssert(AST_Assert* node, UnwindInfo unw_info) { void doAssert(AST_Assert* node, UnwindInfo unw_info) {
// cfg translates all asserts into only 'assert 0' on the failing path.
AST_expr* test = node->test; AST_expr* test = node->test;
assert(test->type == AST_TYPE::Num); assert(test->type == AST_TYPE::Num);
AST_Num* num = ast_cast<AST_Num>(test); AST_Num* num = ast_cast<AST_Num>(test);
...@@ -1667,7 +1668,12 @@ private: ...@@ -1667,7 +1668,12 @@ private:
assert(num->n_int == 0); assert(num->n_int == 0);
std::vector<llvm::Value*> llvm_args; std::vector<llvm::Value*> llvm_args;
llvm_args.push_back(embedParentModulePtr());
// We could patchpoint this or try to avoid the overhead, but this should only
// happen when the assertion is actually thrown so I don't think it will be necessary.
static std::string AssertionError_str("AssertionError");
llvm_args.push_back(emitter.createCall2(unw_info, g.funcs.getGlobal, embedParentModulePtr(),
embedRelocatablePtr(&AssertionError_str, g.llvm_str_type_ptr)));
ConcreteCompilerVariable* converted_msg = NULL; ConcreteCompilerVariable* converted_msg = NULL;
if (node->msg) { if (node->msg) {
......
...@@ -220,12 +220,13 @@ extern "C" bool isSubclass(BoxedClass* child, BoxedClass* parent) { ...@@ -220,12 +220,13 @@ extern "C" bool isSubclass(BoxedClass* child, BoxedClass* parent) {
return PyType_IsSubtype(child, parent); return PyType_IsSubtype(child, parent);
} }
extern "C" void assertFail(BoxedModule* inModule, Box* msg) { extern "C" void assertFail(Box* assertion_type, Box* msg) {
RELEASE_ASSERT(assertion_type->cls == type_cls, "%s", assertion_type->cls->tp_name);
if (msg) { if (msg) {
BoxedString* tostr = str(msg); BoxedString* tostr = str(msg);
raiseExcHelper(AssertionError, "%s", tostr->data()); raiseExcHelper(static_cast<BoxedClass*>(assertion_type), "%s", tostr->data());
} else { } else {
raiseExcHelper(AssertionError, ""); raiseExcHelper(static_cast<BoxedClass*>(assertion_type), "");
} }
} }
......
...@@ -83,7 +83,7 @@ extern "C" Box* importStar(Box* from_module, BoxedModule* to_module); ...@@ -83,7 +83,7 @@ extern "C" Box* importStar(Box* from_module, BoxedModule* to_module);
extern "C" Box** unpackIntoArray(Box* obj, int64_t expected_size); extern "C" Box** unpackIntoArray(Box* obj, int64_t expected_size);
extern "C" void assertNameDefined(bool b, const char* name, BoxedClass* exc_cls, bool local_var_msg); extern "C" void assertNameDefined(bool b, const char* name, BoxedClass* exc_cls, bool local_var_msg);
extern "C" void assertFailDerefNameDefined(const char* name); extern "C" void assertFailDerefNameDefined(const char* name);
extern "C" void assertFail(BoxedModule* inModule, Box* msg); extern "C" void assertFail(Box* assertion_type, Box* msg);
extern "C" bool isSubclass(BoxedClass* child, BoxedClass* parent); extern "C" bool isSubclass(BoxedClass* child, BoxedClass* parent);
extern "C" BoxedClosure* createClosure(BoxedClosure* parent_closure, size_t size); extern "C" BoxedClosure* createClosure(BoxedClosure* parent_closure, size_t size);
......
orig_ae = AssertionError
class MyAssertionError(Exception):
pass
s = """
try:
assert 0
except Exception as e:
print type(e)
"""
exec s
import __builtin__
__builtin__.AssertionError = MyAssertionError
exec s
class MyAssertionError2(Exception):
pass
AssertionError = MyAssertionError2
exec s
exec s in {}
class MyAssertionError3(Exception):
pass
def f1():
# assert is hardcoded to look up "AssertionError" in the global scope,
# even if there is a store to it locally:
AssertionError = MyAssertionError3
exec s
try:
assert 0
except Exception as e:
print type(e)
f1()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment