Commit ddabda9a authored by Kevin Modzelewski's avatar Kevin Modzelewski

Implement closures

Implementation is pretty straightforward for now:
- find all names that get accessed from a nested function
- if any, create a closure object at function entry
- any time we set a name accessed from a nested function,
  update its value in the closure
- when evaluating a functiondef that needs a closure, attach
  the created closure to the created function object.

Closures are currently passed as an extra argument before any
python-level args, which I'm not convinced is the right strategy.
It's works out fine but it feels messy to say that functions
can have different C-level calling conventions.
It felt worse to include the closure as part of the python-level arg
passing.
Maybe it should be passed after all the other arguments?

Closures are currently just simple objects, on which we set and get
Python-level attributes.  The performance (which I haven't tested)
relies on attribute access being made fast through the hidden-class
inline caches.

There are a number of ways that this could be improved:
- be smarter about when we create the closure object, or when we
  update it.
- not create empty pass-through closures
- give the closures a pre-defined shape, since we know at irgen-time
  what names can get set.  could probably avoid the inline cache
  machinery and also have better code.
parent 7853bcf2
...@@ -76,15 +76,9 @@ public: ...@@ -76,15 +76,9 @@ public:
virtual ScopeInfo* getParent() { return parent; } virtual ScopeInfo* getParent() { return parent; }
virtual bool createsClosure() { virtual bool createsClosure() { return usage->referenced_from_nested.size() > 0; }
assert(0);
return usage->referenced_from_nested.size() > 0;
}
virtual bool takesClosure() { virtual bool takesClosure() { return usage->got_from_closure.size() > 0; }
assert(0);
return false;
}
virtual bool refersToGlobal(const std::string& name) { virtual bool refersToGlobal(const std::string& name) {
// HAX // HAX
...@@ -143,7 +137,6 @@ public: ...@@ -143,7 +137,6 @@ public:
return true; return true;
} }
virtual bool visit_arguments(AST_arguments* node) { return false; }
virtual bool visit_assert(AST_Assert* node) { return false; } virtual bool visit_assert(AST_Assert* node) { return false; }
virtual bool visit_assign(AST_Assign* node) { return false; } virtual bool visit_assign(AST_Assign* node) { return false; }
virtual bool visit_augassign(AST_AugAssign* node) { return false; } virtual bool visit_augassign(AST_AugAssign* node) { return false; }
...@@ -201,8 +194,15 @@ public: ...@@ -201,8 +194,15 @@ public:
virtual bool visit_classdef(AST_ClassDef* node) { virtual bool visit_classdef(AST_ClassDef* node) {
if (node == orig_node) { if (node == orig_node) {
return false; for (AST_stmt* s : node->body)
s->accept(this);
return true;
} else { } else {
for (auto* e : node->bases)
e->accept(this);
for (auto* e : node->decorator_list)
e->accept(this);
doWrite(node->name); doWrite(node->name);
(*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur); (*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur);
collect(node, map); collect(node, map);
...@@ -212,8 +212,21 @@ public: ...@@ -212,8 +212,21 @@ public:
virtual bool visit_functiondef(AST_FunctionDef* node) { virtual bool visit_functiondef(AST_FunctionDef* node) {
if (node == orig_node) { if (node == orig_node) {
return false; for (AST_expr* e : node->args->args)
e->accept(this);
if (node->args->vararg.size())
doWrite(node->args->vararg);
if (node->args->kwarg.size())
doWrite(node->args->kwarg);
for (AST_stmt* s : node->body)
s->accept(this);
return true;
} else { } else {
for (auto* e : node->args->defaults)
e->accept(this);
for (auto* e : node->decorator_list)
e->accept(this);
doWrite(node->name); doWrite(node->name);
(*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur); (*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur);
collect(node, map); collect(node, map);
...@@ -278,6 +291,8 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { ...@@ -278,6 +291,8 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
if (usage->written.count(*it2)) if (usage->written.count(*it2))
continue; continue;
std::vector<ScopeNameUsage*> intermediate_parents;
ScopeNameUsage* parent = usage->parent; ScopeNameUsage* parent = usage->parent;
while (parent) { while (parent) {
if (parent->node->type == AST_TYPE::ClassDef) { if (parent->node->type == AST_TYPE::ClassDef) {
...@@ -287,8 +302,15 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { ...@@ -287,8 +302,15 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
} else if (parent->written.count(*it2)) { } else if (parent->written.count(*it2)) {
usage->got_from_closure.insert(*it2); usage->got_from_closure.insert(*it2);
parent->referenced_from_nested.insert(*it2); parent->referenced_from_nested.insert(*it2);
for (ScopeNameUsage* iparent : intermediate_parents) {
iparent->referenced_from_nested.insert(*it2);
iparent->got_from_closure.insert(*it2);
}
break; break;
} else { } else {
intermediate_parents.push_back(parent);
parent = parent->parent; parent = parent->parent;
} }
} }
...@@ -320,19 +342,6 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { ...@@ -320,19 +342,6 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
} }
ScopeInfo* ScopingAnalysis::analyzeSubtree(AST* node) { ScopeInfo* ScopingAnalysis::analyzeSubtree(AST* node) {
#ifndef NDEBUG
std::vector<AST*> flattened;
flatten(parent_module->body, flattened, false);
bool found = 0;
for (AST* n : flattened) {
if (n == node) {
found = true;
break;
}
}
assert(found);
#endif
NameUsageMap usages; NameUsageMap usages;
usages[node] = new ScopeNameUsage(node, NULL); usages[node] = new ScopeNameUsage(node, NULL);
NameCollectorVisitor::collect(node, &usages); NameCollectorVisitor::collect(node, &usages);
......
...@@ -70,7 +70,7 @@ struct GlobalState { ...@@ -70,7 +70,7 @@ struct GlobalState {
llvm::Type* llvm_flavor_type, *llvm_flavor_type_ptr; llvm::Type* llvm_flavor_type, *llvm_flavor_type_ptr;
llvm::Type* llvm_opaque_type; llvm::Type* llvm_opaque_type;
llvm::Type* llvm_str_type_ptr; llvm::Type* llvm_str_type_ptr;
llvm::Type* llvm_clfunction_type_ptr; llvm::Type* llvm_clfunction_type_ptr, *llvm_closure_type_ptr;
llvm::Type* llvm_module_type_ptr, *llvm_bool_type_ptr; llvm::Type* llvm_module_type_ptr, *llvm_bool_type_ptr;
llvm::Type* i1, *i8, *i8_ptr, *i32, *i64, *void_, *double_; llvm::Type* i1, *i8, *i8_ptr, *i32, *i64, *void_, *double_;
llvm::Type* vector_ptr; llvm::Type* vector_ptr;
......
...@@ -518,11 +518,24 @@ ConcreteCompilerVariable* UnknownType::nonzero(IREmitter& emitter, const OpInfo& ...@@ -518,11 +518,24 @@ ConcreteCompilerVariable* UnknownType::nonzero(IREmitter& emitter, const OpInfo&
return new ConcreteCompilerVariable(BOOL, rtn_val, true); return new ConcreteCompilerVariable(BOOL, rtn_val, true);
} }
CompilerVariable* makeFunction(IREmitter& emitter, CLFunction* f) { CompilerVariable* makeFunction(IREmitter& emitter, CLFunction* f, CompilerVariable* closure) {
// Unlike the CLFunction*, which can be shared between recompilations, the Box* around it // Unlike the CLFunction*, which can be shared between recompilations, the Box* around it
// should be created anew every time the functiondef is encountered // should be created anew every time the functiondef is encountered
llvm::Value* boxed
= emitter.getBuilder()->CreateCall(g.funcs.boxCLFunction, embedConstantPtr(f, g.llvm_clfunction_type_ptr)); llvm::Value* closure_v;
ConcreteCompilerVariable* converted = NULL;
if (closure) {
converted = closure->makeConverted(emitter, closure->getConcreteType());
closure_v = converted->getValue();
} else {
closure_v = embedConstantPtr(nullptr, g.llvm_closure_type_ptr);
}
llvm::Value* boxed = emitter.getBuilder()->CreateCall2(g.funcs.boxCLFunction,
embedConstantPtr(f, g.llvm_clfunction_type_ptr), closure_v);
if (converted)
converted->decvref(emitter);
return new ConcreteCompilerVariable(typeFromClass(function_cls), boxed, true); return new ConcreteCompilerVariable(typeFromClass(function_cls), boxed, true);
} }
...@@ -1147,6 +1160,30 @@ public: ...@@ -1147,6 +1160,30 @@ public:
std::unordered_map<BoxedClass*, NormalObjectType*> NormalObjectType::made; std::unordered_map<BoxedClass*, NormalObjectType*> NormalObjectType::made;
ConcreteCompilerType* STR, *BOXED_INT, *BOXED_FLOAT, *BOXED_BOOL, *NONE; ConcreteCompilerType* STR, *BOXED_INT, *BOXED_FLOAT, *BOXED_BOOL, *NONE;
class ClosureType : public ConcreteCompilerType {
public:
llvm::Type* llvmType() { return g.llvm_closure_type_ptr; }
std::string debugName() { return "closure"; }
CompilerVariable* getattr(IREmitter& emitter, const OpInfo& info, ConcreteCompilerVariable* var,
const std::string* attr, bool cls_only) {
assert(!cls_only);
llvm::Value* bitcast = emitter.getBuilder()->CreateBitCast(var->getValue(), g.llvm_value_type_ptr);
return ConcreteCompilerVariable(UNKNOWN, bitcast, true).getattr(emitter, info, attr, cls_only);
}
void setattr(IREmitter& emitter, const OpInfo& info, ConcreteCompilerVariable* var, const std::string* attr,
CompilerVariable* v) {
llvm::Value* bitcast = emitter.getBuilder()->CreateBitCast(var->getValue(), g.llvm_value_type_ptr);
ConcreteCompilerVariable(UNKNOWN, bitcast, true).setattr(emitter, info, attr, v);
}
virtual ConcreteCompilerType* getConcreteType() { return this; }
// Shouldn't call this:
virtual ConcreteCompilerType* getBoxType() { RELEASE_ASSERT(0, ""); }
} _CLOSURE;
ConcreteCompilerType* CLOSURE = &_CLOSURE;
class StrConstantType : public ValuedCompilerType<std::string*> { class StrConstantType : public ValuedCompilerType<std::string*> {
public: public:
std::string debugName() { return "str_constant"; } std::string debugName() { return "str_constant"; }
......
...@@ -29,7 +29,7 @@ class CompilerType; ...@@ -29,7 +29,7 @@ class CompilerType;
class IREmitter; class IREmitter;
extern ConcreteCompilerType* INT, *BOXED_INT, *FLOAT, *BOXED_FLOAT, *VOID, *UNKNOWN, *BOOL, *STR, *NONE, *LIST, *SLICE, extern ConcreteCompilerType* INT, *BOXED_INT, *FLOAT, *BOXED_FLOAT, *VOID, *UNKNOWN, *BOOL, *STR, *NONE, *LIST, *SLICE,
*MODULE, *DICT, *BOOL, *BOXED_BOOL, *BOXED_TUPLE, *SET; *MODULE, *DICT, *BOOL, *BOXED_BOOL, *BOXED_TUPLE, *SET, *CLOSURE;
extern CompilerType* UNDEF; extern CompilerType* UNDEF;
class CompilerType { class CompilerType {
...@@ -316,7 +316,7 @@ ConcreteCompilerVariable* makeInt(int64_t); ...@@ -316,7 +316,7 @@ ConcreteCompilerVariable* makeInt(int64_t);
ConcreteCompilerVariable* makeFloat(double); ConcreteCompilerVariable* makeFloat(double);
ConcreteCompilerVariable* makeBool(bool); ConcreteCompilerVariable* makeBool(bool);
CompilerVariable* makeStr(std::string*); CompilerVariable* makeStr(std::string*);
CompilerVariable* makeFunction(IREmitter& emitter, CLFunction*); CompilerVariable* makeFunction(IREmitter& emitter, CLFunction*, CompilerVariable* closure);
ConcreteCompilerVariable* undefVariable(); ConcreteCompilerVariable* undefVariable();
CompilerVariable* makeTuple(const std::vector<CompilerVariable*>& elts); CompilerVariable* makeTuple(const std::vector<CompilerVariable*>& elts);
......
...@@ -563,7 +563,8 @@ static void emitBBs(IRGenState* irstate, const char* bb_type, GuardList& out_gua ...@@ -563,7 +563,8 @@ static void emitBBs(IRGenState* irstate, const char* bb_type, GuardList& out_gua
emitter->getBuilder()->SetInsertPoint(llvm_entry_blocks[source->cfg->getStartingBlock()]); emitter->getBuilder()->SetInsertPoint(llvm_entry_blocks[source->cfg->getStartingBlock()]);
} }
generator->unpackArguments(arg_names, cf->spec->arg_types);
generator->doFunctionEntry(arg_names, cf->spec->arg_types);
// Function-entry safepoint: // Function-entry safepoint:
// TODO might be more efficient to do post-call safepoints? // TODO might be more efficient to do post-call safepoints?
...@@ -915,6 +916,9 @@ CompiledFunction* doCompile(SourceInfo* source, const OSREntryDescriptor* entry_ ...@@ -915,6 +916,9 @@ CompiledFunction* doCompile(SourceInfo* source, const OSREntryDescriptor* entry_
ASSERT(nargs == spec->arg_types.size(), "%d %ld", nargs, spec->arg_types.size()); ASSERT(nargs == spec->arg_types.size(), "%d %ld", nargs, spec->arg_types.size());
std::vector<llvm::Type*> llvm_arg_types; std::vector<llvm::Type*> llvm_arg_types;
if (source->scoping->getScopeInfoForNode(source->ast)->takesClosure())
llvm_arg_types.push_back(g.llvm_closure_type_ptr);
if (entry_descriptor == NULL) { if (entry_descriptor == NULL) {
for (int i = 0; i < nargs; i++) { for (int i = 0; i < nargs; i++) {
if (i == 3) { if (i == 3) {
......
...@@ -262,7 +262,7 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) { ...@@ -262,7 +262,7 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) {
_t.end(); _t.end();
if (cf->is_interpreted) if (cf->is_interpreted)
interpretFunction(cf->func, 0, NULL, NULL, NULL, NULL); interpretFunction(cf->func, 0, NULL, NULL, NULL, NULL, NULL);
else else
((void (*)())cf->code)(); ((void (*)())cf->code)();
} }
......
This diff is collapsed.
...@@ -183,8 +183,9 @@ public: ...@@ -183,8 +183,9 @@ public:
virtual ~IRGenerator() {} virtual ~IRGenerator() {}
virtual void unpackArguments(const std::vector<AST_expr*>& arg_names, virtual void doFunctionEntry(const std::vector<AST_expr*>& arg_names,
const std::vector<ConcreteCompilerType*>& arg_types) = 0; const std::vector<ConcreteCompilerType*>& arg_types) = 0;
virtual void giveLocalSymbol(const std::string& name, CompilerVariable* var) = 0; virtual void giveLocalSymbol(const std::string& name, CompilerVariable* var) = 0;
virtual void copySymbolsFrom(SymbolTable* st) = 0; virtual void copySymbolsFrom(SymbolTable* st) = 0;
virtual void run(const CFGBlock* block) = 0; virtual void run(const CFGBlock* block) = 0;
......
...@@ -257,7 +257,7 @@ const LineInfo* getLineInfoForInterpretedFrame(void* frame_ptr) { ...@@ -257,7 +257,7 @@ const LineInfo* getLineInfoForInterpretedFrame(void* frame_ptr) {
} }
} }
Box* interpretFunction(llvm::Function* f, int nargs, Box* arg1, Box* arg2, Box* arg3, Box** args) { Box* interpretFunction(llvm::Function* f, int nargs, Box* closure, Box* arg1, Box* arg2, Box* arg3, Box** args) {
assert(f); assert(f);
#ifdef TIME_INTERPRETS #ifdef TIME_INTERPRETS
...@@ -280,18 +280,21 @@ Box* interpretFunction(llvm::Function* f, int nargs, Box* arg1, Box* arg2, Box* ...@@ -280,18 +280,21 @@ Box* interpretFunction(llvm::Function* f, int nargs, Box* arg1, Box* arg2, Box*
UnregisterHelper helper(frame_ptr); UnregisterHelper helper(frame_ptr);
int arg_num = -1; int arg_num = -1;
int closure_indicator = closure ? 1 : 0;
for (llvm::Argument& arg : f->args()) { for (llvm::Argument& arg : f->args()) {
arg_num++; arg_num++;
if (arg_num == 0) if (arg_num == 0 && closure)
symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(closure)));
else if (arg_num == 0 + closure_indicator)
symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg1))); symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg1)));
else if (arg_num == 1) else if (arg_num == 1 + closure_indicator)
symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg2))); symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg2)));
else if (arg_num == 2) else if (arg_num == 2 + closure_indicator)
symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg3))); symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg3)));
else { else {
assert(arg_num == 3); assert(arg_num == 3 + closure_indicator);
assert(f->getArgumentList().size() == 4); assert(f->getArgumentList().size() == 4 + closure_indicator);
assert(f->getArgumentList().back().getType() == g.llvm_value_type_ptr->getPointerTo()); assert(f->getArgumentList().back().getType() == g.llvm_value_type_ptr->getPointerTo());
symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val((int64_t)args))); symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val((int64_t)args)));
// printf("loading %%4 with %p\n", (void*)args); // printf("loading %%4 with %p\n", (void*)args);
......
...@@ -25,7 +25,7 @@ class Box; ...@@ -25,7 +25,7 @@ class Box;
class GCVisitor; class GCVisitor;
class LineInfo; class LineInfo;
Box* interpretFunction(llvm::Function* f, int nargs, Box* arg1, Box* arg2, Box* arg3, Box** args); Box* interpretFunction(llvm::Function* f, int nargs, Box* closure, Box* arg1, Box* arg2, Box* arg3, Box** args);
void gatherInterpreterRoots(GCVisitor* visitor); void gatherInterpreterRoots(GCVisitor* visitor);
const LineInfo* getLineInfoForInterpretedFrame(void* frame_ptr); const LineInfo* getLineInfoForInterpretedFrame(void* frame_ptr);
......
...@@ -137,6 +137,9 @@ void initGlobalFuncs(GlobalState& g) { ...@@ -137,6 +137,9 @@ void initGlobalFuncs(GlobalState& g) {
assert(vector_type); assert(vector_type);
g.vector_ptr = vector_type->getPointerTo(); g.vector_ptr = vector_type->getPointerTo();
g.llvm_closure_type_ptr = g.stdlib_module->getTypeByName("class.pyston::BoxedClosure")->getPointerTo();
assert(g.llvm_closure_type_ptr);
#define GET(N) g.funcs.N = getFunc((void*)N, STRINGIFY(N)) #define GET(N) g.funcs.N = getFunc((void*)N, STRINGIFY(N))
g.funcs.printf = addFunc((void*)printf, g.i8_ptr, true); g.funcs.printf = addFunc((void*)printf, g.i8_ptr, true);
...@@ -161,6 +164,7 @@ void initGlobalFuncs(GlobalState& g) { ...@@ -161,6 +164,7 @@ void initGlobalFuncs(GlobalState& g) {
GET(createList); GET(createList);
GET(createDict); GET(createDict);
GET(createSlice); GET(createSlice);
GET(createClosure);
GET(getattr); GET(getattr);
GET(setattr); GET(setattr);
......
...@@ -32,7 +32,7 @@ struct GlobalFuncs { ...@@ -32,7 +32,7 @@ struct GlobalFuncs {
llvm::Value* boxInt, *unboxInt, *boxFloat, *unboxFloat, *boxStringPtr, *boxCLFunction, *unboxCLFunction, llvm::Value* boxInt, *unboxInt, *boxFloat, *unboxFloat, *boxStringPtr, *boxCLFunction, *unboxCLFunction,
*boxInstanceMethod, *boxBool, *unboxBool, *createTuple, *createDict, *createList, *createSlice, *boxInstanceMethod, *boxBool, *unboxBool, *createTuple, *createDict, *createList, *createSlice,
*createUserClass; *createUserClass, *createClosure;
llvm::Value* getattr, *setattr, *print, *nonzero, *binop, *compare, *augbinop, *unboxedLen, *getitem, *getclsattr, llvm::Value* getattr, *setattr, *print, *nonzero, *binop, *compare, *augbinop, *unboxedLen, *getitem, *getclsattr,
*getGlobal, *setitem, *delitem, *unaryop, *import, *repr, *isinstance; *getGlobal, *setitem, *delitem, *unaryop, *import, *repr, *isinstance;
......
...@@ -165,6 +165,7 @@ struct FunctionSpecialization { ...@@ -165,6 +165,7 @@ struct FunctionSpecialization {
: rtn_type(rtn_type), arg_types(arg_types) {} : rtn_type(rtn_type), arg_types(arg_types) {}
}; };
class BoxedClosure;
struct CompiledFunction { struct CompiledFunction {
private: private:
public: public:
...@@ -176,6 +177,7 @@ public: ...@@ -176,6 +177,7 @@ public:
union { union {
Box* (*call)(Box*, Box*, Box*, Box**); Box* (*call)(Box*, Box*, Box*, Box**);
Box* (*closure_call)(BoxedClosure*, Box*, Box*, Box*, Box**);
void* code; void* code;
}; };
llvm::Value* llvm_code; // the llvm callable. llvm::Value* llvm_code; // the llvm callable.
......
...@@ -422,7 +422,6 @@ static Block** freeChain(Block** head) { ...@@ -422,7 +422,6 @@ static Block** freeChain(Block** head) {
} }
void Heap::freeUnmarked() { void Heap::freeUnmarked() {
Timer _t("looking at the thread caches");
thread_caches.forEachValue([this](ThreadBlockCache* cache) { thread_caches.forEachValue([this](ThreadBlockCache* cache) {
for (int bidx = 0; bidx < NUM_BUCKETS; bidx++) { for (int bidx = 0; bidx < NUM_BUCKETS; bidx++) {
Block* h = cache->cache_free_heads[bidx]; Block* h = cache->cache_free_heads[bidx];
...@@ -452,7 +451,6 @@ void Heap::freeUnmarked() { ...@@ -452,7 +451,6 @@ void Heap::freeUnmarked() {
} }
} }
}); });
_t.end();
for (int bidx = 0; bidx < NUM_BUCKETS; bidx++) { for (int bidx = 0; bidx < NUM_BUCKETS; bidx++) {
Block** chain_end = freeChain(&heads[bidx]); Block** chain_end = freeChain(&heads[bidx]);
......
...@@ -57,6 +57,7 @@ void force() { ...@@ -57,6 +57,7 @@ void force() {
FORCE(createList); FORCE(createList);
FORCE(createSlice); FORCE(createSlice);
FORCE(createUserClass); FORCE(createUserClass);
FORCE(createClosure);
FORCE(getattr); FORCE(getattr);
FORCE(setattr); FORCE(setattr);
......
...@@ -795,6 +795,26 @@ Box* getattr_internal(Box* obj, const std::string& attr, bool check_cls, bool al ...@@ -795,6 +795,26 @@ Box* getattr_internal(Box* obj, const std::string& attr, bool check_cls, bool al
} }
} }
// TODO closures should get their own treatment, but now just piggy-back on the
// normal hidden-class IC logic.
// Can do better since we don't need to guard on the cls (always going to be closure)
if (obj->cls == closure_cls) {
BoxedClosure* closure = static_cast<BoxedClosure*>(obj);
if (closure->parent) {
if (rewrite_args) {
rewrite_args->obj = rewrite_args->obj.getAttr(offsetof(BoxedClosure, parent), -1);
}
if (rewrite_args2) {
rewrite_args2->obj
= rewrite_args2->obj.getAttr(offsetof(BoxedClosure, parent), RewriterVarUsage2::Kill);
}
return getattr_internal(closure->parent, attr, false, false, rewrite_args, rewrite_args2);
}
raiseExcHelper(NameError, "free variable '%s' referenced before assignment in enclosing scope", attr.c_str());
}
if (allow_custom) { if (allow_custom) {
// Don't need to pass icentry args, since we special-case __getattribtue__ and __getattr__ to use // Don't need to pass icentry args, since we special-case __getattribtue__ and __getattr__ to use
// invalidation rather than guards // invalidation rather than guards
...@@ -1655,6 +1675,7 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar ...@@ -1655,6 +1675,7 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar
int num_output_args = f->numReceivedArgs(); int num_output_args = f->numReceivedArgs();
int num_passed_args = argspec.totalPassed(); int num_passed_args = argspec.totalPassed();
BoxedClosure* closure = func->closure;
if (argspec.has_starargs || argspec.has_kwargs || f->takes_kwargs) if (argspec.has_starargs || argspec.has_kwargs || f->takes_kwargs)
rewrite_args = NULL; rewrite_args = NULL;
...@@ -1690,14 +1711,21 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar ...@@ -1690,14 +1711,21 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar
} }
if (rewrite_args) { if (rewrite_args) {
int closure_indicator = closure ? 1 : 0;
if (num_passed_args >= 1) if (num_passed_args >= 1)
rewrite_args->arg1 = rewrite_args->arg1.move(0); rewrite_args->arg1 = rewrite_args->arg1.move(0 + closure_indicator);
if (num_passed_args >= 2) if (num_passed_args >= 2)
rewrite_args->arg2 = rewrite_args->arg2.move(1); rewrite_args->arg2 = rewrite_args->arg2.move(1 + closure_indicator);
if (num_passed_args >= 3) if (num_passed_args >= 3)
rewrite_args->arg3 = rewrite_args->arg3.move(2); rewrite_args->arg3 = rewrite_args->arg3.move(2 + closure_indicator);
if (num_passed_args >= 4) if (num_passed_args >= 4)
rewrite_args->args = rewrite_args->args.move(3); rewrite_args->args = rewrite_args->args.move(3 + closure_indicator);
// TODO this kind of embedded reference needs to be tracked by the GC somehow?
// Or maybe it's ok, since we've guarded on the function object?
if (closure)
rewrite_args->rewriter->loadConst(0, (intptr_t)closure);
// We might have trouble if we have more output args than input args, // We might have trouble if we have more output args than input args,
// such as if we need more space to pass defaults. // such as if we need more space to pass defaults.
...@@ -1898,7 +1926,7 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar ...@@ -1898,7 +1926,7 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar
assert(chosen_cf->is_interpreted == (chosen_cf->code == NULL)); assert(chosen_cf->is_interpreted == (chosen_cf->code == NULL));
if (chosen_cf->is_interpreted) { if (chosen_cf->is_interpreted) {
return interpretFunction(chosen_cf->func, num_output_args, oarg1, oarg2, oarg3, oargs); return interpretFunction(chosen_cf->func, num_output_args, func->closure, oarg1, oarg2, oarg3, oargs);
} else { } else {
if (rewrite_args) { if (rewrite_args) {
rewrite_args->rewriter->addDependenceOn(chosen_cf->dependent_callsites); rewrite_args->rewriter->addDependenceOn(chosen_cf->dependent_callsites);
...@@ -1908,7 +1936,11 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar ...@@ -1908,7 +1936,11 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar
rewrite_args->out_rtn = var; rewrite_args->out_rtn = var;
rewrite_args->out_success = true; rewrite_args->out_success = true;
} }
return chosen_cf->call(oarg1, oarg2, oarg3, oargs);
if (closure)
return chosen_cf->closure_call(closure, oarg1, oarg2, oarg3, oargs);
else
return chosen_cf->call(oarg1, oarg2, oarg3, oargs);
} }
} }
......
...@@ -69,6 +69,7 @@ extern "C" void checkUnpackingLength(i64 expected, i64 given); ...@@ -69,6 +69,7 @@ extern "C" void checkUnpackingLength(i64 expected, i64 given);
extern "C" void assertNameDefined(bool b, const char* name); extern "C" void assertNameDefined(bool b, const char* name);
extern "C" void assertFail(BoxedModule* inModule, Box* msg); extern "C" void assertFail(BoxedModule* inModule, Box* msg);
extern "C" bool isSubclass(BoxedClass* child, BoxedClass* parent); extern "C" bool isSubclass(BoxedClass* child, BoxedClass* parent);
extern "C" BoxedClosure* createClosure(BoxedClosure* parent_closure);
class BinopRewriteArgs; class BinopRewriteArgs;
extern "C" Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args); extern "C" Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args);
......
...@@ -65,7 +65,7 @@ llvm::iterator_range<BoxIterator> Box::pyElements() { ...@@ -65,7 +65,7 @@ llvm::iterator_range<BoxIterator> Box::pyElements() {
} }
extern "C" BoxedFunction::BoxedFunction(CLFunction* f) extern "C" BoxedFunction::BoxedFunction(CLFunction* f)
: Box(&function_flavor, function_cls), f(f), ndefaults(0), defaults(NULL) { : Box(&function_flavor, function_cls), f(f), closure(NULL), ndefaults(0), defaults(NULL) {
if (f->source) { if (f->source) {
assert(f->source->ast); assert(f->source->ast);
// this->giveAttr("__name__", boxString(&f->source->ast->name)); // this->giveAttr("__name__", boxString(&f->source->ast->name));
...@@ -78,13 +78,15 @@ extern "C" BoxedFunction::BoxedFunction(CLFunction* f) ...@@ -78,13 +78,15 @@ extern "C" BoxedFunction::BoxedFunction(CLFunction* f)
assert(f->num_defaults == ndefaults); assert(f->num_defaults == ndefaults);
} }
extern "C" BoxedFunction::BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults) extern "C" BoxedFunction::BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults, BoxedClosure* closure)
: Box(&function_flavor, function_cls), f(f), ndefaults(0), defaults(NULL) { : Box(&function_flavor, function_cls), f(f), closure(closure), ndefaults(0), defaults(NULL) {
// make sure to initialize defaults first, since the GC behavior is triggered by ndefaults, if (defaults.size()) {
// and a GC can happen within this constructor: // make sure to initialize defaults first, since the GC behavior is triggered by ndefaults,
this->defaults = new (defaults.size()) GCdArray(); // and a GC can happen within this constructor:
memcpy(this->defaults->elts, defaults.begin(), defaults.size() * sizeof(Box*)); this->defaults = new (defaults.size()) GCdArray();
this->ndefaults = defaults.size(); memcpy(this->defaults->elts, defaults.begin(), defaults.size() * sizeof(Box*));
this->ndefaults = defaults.size();
}
if (f->source) { if (f->source) {
assert(f->source->ast); assert(f->source->ast);
...@@ -104,6 +106,9 @@ extern "C" void functionGCHandler(GCVisitor* v, void* p) { ...@@ -104,6 +106,9 @@ extern "C" void functionGCHandler(GCVisitor* v, void* p) {
BoxedFunction* f = (BoxedFunction*)p; BoxedFunction* f = (BoxedFunction*)p;
if (f->closure)
v->visit(f->closure);
// It's ok for f->defaults to be NULL here even if f->ndefaults isn't, // It's ok for f->defaults to be NULL here even if f->ndefaults isn't,
// since we could be collecting from inside a BoxedFunction constructor // since we could be collecting from inside a BoxedFunction constructor
if (f->ndefaults) { if (f->ndefaults) {
...@@ -130,8 +135,11 @@ std::string BoxedModule::name() { ...@@ -130,8 +135,11 @@ std::string BoxedModule::name() {
} }
} }
extern "C" Box* boxCLFunction(CLFunction* f) { extern "C" Box* boxCLFunction(CLFunction* f, BoxedClosure* closure) {
return new BoxedFunction(f); if (closure)
assert(closure->cls == closure_cls);
return new BoxedFunction(f, {}, closure);
} }
extern "C" CLFunction* unboxCLFunction(Box* b) { extern "C" CLFunction* unboxCLFunction(Box* b) {
...@@ -245,9 +253,18 @@ extern "C" void conservativeGCHandler(GCVisitor* v, void* p) { ...@@ -245,9 +253,18 @@ extern "C" void conservativeGCHandler(GCVisitor* v, void* p) {
v->visitPotentialRange(start, start + (size / sizeof(void*))); v->visitPotentialRange(start, start + (size / sizeof(void*)));
} }
extern "C" void closureGCHandler(GCVisitor* v, void* p) {
boxGCHandler(v, p);
BoxedClosure* c = (BoxedClosure*)v;
if (c->parent)
v->visit(c->parent);
}
extern "C" { extern "C" {
BoxedClass* object_cls, *type_cls, *none_cls, *bool_cls, *int_cls, *float_cls, *str_cls, *function_cls, BoxedClass* object_cls, *type_cls, *none_cls, *bool_cls, *int_cls, *float_cls, *str_cls, *function_cls,
*instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *member_cls; *instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *member_cls,
*closure_cls;
const ObjectFlavor object_flavor(&boxGCHandler, NULL); const ObjectFlavor object_flavor(&boxGCHandler, NULL);
const ObjectFlavor type_flavor(&typeGCHandler, NULL); const ObjectFlavor type_flavor(&typeGCHandler, NULL);
...@@ -265,6 +282,7 @@ const ObjectFlavor dict_flavor(&dictGCHandler, NULL); ...@@ -265,6 +282,7 @@ const ObjectFlavor dict_flavor(&dictGCHandler, NULL);
const ObjectFlavor tuple_flavor(&tupleGCHandler, NULL); const ObjectFlavor tuple_flavor(&tupleGCHandler, NULL);
const ObjectFlavor file_flavor(&boxGCHandler, NULL); const ObjectFlavor file_flavor(&boxGCHandler, NULL);
const ObjectFlavor member_flavor(&boxGCHandler, NULL); const ObjectFlavor member_flavor(&boxGCHandler, NULL);
const ObjectFlavor closure_flavor(&closureGCHandler, NULL);
const AllocationKind untracked_kind(NULL, NULL); const AllocationKind untracked_kind(NULL, NULL);
const AllocationKind hc_kind(&hcGCHandler, NULL); const AllocationKind hc_kind(&hcGCHandler, NULL);
...@@ -356,6 +374,12 @@ extern "C" Box* createSlice(Box* start, Box* stop, Box* step) { ...@@ -356,6 +374,12 @@ extern "C" Box* createSlice(Box* start, Box* stop, Box* step) {
return rtn; return rtn;
} }
extern "C" BoxedClosure* createClosure(BoxedClosure* parent_closure) {
if (parent_closure)
assert(parent_closure->cls == closure_cls);
return new BoxedClosure(parent_closure);
}
extern "C" Box* sliceNew(Box* cls, Box* start, Box* stop, Box** args) { extern "C" Box* sliceNew(Box* cls, Box* start, Box* stop, Box** args) {
RELEASE_ASSERT(cls == slice_cls, ""); RELEASE_ASSERT(cls == slice_cls, "");
Box* step = args[0]; Box* step = args[0];
...@@ -483,6 +507,7 @@ void setupRuntime() { ...@@ -483,6 +507,7 @@ void setupRuntime() {
file_cls = new BoxedClass(object_cls, 0, sizeof(BoxedFile), false); file_cls = new BoxedClass(object_cls, 0, sizeof(BoxedFile), false);
set_cls = new BoxedClass(object_cls, 0, sizeof(BoxedSet), false); set_cls = new BoxedClass(object_cls, 0, sizeof(BoxedSet), false);
member_cls = new BoxedClass(object_cls, 0, sizeof(BoxedMemberDescriptor), false); member_cls = new BoxedClass(object_cls, 0, sizeof(BoxedMemberDescriptor), false);
closure_cls = new BoxedClass(object_cls, offsetof(BoxedClosure, attrs), sizeof(BoxedClosure), false);
STR = typeFromClass(str_cls); STR = typeFromClass(str_cls);
BOXED_INT = typeFromClass(int_cls); BOXED_INT = typeFromClass(int_cls);
...@@ -524,6 +549,9 @@ void setupRuntime() { ...@@ -524,6 +549,9 @@ void setupRuntime() {
member_cls->giveAttr("__name__", boxStrConstant("member")); member_cls->giveAttr("__name__", boxStrConstant("member"));
member_cls->freeze(); member_cls->freeze();
closure_cls->giveAttr("__name__", boxStrConstant("closure"));
closure_cls->freeze();
setupBool(); setupBool();
setupInt(); setupInt();
setupFloat(); setupFloat();
......
...@@ -27,6 +27,7 @@ class BoxedList; ...@@ -27,6 +27,7 @@ class BoxedList;
class BoxedDict; class BoxedDict;
class BoxedTuple; class BoxedTuple;
class BoxedFile; class BoxedFile;
class BoxedClosure;
void setupInt(); void setupInt();
void teardownInt(); void teardownInt();
...@@ -62,12 +63,13 @@ BoxedList* getSysPath(); ...@@ -62,12 +63,13 @@ BoxedList* getSysPath();
extern "C" { extern "C" {
extern BoxedClass* object_cls, *type_cls, *bool_cls, *int_cls, *float_cls, *str_cls, *function_cls, *none_cls, extern BoxedClass* object_cls, *type_cls, *bool_cls, *int_cls, *float_cls, *str_cls, *function_cls, *none_cls,
*instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *xrange_cls, *member_cls; *instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *xrange_cls, *member_cls,
*closure_cls;
} }
extern "C" { extern "C" {
extern const ObjectFlavor object_flavor, type_flavor, bool_flavor, int_flavor, float_flavor, str_flavor, extern const ObjectFlavor object_flavor, type_flavor, bool_flavor, int_flavor, float_flavor, str_flavor,
function_flavor, none_flavor, instancemethod_flavor, list_flavor, slice_flavor, module_flavor, dict_flavor, function_flavor, none_flavor, instancemethod_flavor, list_flavor, slice_flavor, module_flavor, dict_flavor,
tuple_flavor, file_flavor, xrange_flavor, member_flavor; tuple_flavor, file_flavor, xrange_flavor, member_flavor, closure_flavor;
} }
extern "C" { extern Box* None, *NotImplemented, *True, *False; } extern "C" { extern Box* None, *NotImplemented, *True, *False; }
extern "C" { extern "C" {
...@@ -86,7 +88,7 @@ Box* boxString(const std::string& s); ...@@ -86,7 +88,7 @@ Box* boxString(const std::string& s);
extern "C" BoxedString* boxStrConstant(const char* chars); extern "C" BoxedString* boxStrConstant(const char* chars);
extern "C" void listAppendInternal(Box* self, Box* v); extern "C" void listAppendInternal(Box* self, Box* v);
extern "C" void listAppendArrayInternal(Box* self, Box** v, int nelts); extern "C" void listAppendArrayInternal(Box* self, Box** v, int nelts);
extern "C" Box* boxCLFunction(CLFunction* f); extern "C" Box* boxCLFunction(CLFunction* f, BoxedClosure* closure);
extern "C" CLFunction* unboxCLFunction(Box* b); extern "C" CLFunction* unboxCLFunction(Box* b);
extern "C" Box* createUserClass(std::string* name, Box* base, BoxedModule* parent_module); extern "C" Box* createUserClass(std::string* name, Box* base, BoxedModule* parent_module);
extern "C" double unboxFloat(Box* b); extern "C" double unboxFloat(Box* b);
...@@ -272,12 +274,13 @@ class BoxedFunction : public Box { ...@@ -272,12 +274,13 @@ class BoxedFunction : public Box {
public: public:
HCAttrs attrs; HCAttrs attrs;
CLFunction* f; CLFunction* f;
BoxedClosure* closure;
int ndefaults; int ndefaults;
GCdArray* defaults; GCdArray* defaults;
BoxedFunction(CLFunction* f); BoxedFunction(CLFunction* f);
BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults); BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults, BoxedClosure* closure = NULL);
}; };
class BoxedModule : public Box { class BoxedModule : public Box {
...@@ -307,6 +310,15 @@ public: ...@@ -307,6 +310,15 @@ public:
BoxedMemberDescriptor(MemberType type, int offset) : Box(&member_flavor, member_cls), type(type), offset(offset) {} BoxedMemberDescriptor(MemberType type, int offset) : Box(&member_flavor, member_cls), type(type), offset(offset) {}
}; };
// TODO is there any particular reason to make this a Box, ie a python-level object?
class BoxedClosure : public Box {
public:
HCAttrs attrs;
BoxedClosure* parent;
BoxedClosure(BoxedClosure* parent) : Box(&closure_flavor, closure_cls), parent(parent) {}
};
extern "C" void boxGCHandler(GCVisitor* v, void* p); extern "C" void boxGCHandler(GCVisitor* v, void* p);
Box* exceptionNew1(BoxedClass* cls); Box* exceptionNew1(BoxedClass* cls);
......
# closure tests
# simple closure:
def make_adder(x):
def g(y):
return x + y
return g
a = make_adder(1)
print a(5)
print map(a, range(5))
def make_adder2(x2):
# f takes a closure since it needs to get a reference to x2 to pass to g
def f():
def g(y2):
return x2 + y2
return g
r = f()
print r(1)
x2 += 1
print r(1)
return r
a = make_adder2(2)
print a(5)
print map(a, range(5))
def make_addr3(x3):
# this function doesn't take a closure:
def f1():
return 2
f1()
def g(y3):
return x3 + y3
# this function does a different closure
def f2():
print f1()
f2()
return g
print make_addr3(10)(2)
def bad_addr3(_x):
if 0:
x3 = _
def g(y3):
return x3 + y3
return g
print bad_addr3(1)(2)
# expected: fail
# - varargs, kwarg
# Regression test: make sure that args and kw get properly treated as potentially-saved-in-closure
def f1(*args):
def inner():
return args
return inner
print f1()()
print f1(1, 2, 3, "a")()
def f2(**kw):
def inner():
return kw
return inner
print f2()()
print f2(a=1)()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment