Commit d1387d74 authored by Marius Wachtler's avatar Marius Wachtler

Add with statement support for old style classes

and let the interpreter actually do a cls lookup when encountering a ClsAttribute node...
parent 71e74437
......@@ -484,7 +484,7 @@ Value ASTInterpreter::visit_invoke(AST_Invoke* node) {
}
Value ASTInterpreter::visit_clsAttribute(AST_ClsAttribute* node) {
return getattr(visit_expr(node->value).o, node->attr.c_str());
return getclsattr(visit_expr(node->value).o, node->attr.c_str());
}
Value ASTInterpreter::visit_augBinOp(AST_AugBinOp* node) {
......
......@@ -1207,6 +1207,13 @@ extern "C" Box* getclsattr(Box* obj, const char* attr) {
Box* gotten;
if (attr[0] == '_' && attr[1] == '_' && PyInstance_Check(obj)) {
// __enter__ and __exit__ need special treatment.
static std::string enter_str("__enter__"), exit_str("__exit__");
if (attr == enter_str || attr == exit_str)
return getattr(obj, attr);
}
#if 0
std::unique_ptr<Rewriter> rewriter(Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 2, 1, "getclsattr"));
......@@ -2508,6 +2515,12 @@ extern "C" Box* callattr(Box* obj, const std::string* attr, CallattrFlags flags,
LookupScope scope = flags.cls_only ? CLASS_ONLY : CLASS_OR_INST;
if ((*attr)[0] == '_' && (*attr)[1] == '_' && PyInstance_Check(obj)) {
// __enter__ and __exit__ need special treatment.
if (*attr == "__enter__" || *attr == "__exit__")
scope = CLASS_OR_INST;
}
if (rewriter.get()) {
// TODO feel weird about doing this; it either isn't necessary
// or this kind of thing is necessary in a lot more places
......
# Make sure __exit__ gets called in various exit scenarios:
class C(object):
class NewC(object):
def __enter__(self):
print "__enter__"
def __exit__(self, type, val, tb):
print "__exit__"
def f():
class OldC:
def __enter__(self):
print "__enter__"
def __exit__(self, type, val, tb):
print "__exit__"
def f(C):
with C():
pass
with C() as n:
......@@ -42,9 +49,10 @@ def f():
with C() as o:
return
f()
f(NewC)
f(OldC)
def f2(b):
def f2(b, C):
n = 2
while n:
print n
......@@ -56,6 +64,6 @@ def f2(b):
else:
return "b false"
print f2(False)
print f2(True)
print f2(False, NewC), f2(False, OldC)
print f2(True, NewC), f2(True, OldC)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment