Commit 6781ec24 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Merge pull request #90 from undingen/lambda_expr2

Implement lambda expressions
parents b2b06576 0688008c
...@@ -66,6 +66,13 @@ public: ...@@ -66,6 +66,13 @@ public:
_doStore(node->name); _doStore(node->name);
return true; return true;
} }
bool visit_lambda(AST_Lambda* node) {
for (auto* d : node->args->defaults)
d->accept(this);
return true;
}
bool visit_name(AST_Name* node) { bool visit_name(AST_Name* node) {
if (node->ctx_type == AST_TYPE::Load) if (node->ctx_type == AST_TYPE::Load)
_doLoad(node->id); _doLoad(node->id);
......
...@@ -246,6 +246,25 @@ public: ...@@ -246,6 +246,25 @@ public:
} }
} }
virtual bool visit_lambda(AST_Lambda* node) {
if (node == orig_node) {
for (AST_expr* e : node->args->args)
e->accept(this);
if (node->args->vararg.size())
doWrite(node->args->vararg);
if (node->args->kwarg.size())
doWrite(node->args->kwarg);
node->body->accept(this);
} else {
for (auto* e : node->args->defaults)
e->accept(this);
(*map)[node] = new ScopingAnalysis::ScopeNameUsage(node, cur);
collect(node, map);
}
return true;
}
virtual bool visit_import(AST_Import* node) { virtual bool visit_import(AST_Import* node) {
for (int i = 0; i < node->names.size(); i++) { for (int i = 0; i < node->names.size(); i++) {
AST_alias* alias = node->names[i]; AST_alias* alias = node->names[i];
...@@ -292,15 +311,13 @@ static std::vector<ScopingAnalysis::ScopeNameUsage*> sortNameUsages(ScopingAnaly ...@@ -292,15 +311,13 @@ static std::vector<ScopingAnalysis::ScopeNameUsage*> sortNameUsages(ScopingAnaly
} }
void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
typedef ScopeNameUsage::StrSet StrSet;
// Resolve name lookups: // Resolve name lookups:
for (const auto& p : *usages) { for (const auto& p : *usages) {
ScopeNameUsage* usage = p.second; ScopeNameUsage* usage = p.second;
for (StrSet::iterator it2 = usage->read.begin(), end2 = usage->read.end(); it2 != end2; ++it2) { for (const auto& name : usage->read) {
if (usage->forced_globals.count(*it2)) if (usage->forced_globals.count(name))
continue; continue;
if (usage->written.count(*it2)) if (usage->written.count(name))
continue; continue;
std::vector<ScopeNameUsage*> intermediate_parents; std::vector<ScopeNameUsage*> intermediate_parents;
...@@ -309,15 +326,15 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { ...@@ -309,15 +326,15 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
while (parent) { while (parent) {
if (parent->node->type == AST_TYPE::ClassDef) { if (parent->node->type == AST_TYPE::ClassDef) {
parent = parent->parent; parent = parent->parent;
} else if (parent->forced_globals.count(*it2)) { } else if (parent->forced_globals.count(name)) {
break; break;
} else if (parent->written.count(*it2)) { } else if (parent->written.count(name)) {
usage->got_from_closure.insert(*it2); usage->got_from_closure.insert(name);
parent->referenced_from_nested.insert(*it2); parent->referenced_from_nested.insert(name);
for (ScopeNameUsage* iparent : intermediate_parents) { for (ScopeNameUsage* iparent : intermediate_parents) {
iparent->referenced_from_nested.insert(*it2); iparent->referenced_from_nested.insert(name);
iparent->got_from_closure.insert(*it2); iparent->got_from_closure.insert(name);
} }
break; break;
...@@ -340,10 +357,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) { ...@@ -340,10 +357,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
ScopeInfo* parent_info = this->scopes[(usage->parent == NULL) ? this->parent_module : usage->parent->node]; ScopeInfo* parent_info = this->scopes[(usage->parent == NULL) ? this->parent_module : usage->parent->node];
switch (node->type) { switch (node->type) {
case AST_TYPE::FunctionDef:
this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break;
case AST_TYPE::ClassDef: case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
this->scopes[node] = new ScopeInfoBase(parent_info, usage); this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break; break;
default: default:
...@@ -375,6 +391,7 @@ ScopeInfo* ScopingAnalysis::getScopeInfoForNode(AST* node) { ...@@ -375,6 +391,7 @@ ScopeInfo* ScopingAnalysis::getScopeInfoForNode(AST* node) {
switch (node->type) { switch (node->type) {
case AST_TYPE::ClassDef: case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef: case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
return analyzeSubtree(node); return analyzeSubtree(node);
// this is handled in the constructor: // this is handled in the constructor:
// case AST_TYPE::Module: // case AST_TYPE::Module:
......
...@@ -351,6 +351,8 @@ private: ...@@ -351,6 +351,8 @@ private:
virtual void* visit_index(AST_Index* node) { return getType(node->value); } virtual void* visit_index(AST_Index* node) { return getType(node->value); }
virtual void* visit_lambda(AST_Lambda* node) { return typeFromClass(function_cls); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { virtual void* visit_langprimitive(AST_LangPrimitive* node) {
switch (node->opcode) { switch (node->opcode) {
case AST_LangPrimitive::ISINSTANCE: case AST_LangPrimitive::ISINSTANCE:
......
...@@ -51,6 +51,11 @@ SourceInfo::ArgNames::ArgNames(AST* ast) { ...@@ -51,6 +51,11 @@ SourceInfo::ArgNames::ArgNames(AST* ast) {
args = &f->args->args; args = &f->args->args;
vararg = &f->args->vararg; vararg = &f->args->vararg;
kwarg = &f->args->kwarg; kwarg = &f->args->kwarg;
} else if (ast->type == AST_TYPE::Lambda) {
AST_Lambda* l = ast_cast<AST_Lambda>(ast);
args = &l->args->args;
vararg = &l->args->vararg;
kwarg = &l->args->kwarg;
} else { } else {
RELEASE_ASSERT(0, "%d", ast->type); RELEASE_ASSERT(0, "%d", ast->type);
} }
...@@ -63,6 +68,8 @@ const std::string SourceInfo::getName() { ...@@ -63,6 +68,8 @@ const std::string SourceInfo::getName() {
return ast_cast<AST_ClassDef>(ast)->name; return ast_cast<AST_ClassDef>(ast)->name;
case AST_TYPE::FunctionDef: case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->name; return ast_cast<AST_FunctionDef>(ast)->name;
case AST_TYPE::Lambda:
return "<lambda>";
case AST_TYPE::Module: case AST_TYPE::Module:
return this->parent_module->name(); return this->parent_module->name();
default: default:
...@@ -70,20 +77,6 @@ const std::string SourceInfo::getName() { ...@@ -70,20 +77,6 @@ const std::string SourceInfo::getName() {
} }
} }
const std::vector<AST_stmt*>& SourceInfo::getBody() {
assert(ast);
switch (ast->type) {
case AST_TYPE::ClassDef:
return ast_cast<AST_ClassDef>(ast)->body;
case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->body;
case AST_TYPE::Module:
return ast_cast<AST_Module>(ast)->body;
default:
RELEASE_ASSERT(0, "%d", ast->type);
}
}
EffortLevel::EffortLevel initialEffort() { EffortLevel::EffortLevel initialEffort() {
if (FORCE_OPTIMIZE) if (FORCE_OPTIMIZE)
return EffortLevel::MAXIMAL; return EffortLevel::MAXIMAL;
...@@ -169,7 +162,7 @@ CompiledFunction* compileFunction(CLFunction* f, FunctionSpecialization* spec, E ...@@ -169,7 +162,7 @@ CompiledFunction* compileFunction(CLFunction* f, FunctionSpecialization* spec, E
// Do the analysis now if we had deferred it earlier: // Do the analysis now if we had deferred it earlier:
if (source->cfg == NULL) { if (source->cfg == NULL) {
assert(source->ast); assert(source->ast);
source->cfg = computeCFG(source, source->getBody()); source->cfg = computeCFG(source, source->body);
source->liveness = computeLivenessInfo(source->cfg); source->liveness = computeLivenessInfo(source->cfg);
source->phis = computeRequiredPhis(source->arg_names, source->cfg, source->liveness, source->phis = computeRequiredPhis(source->arg_names, source->cfg, source->liveness,
source->scoping->getScopeInfoForNode(source->ast)); source->scoping->getScopeInfoForNode(source->ast));
...@@ -231,7 +224,7 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) { ...@@ -231,7 +224,7 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) {
ScopingAnalysis* scoping = runScopingAnalysis(m); ScopingAnalysis* scoping = runScopingAnalysis(m);
SourceInfo* si = new SourceInfo(bm, scoping, m); SourceInfo* si = new SourceInfo(bm, scoping, m, m->body);
si->cfg = computeCFG(si, m->body); si->cfg = computeCFG(si, m->body);
si->liveness = computeLivenessInfo(si->cfg); si->liveness = computeLivenessInfo(si->cfg);
si->phis = computeRequiredPhis(si->arg_names, si->cfg, si->liveness, si->scoping->getScopeInfoForNode(si->ast)); si->phis = computeRequiredPhis(si->arg_names, si->cfg, si->liveness, si->scoping->getScopeInfoForNode(si->ast));
......
...@@ -757,6 +757,21 @@ private: ...@@ -757,6 +757,21 @@ private:
return evalExpr(node->value, exc_info); return evalExpr(node->value, exc_info);
} }
CompilerVariable* evalLambda(AST_Lambda* node, ExcInfo exc_info) {
assert(state != PARTIAL);
AST_Return* expr = new AST_Return();
expr->value = node->body;
std::vector<AST_stmt*> body = { expr };
CompilerVariable* func = _createFunction(node, exc_info, node->args, body);
ConcreteCompilerVariable* converted = func->makeConverted(emitter, func->getBoxType());
func->decvref(emitter);
return converted;
}
CompilerVariable* evalList(AST_List* node, ExcInfo exc_info) { CompilerVariable* evalList(AST_List* node, ExcInfo exc_info) {
assert(state != PARTIAL); assert(state != PARTIAL);
...@@ -1052,6 +1067,9 @@ private: ...@@ -1052,6 +1067,9 @@ private:
case AST_TYPE::Index: case AST_TYPE::Index:
rtn = evalIndex(ast_cast<AST_Index>(node), exc_info); rtn = evalIndex(ast_cast<AST_Index>(node), exc_info);
break; break;
case AST_TYPE::Lambda:
rtn = evalLambda(ast_cast<AST_Lambda>(node), exc_info);
break;
case AST_TYPE::List: case AST_TYPE::List:
rtn = evalList(ast_cast<AST_List>(node), exc_info); rtn = evalList(ast_cast<AST_List>(node), exc_info);
break; break;
...@@ -1417,10 +1435,9 @@ private: ...@@ -1417,10 +1435,9 @@ private:
ConcreteCompilerVariable* converted_base = base->makeConverted(emitter, base->getBoxType()); ConcreteCompilerVariable* converted_base = base->makeConverted(emitter, base->getBoxType());
base->decvref(emitter); base->decvref(emitter);
CLFunction* cl = _wrapClassDef(node); CLFunction* cl = _wrapFunction(node, nullptr, node->body);
// TODO duplication with doFunctionDef: // TODO duplication with _createFunction:
CompilerVariable* created_closure = NULL; CompilerVariable* created_closure = NULL;
if (scope_info->takesClosure()) { if (scope_info->takesClosure()) {
created_closure = _getFake(CREATED_CLOSURE_NAME, false); created_closure = _getFake(CREATED_CLOSURE_NAME, false);
...@@ -1495,46 +1512,29 @@ private: ...@@ -1495,46 +1512,29 @@ private:
converted_slice->decvref(emitter); converted_slice->decvref(emitter);
} }
CLFunction* _wrapFunction(AST_FunctionDef* node) { CLFunction* _wrapFunction(AST* node, AST_arguments* args, const std::vector<AST_stmt*>& body) {
// Different compilations of the parent scope of a functiondef should lead // Different compilations of the parent scope of a functiondef should lead
// to the same CLFunction* being used: // to the same CLFunction* being used:
static std::unordered_map<AST_FunctionDef*, CLFunction*> made; static std::unordered_map<AST*, CLFunction*> made;
CLFunction*& cl = made[node]; CLFunction*& cl = made[node];
if (cl == NULL) { if (cl == NULL) {
SourceInfo* si SourceInfo* source = irstate->getSourceInfo();
= new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping, node); SourceInfo* si = new SourceInfo(source->parent_module, source->scoping, node, body);
si->ast = node; if (args)
cl = new CLFunction(node->args->args.size(), node->args->defaults.size(), node->args->vararg.size(), cl = new CLFunction(args->args.size(), args->defaults.size(), args->vararg.size(), args->kwarg.size(), si);
node->args->kwarg.size(), si); else
} cl = new CLFunction(0, 0, 0, 0, si);
return cl;
}
CLFunction* _wrapClassDef(AST_ClassDef* node) {
// TODO duplication with _wrapFunction
static std::unordered_map<AST_ClassDef*, CLFunction*> made;
CLFunction*& cl = made[node];
if (cl == NULL) {
SourceInfo* si
= new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping, node);
si->ast = node;
cl = new CLFunction(0, 0, 0, 0, si);
} }
return cl; return cl;
} }
void doFunctionDef(AST_FunctionDef* node, ExcInfo exc_info) { CompilerVariable* _createFunction(AST* node, ExcInfo exc_info, AST_arguments* args,
if (state == PARTIAL) const std::vector<AST_stmt*>& body) {
return; CLFunction* cl = this->_wrapFunction(node, args, body);
assert(!node->decorator_list.size());
CLFunction* cl = this->_wrapFunction(node);
std::vector<ConcreteCompilerVariable*> defaults; std::vector<ConcreteCompilerVariable*> defaults;
for (auto d : node->args->defaults) { for (auto d : args->defaults) {
CompilerVariable* e = evalExpr(d, exc_info); CompilerVariable* e = evalExpr(d, exc_info);
ConcreteCompilerVariable* converted = e->makeConverted(emitter, e->getBoxType()); ConcreteCompilerVariable* converted = e->makeConverted(emitter, e->getBoxType());
e->decvref(emitter); e->decvref(emitter);
...@@ -1558,7 +1558,16 @@ private: ...@@ -1558,7 +1558,16 @@ private:
// llvm::Value *boxed = emitter.getBuilder()->CreateCall(g.funcs.boxCLFunction, embedConstantPtr(cl, // llvm::Value *boxed = emitter.getBuilder()->CreateCall(g.funcs.boxCLFunction, embedConstantPtr(cl,
// boxCLFuncArgType)); // boxCLFuncArgType));
// CompilerVariable *func = new ConcreteCompilerVariable(typeFromClass(function_cls), boxed, true); // CompilerVariable *func = new ConcreteCompilerVariable(typeFromClass(function_cls), boxed, true);
return func;
}
void doFunctionDef(AST_FunctionDef* node, ExcInfo exc_info) {
if (state == PARTIAL)
return;
assert(!node->decorator_list.size());
CompilerVariable* func = _createFunction(node, exc_info, node->args, node->body);
_doSet(node->name, func, exc_info); _doSet(node->name, func, exc_info);
func->decvref(emitter); func->decvref(emitter);
} }
......
...@@ -457,6 +457,16 @@ AST_keyword* read_keyword(BufferedReader* reader) { ...@@ -457,6 +457,16 @@ AST_keyword* read_keyword(BufferedReader* reader) {
return rtn; return rtn;
} }
AST_Lambda* read_lambda(BufferedReader* reader) {
AST_Lambda* rtn = new AST_Lambda();
rtn->args = ast_cast<AST_arguments>(readASTMisc(reader));
rtn->body = readASTExpr(reader);
rtn->col_offset = readColOffset(reader);
rtn->lineno = reader->readULL();
return rtn;
}
AST_List* read_list(BufferedReader* reader) { AST_List* read_list(BufferedReader* reader) {
AST_List* rtn = new AST_List(); AST_List* rtn = new AST_List();
...@@ -696,6 +706,8 @@ AST_expr* readASTExpr(BufferedReader* reader) { ...@@ -696,6 +706,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_ifexp(reader); return read_ifexp(reader);
case AST_TYPE::Index: case AST_TYPE::Index:
return read_index(reader); return read_index(reader);
case AST_TYPE::Lambda:
return read_lambda(reader);
case AST_TYPE::List: case AST_TYPE::List:
return read_list(reader); return read_list(reader);
case AST_TYPE::ListComp: case AST_TYPE::ListComp:
......
...@@ -567,6 +567,19 @@ void AST_keyword::accept(ASTVisitor* v) { ...@@ -567,6 +567,19 @@ void AST_keyword::accept(ASTVisitor* v) {
value->accept(v); value->accept(v);
} }
void AST_Lambda::accept(ASTVisitor* v) {
bool skip = v->visit_lambda(this);
if (skip)
return;
args->accept(v);
body->accept(v);
}
void* AST_Lambda::accept_expr(ExprVisitor* v) {
return v->visit_lambda(this);
}
void AST_LangPrimitive::accept(ASTVisitor* v) { void AST_LangPrimitive::accept(ASTVisitor* v) {
bool skip = v->visit_langprimitive(this); bool skip = v->visit_langprimitive(this);
if (skip) if (skip)
...@@ -1272,6 +1285,14 @@ bool PrintVisitor::visit_invoke(AST_Invoke* node) { ...@@ -1272,6 +1285,14 @@ bool PrintVisitor::visit_invoke(AST_Invoke* node) {
return true; return true;
} }
bool PrintVisitor::visit_lambda(AST_Lambda* node) {
printf("lambda ");
node->args->accept(this);
printf(": ");
node->body->accept(this);
return true;
}
bool PrintVisitor::visit_langprimitive(AST_LangPrimitive* node) { bool PrintVisitor::visit_langprimitive(AST_LangPrimitive* node) {
printf(":"); printf(":");
switch (node->opcode) { switch (node->opcode) {
...@@ -1726,6 +1747,10 @@ public: ...@@ -1726,6 +1747,10 @@ public:
output->push_back(node); output->push_back(node);
return false; return false;
} }
virtual bool visit_lambda(AST_Lambda* node) {
output->push_back(node);
return !expand_scopes;
}
virtual bool visit_langprimitive(AST_LangPrimitive* node) { virtual bool visit_langprimitive(AST_LangPrimitive* node) {
output->push_back(node); output->push_back(node);
return false; return false;
......
...@@ -532,6 +532,19 @@ public: ...@@ -532,6 +532,19 @@ public:
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword; static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword;
}; };
class AST_Lambda : public AST_expr {
public:
AST_arguments* args;
AST_expr* body;
virtual void accept(ASTVisitor* v);
virtual void* accept_expr(ExprVisitor* v);
AST_Lambda() : AST_expr(AST_TYPE::Lambda) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Lambda;
};
class AST_List : public AST_expr { class AST_List : public AST_expr {
public: public:
std::vector<AST_expr*> elts; std::vector<AST_expr*> elts;
...@@ -918,6 +931,7 @@ public: ...@@ -918,6 +931,7 @@ public:
virtual bool visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_invoke(AST_Invoke* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_invoke(AST_Invoke* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_keyword(AST_keyword* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_keyword(AST_keyword* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
...@@ -980,6 +994,7 @@ public: ...@@ -980,6 +994,7 @@ public:
virtual bool visit_index(AST_Index* node) { return false; } virtual bool visit_index(AST_Index* node) { return false; }
virtual bool visit_invoke(AST_Invoke* node) { return false; } virtual bool visit_invoke(AST_Invoke* node) { return false; }
virtual bool visit_keyword(AST_keyword* node) { return false; } virtual bool visit_keyword(AST_keyword* node) { return false; }
virtual bool visit_lambda(AST_Lambda* node) { return false; }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { return false; } virtual bool visit_langprimitive(AST_LangPrimitive* node) { return false; }
virtual bool visit_list(AST_List* node) { return false; } virtual bool visit_list(AST_List* node) { return false; }
virtual bool visit_listcomp(AST_ListComp* node) { return false; } virtual bool visit_listcomp(AST_ListComp* node) { return false; }
...@@ -1022,6 +1037,7 @@ public: ...@@ -1022,6 +1037,7 @@ public:
virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
...@@ -1110,6 +1126,7 @@ public: ...@@ -1110,6 +1126,7 @@ public:
virtual bool visit_index(AST_Index* node); virtual bool visit_index(AST_Index* node);
virtual bool visit_invoke(AST_Invoke* node); virtual bool visit_invoke(AST_Invoke* node);
virtual bool visit_keyword(AST_keyword* node); virtual bool visit_keyword(AST_keyword* node);
virtual bool visit_lambda(AST_Lambda* node);
virtual bool visit_langprimitive(AST_LangPrimitive* node); virtual bool visit_langprimitive(AST_LangPrimitive* node);
virtual bool visit_list(AST_List* node); virtual bool visit_list(AST_List* node);
virtual bool visit_listcomp(AST_ListComp* node); virtual bool visit_listcomp(AST_ListComp* node);
......
...@@ -621,6 +621,27 @@ private: ...@@ -621,6 +621,27 @@ private:
return rtn; return rtn;
} }
AST_expr* remapLambda(AST_Lambda* node) {
if (node->args->defaults.empty()) {
return node;
}
AST_Lambda* rtn = new AST_Lambda();
rtn->lineno = node->lineno;
rtn->col_offset = node->col_offset;
rtn->args = new AST_arguments();
rtn->args->args = node->args->args;
rtn->args->vararg = node->args->vararg;
rtn->args->kwarg = node->args->kwarg;
for (auto d : node->args->defaults) {
rtn->args->defaults.push_back(remapExpr(d));
}
rtn->body = node->body;
return rtn;
}
AST_expr* remapLangPrimitive(AST_LangPrimitive* node) { AST_expr* remapLangPrimitive(AST_LangPrimitive* node) {
AST_LangPrimitive* rtn = new AST_LangPrimitive(node->opcode); AST_LangPrimitive* rtn = new AST_LangPrimitive(node->opcode);
for (AST_expr* arg : node->args) { for (AST_expr* arg : node->args) {
...@@ -732,6 +753,9 @@ private: ...@@ -732,6 +753,9 @@ private:
case AST_TYPE::Index: case AST_TYPE::Index:
rtn = remapIndex(ast_cast<AST_Index>(node)); rtn = remapIndex(ast_cast<AST_Index>(node));
break; break;
case AST_TYPE::Lambda:
rtn = remapLambda(ast_cast<AST_Lambda>(node));
break;
case AST_TYPE::LangPrimitive: case AST_TYPE::LangPrimitive:
rtn = remapLangPrimitive(ast_cast<AST_LangPrimitive>(node)); rtn = remapLangPrimitive(ast_cast<AST_LangPrimitive>(node));
break; break;
...@@ -1103,7 +1127,7 @@ public: ...@@ -1103,7 +1127,7 @@ public:
} }
virtual bool visit_return(AST_Return* node) { virtual bool visit_return(AST_Return* node) {
if (root_type != AST_TYPE::FunctionDef) { if (root_type != AST_TYPE::FunctionDef && root_type != AST_TYPE::Lambda) {
fprintf(stderr, "SyntaxError: 'return' outside function\n"); fprintf(stderr, "SyntaxError: 'return' outside function\n");
exit(1); exit(1);
} }
......
...@@ -218,13 +218,13 @@ public: ...@@ -218,13 +218,13 @@ public:
}; };
ArgNames arg_names; ArgNames arg_names;
const std::vector<AST_stmt*> body;
const std::string getName(); const std::string getName();
// AST_arguments* getArgsAST();
const std::vector<AST_stmt*>& getBody();
SourceInfo(BoxedModule* m, ScopingAnalysis* scoping, AST* ast) SourceInfo(BoxedModule* m, ScopingAnalysis* scoping, AST* ast, const std::vector<AST_stmt*>& body)
: parent_module(m), scoping(scoping), ast(ast), cfg(NULL), liveness(NULL), phis(NULL), arg_names(ast) {} : parent_module(m), scoping(scoping), ast(ast), cfg(NULL), liveness(NULL), phis(NULL), arg_names(ast),
body(body) {}
}; };
typedef std::vector<CompiledFunction*> FunctionList; typedef std::vector<CompiledFunction*> FunctionList;
......
s = lambda x=5: x**2
print s(8), s(100), s()
for i in range(10):
print (lambda x, y: x < y)(i, 5)
t = lambda s: " ".join(s.split())
print t("test \tstr\ni\n ng")
def T(y):
return (lambda x: x < y)
print T(10)(1), T(10)(20)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment