diff --git a/src/analysis/scoping_analysis.cpp b/src/analysis/scoping_analysis.cpp index aae431b75313d82e9cc0899e97b3ba2950dfd745..b8b1f289b12bb1ab6facc883f8db690c0152d9f6 100644 --- a/src/analysis/scoping_analysis.cpp +++ b/src/analysis/scoping_analysis.cpp @@ -76,15 +76,9 @@ public: virtual ScopeInfo* getParent() { return parent; } - virtual bool createsClosure() { - assert(0); - return usage->referenced_from_nested.size() > 0; - } + virtual bool createsClosure() { return usage->referenced_from_nested.size() > 0; } - virtual bool takesClosure() { - assert(0); - return false; - } + virtual bool takesClosure() { return usage->got_from_closure.size() > 0; } virtual bool refersToGlobal(const std::string& name) { // HAX @@ -143,7 +137,6 @@ public: return true; } - virtual bool visit_arguments(AST_arguments* node) { return false; } virtual bool visit_assert(AST_Assert* node) { return false; } virtual bool visit_assign(AST_Assign* node) { return false; } virtual bool visit_augassign(AST_AugAssign* node) { return false; } @@ -201,8 +194,15 @@ public: virtual bool visit_classdef(AST_ClassDef* node) { if (node == orig_node) { - return false; + for (AST_stmt* s : node->body) + s->accept(this); + return true; } else { + for (auto* e : node->bases) + e->accept(this); + for (auto* e : node->decorator_list) + e->accept(this); + doWrite(node->name); (*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur); collect(node, map); @@ -212,8 +212,21 @@ public: virtual bool visit_functiondef(AST_FunctionDef* node) { if (node == orig_node) { - return false; + for (AST_expr* e : node->args->args) + e->accept(this); + if (node->args->vararg.size()) + doWrite(node->args->vararg); + if (node->args->kwarg.size()) + doWrite(node->args->kwarg); + for (AST_stmt* s : node->body) + s->accept(this); + return true; } else { + for (auto* e : node->args->defaults) + e->accept(this); + for (auto* e : node->decorator_list) + e->accept(this); + doWrite(node->name); (*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur); collect(node, map); @@ -278,6 +291,8 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { if (usage->written.count(*it2)) continue; + std::vector<ScopeNameUsage*> intermediate_parents; + ScopeNameUsage* parent = usage->parent; while (parent) { if (parent->node->type == AST_TYPE::ClassDef) { @@ -287,8 +302,15 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { } else if (parent->written.count(*it2)) { usage->got_from_closure.insert(*it2); parent->referenced_from_nested.insert(*it2); + + for (ScopeNameUsage* iparent : intermediate_parents) { + iparent->referenced_from_nested.insert(*it2); + iparent->got_from_closure.insert(*it2); + } + break; } else { + intermediate_parents.push_back(parent); parent = parent->parent; } } @@ -320,19 +342,6 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { } ScopeInfo* ScopingAnalysis::analyzeSubtree(AST* node) { -#ifndef NDEBUG - std::vector<AST*> flattened; - flatten(parent_module->body, flattened, false); - bool found = 0; - for (AST* n : flattened) { - if (n == node) { - found = true; - break; - } - } - assert(found); -#endif - NameUsageMap usages; usages[node] = new ScopeNameUsage(node, NULL); NameCollectorVisitor::collect(node, &usages); diff --git a/src/codegen/codegen.h b/src/codegen/codegen.h index 761ac93dd4af7dfdbf13b3ff9c45018140cb8478..bed65b2bab64e547a2e6d8a9663bdc590ecf148e 100644 --- a/src/codegen/codegen.h +++ b/src/codegen/codegen.h @@ -70,7 +70,7 @@ struct GlobalState { llvm::Type* llvm_flavor_type, *llvm_flavor_type_ptr; llvm::Type* llvm_opaque_type; llvm::Type* llvm_str_type_ptr; - llvm::Type* llvm_clfunction_type_ptr; + llvm::Type* llvm_clfunction_type_ptr, *llvm_closure_type_ptr; llvm::Type* llvm_module_type_ptr, *llvm_bool_type_ptr; llvm::Type* i1, *i8, *i8_ptr, *i32, *i64, *void_, *double_; llvm::Type* vector_ptr; diff --git a/src/codegen/compvars.cpp b/src/codegen/compvars.cpp index 3fc6b98c4922c1f00a1f819d9b837dbe0f528e8f..0931587bcedeb7811654a5119e44bd841ed468e8 100644 --- a/src/codegen/compvars.cpp +++ b/src/codegen/compvars.cpp @@ -518,11 +518,24 @@ ConcreteCompilerVariable* UnknownType::nonzero(IREmitter& emitter, const OpInfo& return new ConcreteCompilerVariable(BOOL, rtn_val, true); } -CompilerVariable* makeFunction(IREmitter& emitter, CLFunction* f) { +CompilerVariable* makeFunction(IREmitter& emitter, CLFunction* f, CompilerVariable* closure) { // Unlike the CLFunction*, which can be shared between recompilations, the Box* around it // should be created anew every time the functiondef is encountered - llvm::Value* boxed - = emitter.getBuilder()->CreateCall(g.funcs.boxCLFunction, embedConstantPtr(f, g.llvm_clfunction_type_ptr)); + + llvm::Value* closure_v; + ConcreteCompilerVariable* converted = NULL; + if (closure) { + converted = closure->makeConverted(emitter, closure->getConcreteType()); + closure_v = converted->getValue(); + } else { + closure_v = embedConstantPtr(nullptr, g.llvm_closure_type_ptr); + } + + llvm::Value* boxed = emitter.getBuilder()->CreateCall2(g.funcs.boxCLFunction, + embedConstantPtr(f, g.llvm_clfunction_type_ptr), closure_v); + + if (converted) + converted->decvref(emitter); return new ConcreteCompilerVariable(typeFromClass(function_cls), boxed, true); } @@ -1147,6 +1160,30 @@ public: std::unordered_map<BoxedClass*, NormalObjectType*> NormalObjectType::made; ConcreteCompilerType* STR, *BOXED_INT, *BOXED_FLOAT, *BOXED_BOOL, *NONE; +class ClosureType : public ConcreteCompilerType { +public: + llvm::Type* llvmType() { return g.llvm_closure_type_ptr; } + std::string debugName() { return "closure"; } + + CompilerVariable* getattr(IREmitter& emitter, const OpInfo& info, ConcreteCompilerVariable* var, + const std::string* attr, bool cls_only) { + assert(!cls_only); + llvm::Value* bitcast = emitter.getBuilder()->CreateBitCast(var->getValue(), g.llvm_value_type_ptr); + return ConcreteCompilerVariable(UNKNOWN, bitcast, true).getattr(emitter, info, attr, cls_only); + } + + void setattr(IREmitter& emitter, const OpInfo& info, ConcreteCompilerVariable* var, const std::string* attr, + CompilerVariable* v) { + llvm::Value* bitcast = emitter.getBuilder()->CreateBitCast(var->getValue(), g.llvm_value_type_ptr); + ConcreteCompilerVariable(UNKNOWN, bitcast, true).setattr(emitter, info, attr, v); + } + + virtual ConcreteCompilerType* getConcreteType() { return this; } + // Shouldn't call this: + virtual ConcreteCompilerType* getBoxType() { RELEASE_ASSERT(0, ""); } +} _CLOSURE; +ConcreteCompilerType* CLOSURE = &_CLOSURE; + class StrConstantType : public ValuedCompilerType<std::string*> { public: std::string debugName() { return "str_constant"; } diff --git a/src/codegen/compvars.h b/src/codegen/compvars.h index 0ccc9b629f10deb9b2b86b4e51305133ecec0a92..915ef5ac02546aa446a303531d15872113b80b2e 100644 --- a/src/codegen/compvars.h +++ b/src/codegen/compvars.h @@ -29,7 +29,7 @@ class CompilerType; class IREmitter; extern ConcreteCompilerType* INT, *BOXED_INT, *FLOAT, *BOXED_FLOAT, *VOID, *UNKNOWN, *BOOL, *STR, *NONE, *LIST, *SLICE, - *MODULE, *DICT, *BOOL, *BOXED_BOOL, *BOXED_TUPLE, *SET; + *MODULE, *DICT, *BOOL, *BOXED_BOOL, *BOXED_TUPLE, *SET, *CLOSURE; extern CompilerType* UNDEF; class CompilerType { @@ -316,7 +316,7 @@ ConcreteCompilerVariable* makeInt(int64_t); ConcreteCompilerVariable* makeFloat(double); ConcreteCompilerVariable* makeBool(bool); CompilerVariable* makeStr(std::string*); -CompilerVariable* makeFunction(IREmitter& emitter, CLFunction*); +CompilerVariable* makeFunction(IREmitter& emitter, CLFunction*, CompilerVariable* closure); ConcreteCompilerVariable* undefVariable(); CompilerVariable* makeTuple(const std::vector<CompilerVariable*>& elts); diff --git a/src/codegen/irgen.cpp b/src/codegen/irgen.cpp index e20277a56f404e500e6cb7a203fdcb1a61735bc2..fa24bc01c48fdc10964ffd1ce76d390933580b9e 100644 --- a/src/codegen/irgen.cpp +++ b/src/codegen/irgen.cpp @@ -563,7 +563,8 @@ static void emitBBs(IRGenState* irstate, const char* bb_type, GuardList& out_gua emitter->getBuilder()->SetInsertPoint(llvm_entry_blocks[source->cfg->getStartingBlock()]); } - generator->unpackArguments(arg_names, cf->spec->arg_types); + + generator->doFunctionEntry(arg_names, cf->spec->arg_types); // Function-entry safepoint: // TODO might be more efficient to do post-call safepoints? @@ -915,6 +916,9 @@ CompiledFunction* doCompile(SourceInfo* source, const OSREntryDescriptor* entry_ ASSERT(nargs == spec->arg_types.size(), "%d %ld", nargs, spec->arg_types.size()); std::vector<llvm::Type*> llvm_arg_types; + if (source->scoping->getScopeInfoForNode(source->ast)->takesClosure()) + llvm_arg_types.push_back(g.llvm_closure_type_ptr); + if (entry_descriptor == NULL) { for (int i = 0; i < nargs; i++) { if (i == 3) { diff --git a/src/codegen/irgen/hooks.cpp b/src/codegen/irgen/hooks.cpp index 7fe8e14dcbe775af42fa7c5dcf6e6a83edaf1b81..a3c62a13dca136549a7df77f1f1a80c377054fe4 100644 --- a/src/codegen/irgen/hooks.cpp +++ b/src/codegen/irgen/hooks.cpp @@ -262,7 +262,7 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) { _t.end(); if (cf->is_interpreted) - interpretFunction(cf->func, 0, NULL, NULL, NULL, NULL); + interpretFunction(cf->func, 0, NULL, NULL, NULL, NULL, NULL); else ((void (*)())cf->code)(); } diff --git a/src/codegen/irgen/irgenerator.cpp b/src/codegen/irgen/irgenerator.cpp index 73851017678173fba256e2bee3b8a7918c5b442b..09ce3725a35d19b88a3f490af713053d2090f9d2 100644 --- a/src/codegen/irgen/irgenerator.cpp +++ b/src/codegen/irgen/irgenerator.cpp @@ -201,6 +201,9 @@ static std::vector<const std::string*>* getKeywordNameStorage(AST_Call* node) { return rtn; } +static const std::string CREATED_CLOSURE_NAME = "!created_closure"; +static const std::string PASSED_CLOSURE_NAME = "!passed_closure"; + class IRGeneratorImpl : public IRGenerator { private: IRGenState* irstate; @@ -790,7 +793,9 @@ private: CompilerVariable* evalName(AST_Name* node, ExcInfo exc_info) { assert(state != PARTIAL); - if (irstate->getScopeInfo()->refersToGlobal(node->id)) { + auto scope_info = irstate->getScopeInfo(); + + if (scope_info->refersToGlobal(node->id)) { if (1) { // Method 1: calls into the runtime getGlobal(), which handles things like falling back to builtins // or raising the correct error message. @@ -825,6 +830,13 @@ private: mod->decvref(emitter); return attr; } + } else if (scope_info->refersToClosure(node->id)) { + assert(scope_info->takesClosure()); + + CompilerVariable* closure = _getFake(PASSED_CLOSURE_NAME, false); + assert(closure); + + return closure->getattr(emitter, getEmptyOpInfo(exc_info), &node->id, false); } else { if (symbol_table.find(node->id) == symbol_table.end()) { // TODO should mark as DEAD here, though we won't end up setting all the names appropriately @@ -836,7 +848,7 @@ private: } std::string defined_name = _getFakeName("is_defined", node->id.c_str()); - ConcreteCompilerVariable* is_defined = static_cast<ConcreteCompilerVariable*>(_getFake(defined_name, true)); + ConcreteCompilerVariable* is_defined = static_cast<ConcreteCompilerVariable*>(_popFake(defined_name, true)); if (is_defined) { emitter.createCall2(exc_info, g.funcs.assertNameDefined, is_defined->getValue(), getStringConstantPtr(node->id + '\0')); @@ -1203,18 +1215,29 @@ private: symbol_table.erase(name); return rtn; } + CompilerVariable* _getFake(std::string name, bool allow_missing = false) { assert(name[0] == '!'); CompilerVariable* rtn = symbol_table[name]; if (!allow_missing) assert(rtn != NULL); + return rtn; + } + CompilerVariable* _popFake(std::string name, bool allow_missing = false) { + CompilerVariable* rtn = _getFake(name, allow_missing); symbol_table.erase(name); return rtn; } void _doSet(const std::string& name, CompilerVariable* val, ExcInfo exc_info) { assert(name != "None"); - if (irstate->getScopeInfo()->refersToGlobal(name)) { + + auto scope_info = irstate->getScopeInfo(); + assert(!scope_info->refersToClosure(name)); + + if (scope_info->refersToGlobal(name)) { + assert(!scope_info->saveInClosure(name)); + // TODO do something special here so that it knows to only emit a monomorphic inline cache? ConcreteCompilerVariable* module = new ConcreteCompilerVariable( MODULE, embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_value_type_ptr), false); @@ -1231,7 +1254,14 @@ private: // Clear out the is_defined name since it is now definitely defined: assert(!startswith(name, "!is_defined")); std::string defined_name = _getFakeName("is_defined", name.c_str()); - _getFake(defined_name, true); + _popFake(defined_name, true); + + if (scope_info->saveInClosure(name)) { + CompilerVariable* closure = _getFake(CREATED_CLOSURE_NAME, false); + assert(closure); + + closure->setattr(emitter, getEmptyOpInfo(ExcInfo::none()), &name, val); + } } } @@ -1351,7 +1381,10 @@ private: if (state == PARTIAL) return; + assert(node->type == AST_TYPE::ClassDef); ScopeInfo* scope_info = irstate->getSourceInfo()->scoping->getScopeInfoForNode(node); + assert(scope_info); + assert(!scope_info->takesClosure()); RELEASE_ASSERT(node->bases.size() == 1, ""); @@ -1376,8 +1409,11 @@ private: continue; } else if (type == AST_TYPE::FunctionDef) { AST_FunctionDef* fdef = ast_cast<AST_FunctionDef>(node->body[i]); + ScopeInfo* scope_info = irstate->getSourceInfo()->scoping->getScopeInfoForNode(fdef); + CLFunction* cl = this->_wrapFunction(fdef); - CompilerVariable* func = makeFunction(emitter, cl); + assert(!scope_info->takesClosure()); + CompilerVariable* func = makeFunction(emitter, cl, NULL); cls->setattr(emitter, getEmptyOpInfo(exc_info), &fdef->name, func); func->decvref(emitter); } else { @@ -1448,12 +1484,22 @@ private: return cl; } - void doFunction(AST_FunctionDef* node, ExcInfo exc_info) { + void doFunctionDef(AST_FunctionDef* node, ExcInfo exc_info) { if (state == PARTIAL) return; + assert(!node->args->defaults.size()); + CLFunction* cl = this->_wrapFunction(node); - CompilerVariable* func = makeFunction(emitter, cl); + + CompilerVariable* created_closure = NULL; + ScopeInfo* scope_info = irstate->getSourceInfo()->scoping->getScopeInfoForNode(node); + if (scope_info->takesClosure()) { + created_closure = _getFake(CREATED_CLOSURE_NAME, false); + assert(created_closure); + } + + CompilerVariable* func = makeFunction(emitter, cl, created_closure); // llvm::Type* boxCLFuncArgType = g.funcs.boxCLFunction->arg_begin()->getType(); // llvm::Value *boxed = emitter.getBuilder()->CreateCall(g.funcs.boxCLFunction, embedConstantPtr(cl, @@ -1829,7 +1875,7 @@ private: doExpr(ast_cast<AST_Expr>(node), exc_info); break; case AST_TYPE::FunctionDef: - doFunction(ast_cast<AST_FunctionDef>(node), exc_info); + doFunctionDef(ast_cast<AST_FunctionDef>(node), exc_info); break; // case AST_TYPE::If: // doIf(ast_cast<AST_If>(node)); @@ -1892,6 +1938,10 @@ private: var->decvref(emitter); } + bool allowableFakeEndingSymbol(const std::string& name) { + return startswith(name, "!is_defined") || name == PASSED_CLOSURE_NAME || name == CREATED_CLOSURE_NAME; + } + void endBlock(State new_state) { assert(state == RUNNING); @@ -1901,7 +1951,7 @@ private: ScopeInfo* scope_info = irstate->getScopeInfo(); for (SymbolTable::iterator it = symbol_table.begin(); it != symbol_table.end();) { - if (startswith(it->first, "!is_defined")) { + if (allowableFakeEndingSymbol(it->first)) { ++it; continue; } @@ -1950,7 +2000,7 @@ private: // printf("defined on this path; "); ConcreteCompilerVariable* is_defined - = static_cast<ConcreteCompilerVariable*>(_getFake(defined_name, true)); + = static_cast<ConcreteCompilerVariable*>(_popFake(defined_name, true)); if (source->phis->isPotentiallyUndefinedAfter(*it, myblock)) { // printf("is potentially undefined later, so marking it defined\n"); @@ -1996,6 +2046,8 @@ public: } if (myblock->successors.size() == 0) { + st->erase(CREATED_CLOSURE_NAME); + st->erase(PASSED_CLOSURE_NAME); assert(st->size() == 0); // shouldn't have anything live if there are no successors! return EndingState(st, phi_st, curblock); } else if (myblock->successors.size() > 1) { @@ -2015,14 +2067,18 @@ public: } for (SymbolTable::iterator it = st->begin(); it != st->end();) { - if (startswith(it->first, "!is_defined") || source->phis->isRequiredAfter(it->first, myblock)) { - assert(it->second->isGrabbed()); + if (allowableFakeEndingSymbol(it->first) || source->phis->isRequiredAfter(it->first, myblock)) { + ASSERT(it->second->isGrabbed(), "%s", it->first.c_str()); assert(it->second->getVrefs() == 1); // this conversion should have already happened... should refactor this. ConcreteCompilerType* ending_type; if (startswith(it->first, "!is_defined")) { assert(it->second->getType() == BOOL); ending_type = BOOL; + } else if (it->first == PASSED_CLOSURE_NAME) { + ending_type = getPassedClosureType(); + } else if (it->first == CREATED_CLOSURE_NAME) { + ending_type = getCreatedClosureType(); } else { ending_type = types->getTypeAtBlockEnd(it->first, myblock); } @@ -2057,14 +2113,43 @@ public: } } - void unpackArguments(const std::vector<AST_expr*>& arg_names, + ConcreteCompilerType* getPassedClosureType() { + // TODO could know the exact closure shape + return CLOSURE; + } + + ConcreteCompilerType* getCreatedClosureType() { + // TODO could know the exact closure shape + return CLOSURE; + } + + void doFunctionEntry(const std::vector<AST_expr*>& arg_names, const std::vector<ConcreteCompilerType*>& arg_types) override { + auto scope_info = irstate->getScopeInfo(); + + llvm::Value* passed_closure = NULL; + llvm::Function::arg_iterator AI = irstate->getLLVMFunction()->arg_begin(); + if (scope_info->takesClosure()) { + passed_closure = AI; + _setFake(PASSED_CLOSURE_NAME, new ConcreteCompilerVariable(getPassedClosureType(), AI, true)); + ++AI; + } + + if (scope_info->createsClosure()) { + if (!passed_closure) + passed_closure = embedConstantPtr(nullptr, g.llvm_closure_type_ptr); + + llvm::Value* new_closure = emitter.getBuilder()->CreateCall(g.funcs.createClosure, passed_closure); + _setFake(CREATED_CLOSURE_NAME, new ConcreteCompilerVariable(getCreatedClosureType(), new_closure, true)); + } + + int i = 0; llvm::Value* argarray = NULL; - for (llvm::Function::arg_iterator AI = irstate->getLLVMFunction()->arg_begin(); - AI != irstate->getLLVMFunction()->arg_end(); AI++, i++) { + for (; AI != irstate->getLLVMFunction()->arg_end(); ++AI, i++) { if (i == 3) { argarray = AI; + assert(++AI == irstate->getLLVMFunction()->arg_end()); break; } loadArgument(arg_names[i], arg_types[i], AI, ExcInfo::none()); diff --git a/src/codegen/irgen/irgenerator.h b/src/codegen/irgen/irgenerator.h index 443537ffaa4527ca25d7049e66c9375f6e0a3310..cc6e3c15dc1384d5450c126ffa94f6032d134308 100644 --- a/src/codegen/irgen/irgenerator.h +++ b/src/codegen/irgen/irgenerator.h @@ -183,8 +183,9 @@ public: virtual ~IRGenerator() {} - virtual void unpackArguments(const std::vector<AST_expr*>& arg_names, + virtual void doFunctionEntry(const std::vector<AST_expr*>& arg_names, const std::vector<ConcreteCompilerType*>& arg_types) = 0; + virtual void giveLocalSymbol(const std::string& name, CompilerVariable* var) = 0; virtual void copySymbolsFrom(SymbolTable* st) = 0; virtual void run(const CFGBlock* block) = 0; diff --git a/src/codegen/llvm_interpreter.cpp b/src/codegen/llvm_interpreter.cpp index 8da9ee889d7ded03abf6adfbffcf136790f7b243..cf5ddc9bf0e2d60b533238f57abeb66b2a590334 100644 --- a/src/codegen/llvm_interpreter.cpp +++ b/src/codegen/llvm_interpreter.cpp @@ -257,7 +257,7 @@ const LineInfo* getLineInfoForInterpretedFrame(void* frame_ptr) { } } -Box* interpretFunction(llvm::Function* f, int nargs, Box* arg1, Box* arg2, Box* arg3, Box** args) { +Box* interpretFunction(llvm::Function* f, int nargs, Box* closure, Box* arg1, Box* arg2, Box* arg3, Box** args) { assert(f); #ifdef TIME_INTERPRETS @@ -280,18 +280,21 @@ Box* interpretFunction(llvm::Function* f, int nargs, Box* arg1, Box* arg2, Box* UnregisterHelper helper(frame_ptr); int arg_num = -1; + int closure_indicator = closure ? 1 : 0; for (llvm::Argument& arg : f->args()) { arg_num++; - if (arg_num == 0) + if (arg_num == 0 && closure) + symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(closure))); + else if (arg_num == 0 + closure_indicator) symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg1))); - else if (arg_num == 1) + else if (arg_num == 1 + closure_indicator) symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg2))); - else if (arg_num == 2) + else if (arg_num == 2 + closure_indicator) symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val(arg3))); else { - assert(arg_num == 3); - assert(f->getArgumentList().size() == 4); + assert(arg_num == 3 + closure_indicator); + assert(f->getArgumentList().size() == 4 + closure_indicator); assert(f->getArgumentList().back().getType() == g.llvm_value_type_ptr->getPointerTo()); symbols.insert(std::make_pair(static_cast<llvm::Value*>(&arg), Val((int64_t)args))); // printf("loading %%4 with %p\n", (void*)args); diff --git a/src/codegen/llvm_interpreter.h b/src/codegen/llvm_interpreter.h index 60ac07e9b6347b2cf198666f94aa0201eda0c4c1..4f88bbcd970953d11524ade9700854ed42db6912 100644 --- a/src/codegen/llvm_interpreter.h +++ b/src/codegen/llvm_interpreter.h @@ -25,7 +25,7 @@ class Box; class GCVisitor; class LineInfo; -Box* interpretFunction(llvm::Function* f, int nargs, Box* arg1, Box* arg2, Box* arg3, Box** args); +Box* interpretFunction(llvm::Function* f, int nargs, Box* closure, Box* arg1, Box* arg2, Box* arg3, Box** args); void gatherInterpreterRoots(GCVisitor* visitor); const LineInfo* getLineInfoForInterpretedFrame(void* frame_ptr); diff --git a/src/codegen/runtime_hooks.cpp b/src/codegen/runtime_hooks.cpp index ee3d6fdcc446a179147f077dfd04d4fd89fad8fc..eec4c44276b8392e2a59698122df385692bf9acc 100644 --- a/src/codegen/runtime_hooks.cpp +++ b/src/codegen/runtime_hooks.cpp @@ -137,6 +137,9 @@ void initGlobalFuncs(GlobalState& g) { assert(vector_type); g.vector_ptr = vector_type->getPointerTo(); + g.llvm_closure_type_ptr = g.stdlib_module->getTypeByName("class.pyston::BoxedClosure")->getPointerTo(); + assert(g.llvm_closure_type_ptr); + #define GET(N) g.funcs.N = getFunc((void*)N, STRINGIFY(N)) g.funcs.printf = addFunc((void*)printf, g.i8_ptr, true); @@ -161,6 +164,7 @@ void initGlobalFuncs(GlobalState& g) { GET(createList); GET(createDict); GET(createSlice); + GET(createClosure); GET(getattr); GET(setattr); diff --git a/src/codegen/runtime_hooks.h b/src/codegen/runtime_hooks.h index 568d202d431a873db7043d0787b1a5d84bfcc3e9..b177fe3ebd355ac0dfab43ba16dbb0eec5902ac9 100644 --- a/src/codegen/runtime_hooks.h +++ b/src/codegen/runtime_hooks.h @@ -32,7 +32,7 @@ struct GlobalFuncs { llvm::Value* boxInt, *unboxInt, *boxFloat, *unboxFloat, *boxStringPtr, *boxCLFunction, *unboxCLFunction, *boxInstanceMethod, *boxBool, *unboxBool, *createTuple, *createDict, *createList, *createSlice, - *createUserClass; + *createUserClass, *createClosure; llvm::Value* getattr, *setattr, *print, *nonzero, *binop, *compare, *augbinop, *unboxedLen, *getitem, *getclsattr, *getGlobal, *setitem, *delitem, *unaryop, *import, *repr, *isinstance; diff --git a/src/core/types.h b/src/core/types.h index 6e83a6b0b7a35777791134a58015d9dd0b085a61..50192e0dce2ea12aedbc10ddd815a294a5f42c9c 100644 --- a/src/core/types.h +++ b/src/core/types.h @@ -165,6 +165,7 @@ struct FunctionSpecialization { : rtn_type(rtn_type), arg_types(arg_types) {} }; +class BoxedClosure; struct CompiledFunction { private: public: @@ -176,6 +177,7 @@ public: union { Box* (*call)(Box*, Box*, Box*, Box**); + Box* (*closure_call)(BoxedClosure*, Box*, Box*, Box*, Box**); void* code; }; llvm::Value* llvm_code; // the llvm callable. diff --git a/src/gc/heap.cpp b/src/gc/heap.cpp index ecb139145457859ba47c133179acd1c48e0917e4..983a577b28cd17e6669a8474dcd2221f8fa7c404 100644 --- a/src/gc/heap.cpp +++ b/src/gc/heap.cpp @@ -422,7 +422,6 @@ static Block** freeChain(Block** head) { } void Heap::freeUnmarked() { - Timer _t("looking at the thread caches"); thread_caches.forEachValue([this](ThreadBlockCache* cache) { for (int bidx = 0; bidx < NUM_BUCKETS; bidx++) { Block* h = cache->cache_free_heads[bidx]; @@ -452,7 +451,6 @@ void Heap::freeUnmarked() { } } }); - _t.end(); for (int bidx = 0; bidx < NUM_BUCKETS; bidx++) { Block** chain_end = freeChain(&heads[bidx]); diff --git a/src/runtime/inline/link_forcer.cpp b/src/runtime/inline/link_forcer.cpp index 6888700d6ff8fbea9c48ef79bec8987c2f1ce470..371bd036d6b194c8b55c0bee3051b36fb83d2f7d 100644 --- a/src/runtime/inline/link_forcer.cpp +++ b/src/runtime/inline/link_forcer.cpp @@ -57,6 +57,7 @@ void force() { FORCE(createList); FORCE(createSlice); FORCE(createUserClass); + FORCE(createClosure); FORCE(getattr); FORCE(setattr); diff --git a/src/runtime/objmodel.cpp b/src/runtime/objmodel.cpp index 370712511bfa3d13c152ea6a5dd750397f2ed046..286a72e7223aabd5abae5db457fe89752436edb8 100644 --- a/src/runtime/objmodel.cpp +++ b/src/runtime/objmodel.cpp @@ -795,6 +795,26 @@ Box* getattr_internal(Box* obj, const std::string& attr, bool check_cls, bool al } } + + // TODO closures should get their own treatment, but now just piggy-back on the + // normal hidden-class IC logic. + // Can do better since we don't need to guard on the cls (always going to be closure) + if (obj->cls == closure_cls) { + BoxedClosure* closure = static_cast<BoxedClosure*>(obj); + if (closure->parent) { + if (rewrite_args) { + rewrite_args->obj = rewrite_args->obj.getAttr(offsetof(BoxedClosure, parent), -1); + } + if (rewrite_args2) { + rewrite_args2->obj + = rewrite_args2->obj.getAttr(offsetof(BoxedClosure, parent), RewriterVarUsage2::Kill); + } + return getattr_internal(closure->parent, attr, false, false, rewrite_args, rewrite_args2); + } + raiseExcHelper(NameError, "free variable '%s' referenced before assignment in enclosing scope", attr.c_str()); + } + + if (allow_custom) { // Don't need to pass icentry args, since we special-case __getattribtue__ and __getattr__ to use // invalidation rather than guards @@ -1655,6 +1675,7 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar int num_output_args = f->numReceivedArgs(); int num_passed_args = argspec.totalPassed(); + BoxedClosure* closure = func->closure; if (argspec.has_starargs || argspec.has_kwargs || f->takes_kwargs) rewrite_args = NULL; @@ -1690,14 +1711,21 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar } if (rewrite_args) { + int closure_indicator = closure ? 1 : 0; + if (num_passed_args >= 1) - rewrite_args->arg1 = rewrite_args->arg1.move(0); + rewrite_args->arg1 = rewrite_args->arg1.move(0 + closure_indicator); if (num_passed_args >= 2) - rewrite_args->arg2 = rewrite_args->arg2.move(1); + rewrite_args->arg2 = rewrite_args->arg2.move(1 + closure_indicator); if (num_passed_args >= 3) - rewrite_args->arg3 = rewrite_args->arg3.move(2); + rewrite_args->arg3 = rewrite_args->arg3.move(2 + closure_indicator); if (num_passed_args >= 4) - rewrite_args->args = rewrite_args->args.move(3); + rewrite_args->args = rewrite_args->args.move(3 + closure_indicator); + + // TODO this kind of embedded reference needs to be tracked by the GC somehow? + // Or maybe it's ok, since we've guarded on the function object? + if (closure) + rewrite_args->rewriter->loadConst(0, (intptr_t)closure); // We might have trouble if we have more output args than input args, // such as if we need more space to pass defaults. @@ -1898,7 +1926,7 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar assert(chosen_cf->is_interpreted == (chosen_cf->code == NULL)); if (chosen_cf->is_interpreted) { - return interpretFunction(chosen_cf->func, num_output_args, oarg1, oarg2, oarg3, oargs); + return interpretFunction(chosen_cf->func, num_output_args, func->closure, oarg1, oarg2, oarg3, oargs); } else { if (rewrite_args) { rewrite_args->rewriter->addDependenceOn(chosen_cf->dependent_callsites); @@ -1908,7 +1936,11 @@ Box* callFunc(BoxedFunction* func, CallRewriteArgs* rewrite_args, ArgPassSpec ar rewrite_args->out_rtn = var; rewrite_args->out_success = true; } - return chosen_cf->call(oarg1, oarg2, oarg3, oargs); + + if (closure) + return chosen_cf->closure_call(closure, oarg1, oarg2, oarg3, oargs); + else + return chosen_cf->call(oarg1, oarg2, oarg3, oargs); } } diff --git a/src/runtime/objmodel.h b/src/runtime/objmodel.h index 3b537ac63896648637e695a4de292a4e936e05ea..c39966bc16083489a5cbe4907be6d2501217a80f 100644 --- a/src/runtime/objmodel.h +++ b/src/runtime/objmodel.h @@ -69,6 +69,7 @@ extern "C" void checkUnpackingLength(i64 expected, i64 given); extern "C" void assertNameDefined(bool b, const char* name); extern "C" void assertFail(BoxedModule* inModule, Box* msg); extern "C" bool isSubclass(BoxedClass* child, BoxedClass* parent); +extern "C" BoxedClosure* createClosure(BoxedClosure* parent_closure); class BinopRewriteArgs; extern "C" Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args); diff --git a/src/runtime/types.cpp b/src/runtime/types.cpp index 3d42d783362699bc677f054cf751d5681596fb08..2d26f1a83a767465e29283e9e4c26db3300968cb 100644 --- a/src/runtime/types.cpp +++ b/src/runtime/types.cpp @@ -65,7 +65,7 @@ llvm::iterator_range<BoxIterator> Box::pyElements() { } extern "C" BoxedFunction::BoxedFunction(CLFunction* f) - : Box(&function_flavor, function_cls), f(f), ndefaults(0), defaults(NULL) { + : Box(&function_flavor, function_cls), f(f), closure(NULL), ndefaults(0), defaults(NULL) { if (f->source) { assert(f->source->ast); // this->giveAttr("__name__", boxString(&f->source->ast->name)); @@ -78,13 +78,15 @@ extern "C" BoxedFunction::BoxedFunction(CLFunction* f) assert(f->num_defaults == ndefaults); } -extern "C" BoxedFunction::BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults) - : Box(&function_flavor, function_cls), f(f), ndefaults(0), defaults(NULL) { - // make sure to initialize defaults first, since the GC behavior is triggered by ndefaults, - // and a GC can happen within this constructor: - this->defaults = new (defaults.size()) GCdArray(); - memcpy(this->defaults->elts, defaults.begin(), defaults.size() * sizeof(Box*)); - this->ndefaults = defaults.size(); +extern "C" BoxedFunction::BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults, BoxedClosure* closure) + : Box(&function_flavor, function_cls), f(f), closure(closure), ndefaults(0), defaults(NULL) { + if (defaults.size()) { + // make sure to initialize defaults first, since the GC behavior is triggered by ndefaults, + // and a GC can happen within this constructor: + this->defaults = new (defaults.size()) GCdArray(); + memcpy(this->defaults->elts, defaults.begin(), defaults.size() * sizeof(Box*)); + this->ndefaults = defaults.size(); + } if (f->source) { assert(f->source->ast); @@ -104,6 +106,9 @@ extern "C" void functionGCHandler(GCVisitor* v, void* p) { BoxedFunction* f = (BoxedFunction*)p; + if (f->closure) + v->visit(f->closure); + // It's ok for f->defaults to be NULL here even if f->ndefaults isn't, // since we could be collecting from inside a BoxedFunction constructor if (f->ndefaults) { @@ -130,8 +135,11 @@ std::string BoxedModule::name() { } } -extern "C" Box* boxCLFunction(CLFunction* f) { - return new BoxedFunction(f); +extern "C" Box* boxCLFunction(CLFunction* f, BoxedClosure* closure) { + if (closure) + assert(closure->cls == closure_cls); + + return new BoxedFunction(f, {}, closure); } extern "C" CLFunction* unboxCLFunction(Box* b) { @@ -245,9 +253,18 @@ extern "C" void conservativeGCHandler(GCVisitor* v, void* p) { v->visitPotentialRange(start, start + (size / sizeof(void*))); } +extern "C" void closureGCHandler(GCVisitor* v, void* p) { + boxGCHandler(v, p); + + BoxedClosure* c = (BoxedClosure*)v; + if (c->parent) + v->visit(c->parent); +} + extern "C" { BoxedClass* object_cls, *type_cls, *none_cls, *bool_cls, *int_cls, *float_cls, *str_cls, *function_cls, - *instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *member_cls; + *instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *member_cls, + *closure_cls; const ObjectFlavor object_flavor(&boxGCHandler, NULL); const ObjectFlavor type_flavor(&typeGCHandler, NULL); @@ -265,6 +282,7 @@ const ObjectFlavor dict_flavor(&dictGCHandler, NULL); const ObjectFlavor tuple_flavor(&tupleGCHandler, NULL); const ObjectFlavor file_flavor(&boxGCHandler, NULL); const ObjectFlavor member_flavor(&boxGCHandler, NULL); +const ObjectFlavor closure_flavor(&closureGCHandler, NULL); const AllocationKind untracked_kind(NULL, NULL); const AllocationKind hc_kind(&hcGCHandler, NULL); @@ -356,6 +374,12 @@ extern "C" Box* createSlice(Box* start, Box* stop, Box* step) { return rtn; } +extern "C" BoxedClosure* createClosure(BoxedClosure* parent_closure) { + if (parent_closure) + assert(parent_closure->cls == closure_cls); + return new BoxedClosure(parent_closure); +} + extern "C" Box* sliceNew(Box* cls, Box* start, Box* stop, Box** args) { RELEASE_ASSERT(cls == slice_cls, ""); Box* step = args[0]; @@ -483,6 +507,7 @@ void setupRuntime() { file_cls = new BoxedClass(object_cls, 0, sizeof(BoxedFile), false); set_cls = new BoxedClass(object_cls, 0, sizeof(BoxedSet), false); member_cls = new BoxedClass(object_cls, 0, sizeof(BoxedMemberDescriptor), false); + closure_cls = new BoxedClass(object_cls, offsetof(BoxedClosure, attrs), sizeof(BoxedClosure), false); STR = typeFromClass(str_cls); BOXED_INT = typeFromClass(int_cls); @@ -524,6 +549,9 @@ void setupRuntime() { member_cls->giveAttr("__name__", boxStrConstant("member")); member_cls->freeze(); + closure_cls->giveAttr("__name__", boxStrConstant("closure")); + closure_cls->freeze(); + setupBool(); setupInt(); setupFloat(); diff --git a/src/runtime/types.h b/src/runtime/types.h index 81babd51536e66896fe272a35f2f4df2eb296745..1c589e002c5b13b06f6bafa2e3ad97016a0c0563 100644 --- a/src/runtime/types.h +++ b/src/runtime/types.h @@ -27,6 +27,7 @@ class BoxedList; class BoxedDict; class BoxedTuple; class BoxedFile; +class BoxedClosure; void setupInt(); void teardownInt(); @@ -62,12 +63,13 @@ BoxedList* getSysPath(); extern "C" { extern BoxedClass* object_cls, *type_cls, *bool_cls, *int_cls, *float_cls, *str_cls, *function_cls, *none_cls, - *instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *xrange_cls, *member_cls; + *instancemethod_cls, *list_cls, *slice_cls, *module_cls, *dict_cls, *tuple_cls, *file_cls, *xrange_cls, *member_cls, + *closure_cls; } extern "C" { extern const ObjectFlavor object_flavor, type_flavor, bool_flavor, int_flavor, float_flavor, str_flavor, function_flavor, none_flavor, instancemethod_flavor, list_flavor, slice_flavor, module_flavor, dict_flavor, - tuple_flavor, file_flavor, xrange_flavor, member_flavor; + tuple_flavor, file_flavor, xrange_flavor, member_flavor, closure_flavor; } extern "C" { extern Box* None, *NotImplemented, *True, *False; } extern "C" { @@ -86,7 +88,7 @@ Box* boxString(const std::string& s); extern "C" BoxedString* boxStrConstant(const char* chars); extern "C" void listAppendInternal(Box* self, Box* v); extern "C" void listAppendArrayInternal(Box* self, Box** v, int nelts); -extern "C" Box* boxCLFunction(CLFunction* f); +extern "C" Box* boxCLFunction(CLFunction* f, BoxedClosure* closure); extern "C" CLFunction* unboxCLFunction(Box* b); extern "C" Box* createUserClass(std::string* name, Box* base, BoxedModule* parent_module); extern "C" double unboxFloat(Box* b); @@ -272,12 +274,13 @@ class BoxedFunction : public Box { public: HCAttrs attrs; CLFunction* f; + BoxedClosure* closure; int ndefaults; GCdArray* defaults; BoxedFunction(CLFunction* f); - BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults); + BoxedFunction(CLFunction* f, std::initializer_list<Box*> defaults, BoxedClosure* closure = NULL); }; class BoxedModule : public Box { @@ -307,6 +310,15 @@ public: BoxedMemberDescriptor(MemberType type, int offset) : Box(&member_flavor, member_cls), type(type), offset(offset) {} }; +// TODO is there any particular reason to make this a Box, ie a python-level object? +class BoxedClosure : public Box { +public: + HCAttrs attrs; + BoxedClosure* parent; + + BoxedClosure(BoxedClosure* parent) : Box(&closure_flavor, closure_cls), parent(parent) {} +}; + extern "C" void boxGCHandler(GCVisitor* v, void* p); Box* exceptionNew1(BoxedClass* cls); diff --git a/test/tests/closure_test.py b/test/tests/closure_test.py new file mode 100644 index 0000000000000000000000000000000000000000..624aba59b6047224247c929fa2409b21649363b0 --- /dev/null +++ b/test/tests/closure_test.py @@ -0,0 +1,53 @@ +# closure tests + +# simple closure: +def make_adder(x): + def g(y): + return x + y + return g + +a = make_adder(1) +print a(5) +print map(a, range(5)) + +def make_adder2(x2): + # f takes a closure since it needs to get a reference to x2 to pass to g + def f(): + def g(y2): + return x2 + y2 + return g + + r = f() + print r(1) + x2 += 1 + print r(1) + return r + +a = make_adder2(2) +print a(5) +print map(a, range(5)) + +def make_addr3(x3): + # this function doesn't take a closure: + def f1(): + return 2 + f1() + + def g(y3): + return x3 + y3 + + # this function does a different closure + def f2(): + print f1() + f2() + return g +print make_addr3(10)(2) + +def bad_addr3(_x): + if 0: + x3 = _ + + def g(y3): + return x3 + y3 + return g +print bad_addr3(1)(2) diff --git a/test/tests/closure_varargs_test.py b/test/tests/closure_varargs_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6ba353e59c7bf6b7a723e61a0e759e83ab9b5f --- /dev/null +++ b/test/tests/closure_varargs_test.py @@ -0,0 +1,20 @@ +# expected: fail +# - varargs, kwarg + +# Regression test: make sure that args and kw get properly treated as potentially-saved-in-closure + +def f1(*args): + def inner(): + return args + return inner + +print f1()() +print f1(1, 2, 3, "a")() + +def f2(**kw): + def inner(): + return kw + return inner + +print f2()() +print f2(a=1)()