Commit 44ec4eef authored by Marius Wachtler's avatar Marius Wachtler

Implement lambda expressions

parent ddabda9a
......@@ -63,6 +63,9 @@ public:
_doStore(node->name);
return true;
}
bool visit_lambda(AST_Lambda* node) { return true; }
bool visit_name(AST_Name* node) {
if (node->ctx_type == AST_TYPE::Load)
_doLoad(node->id);
......
......@@ -234,6 +234,19 @@ public:
}
}
virtual bool visit_lambda(AST_Lambda* node) {
assert(node == orig_node);
for (AST_expr* e : node->args->args)
e->accept(this);
if (node->args->vararg.size())
doWrite(node->args->vararg);
if (node->args->kwarg.size())
doWrite(node->args->kwarg);
node->body->accept(this);
return true;
}
virtual bool visit_import(AST_Import* node) {
for (int i = 0; i < node->names.size(); i++) {
AST_alias* alias = node->names[i];
......@@ -328,10 +341,9 @@ void ScopingAnalysis::processNameUsages(ScopingAnalysis::NameUsageMap* usages) {
ScopeInfo* parent_info = this->scopes[(usage->parent == NULL) ? this->parent_module : usage->parent->node];
switch (node->type) {
case AST_TYPE::FunctionDef:
this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break;
case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
this->scopes[node] = new ScopeInfoBase(parent_info, usage);
break;
default:
......@@ -363,6 +375,7 @@ ScopeInfo* ScopingAnalysis::getScopeInfoForNode(AST* node) {
switch (node->type) {
case AST_TYPE::ClassDef:
case AST_TYPE::FunctionDef:
case AST_TYPE::Lambda:
return analyzeSubtree(node);
// this is handled in the constructor:
// case AST_TYPE::Module:
......
......@@ -340,6 +340,8 @@ private:
virtual void* visit_index(AST_Index* node) { return getType(node->value); }
virtual void* visit_lambda(AST_Lambda* node) { return typeFromClass(function_cls); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) {
switch (node->opcode) {
case AST_LangPrimitive::ISINSTANCE:
......
......@@ -47,6 +47,8 @@ const std::string SourceInfo::getName() {
switch (ast->type) {
case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->name;
case AST_TYPE::Lambda:
return "<lambda>";
case AST_TYPE::Module:
return this->parent_module->name();
default:
......@@ -59,6 +61,8 @@ AST_arguments* SourceInfo::getArgsAST() {
switch (ast->type) {
case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->args;
case AST_TYPE::Lambda:
return ast_cast<AST_Lambda>(ast)->args;
case AST_TYPE::Module:
return NULL;
default:
......@@ -81,18 +85,6 @@ const std::vector<AST_expr*>* CLFunction::getArgNames() {
return &source->getArgNames();
}
const std::vector<AST_stmt*>& SourceInfo::getBody() {
assert(ast);
switch (ast->type) {
case AST_TYPE::FunctionDef:
return ast_cast<AST_FunctionDef>(ast)->body;
case AST_TYPE::Module:
return ast_cast<AST_Module>(ast)->body;
default:
RELEASE_ASSERT(0, "%d", ast->type);
}
}
EffortLevel::EffortLevel initialEffort() {
if (FORCE_OPTIMIZE)
return EffortLevel::MAXIMAL;
......@@ -187,7 +179,7 @@ CompiledFunction* compileFunction(CLFunction* f, FunctionSpecialization* spec, E
// Do the analysis now if we had deferred it earlier:
if (source->cfg == NULL) {
assert(source->ast);
source->cfg = computeCFG(source->ast->type, source->getBody());
source->cfg = computeCFG(source->ast->type, source->body);
source->liveness = computeLivenessInfo(source->cfg);
source->phis = computeRequiredPhis(args, source->cfg, source->liveness,
source->scoping->getScopeInfoForNode(source->ast));
......@@ -246,9 +238,8 @@ void compileAndRunModule(AST_Module* m, BoxedModule* bm) {
ScopingAnalysis* scoping = runScopingAnalysis(m);
SourceInfo* si = new SourceInfo(bm, scoping);
SourceInfo* si = new SourceInfo(bm, scoping, m, m->body);
si->cfg = computeCFG(AST_TYPE::Module, m->body);
si->ast = m;
si->liveness = computeLivenessInfo(si->cfg);
si->phis = computeRequiredPhis(NULL, si->cfg, si->liveness, si->scoping->getScopeInfoForNode(si->ast));
......
......@@ -756,6 +756,24 @@ private:
return evalExpr(node->value, exc_info);
}
CompilerVariable* evalLambda(AST_Lambda* node, ExcInfo exc_info) {
assert(state != PARTIAL);
AST_Return* expr = new AST_Return();
expr->value = node->body;
SourceInfo* si = new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping,
node, { expr });
CLFunction* cl = new CLFunction(node->args->args.size(), node->args->defaults.size(), node->args->vararg.size(),
node->args->kwarg.size(), si);
CompilerVariable* func = makeFunction(emitter, cl, NULL);
ConcreteCompilerVariable* converted = func->makeConverted(emitter, func->getBoxType());
func->decvref(emitter);
return converted;
}
CompilerVariable* evalList(AST_List* node, ExcInfo exc_info) {
assert(state != PARTIAL);
......@@ -1027,6 +1045,9 @@ private:
case AST_TYPE::Index:
rtn = evalIndex(ast_cast<AST_Index>(node), exc_info);
break;
case AST_TYPE::Lambda:
rtn = evalLambda(ast_cast<AST_Lambda>(node), exc_info);
break;
case AST_TYPE::List:
rtn = evalList(ast_cast<AST_List>(node), exc_info);
break;
......@@ -1476,8 +1497,8 @@ private:
CLFunction*& cl = made[node];
if (cl == NULL) {
SourceInfo* si = new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping);
si->ast = node;
SourceInfo* si = new SourceInfo(irstate->getSourceInfo()->parent_module, irstate->getSourceInfo()->scoping,
node, node->body);
cl = new CLFunction(node->args->args.size(), node->args->defaults.size(), node->args->vararg.size(),
node->args->kwarg.size(), si);
}
......
......@@ -457,6 +457,16 @@ AST_keyword* read_keyword(BufferedReader* reader) {
return rtn;
}
AST_Lambda* read_lambda(BufferedReader* reader) {
AST_Lambda* rtn = new AST_Lambda();
rtn->args = ast_cast<AST_arguments>(readASTMisc(reader));
rtn->body = readASTExpr(reader);
rtn->col_offset = readColOffset(reader);
rtn->lineno = reader->readULL();
return rtn;
}
AST_List* read_list(BufferedReader* reader) {
AST_List* rtn = new AST_List();
......@@ -696,6 +706,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_ifexp(reader);
case AST_TYPE::Index:
return read_index(reader);
case AST_TYPE::Lambda:
return read_lambda(reader);
case AST_TYPE::List:
return read_list(reader);
case AST_TYPE::ListComp:
......
......@@ -567,6 +567,19 @@ void AST_keyword::accept(ASTVisitor* v) {
value->accept(v);
}
void AST_Lambda::accept(ASTVisitor* v) {
bool skip = v->visit_lambda(this);
if (skip)
return;
args->accept(v);
body->accept(v);
}
void* AST_Lambda::accept_expr(ExprVisitor* v) {
return v->visit_lambda(this);
}
void AST_LangPrimitive::accept(ASTVisitor* v) {
bool skip = v->visit_langprimitive(this);
if (skip)
......@@ -1272,6 +1285,14 @@ bool PrintVisitor::visit_invoke(AST_Invoke* node) {
return true;
}
bool PrintVisitor::visit_lambda(AST_Lambda* node) {
printf("lambda ");
node->args->accept(this);
printf(": ");
node->body->accept(this);
return true;
}
bool PrintVisitor::visit_langprimitive(AST_LangPrimitive* node) {
printf(":");
switch (node->opcode) {
......@@ -1726,6 +1747,10 @@ public:
output->push_back(node);
return false;
}
virtual bool visit_lambda(AST_Lambda* node) {
output->push_back(node);
return !expand_scopes;
}
virtual bool visit_langprimitive(AST_LangPrimitive* node) {
output->push_back(node);
return false;
......
......@@ -532,6 +532,19 @@ public:
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword;
};
class AST_Lambda : public AST_expr {
public:
AST_arguments* args;
AST_expr* body;
virtual void accept(ASTVisitor* v);
virtual void* accept_expr(ExprVisitor* v);
AST_Lambda() : AST_expr(AST_TYPE::Lambda) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Lambda;
};
class AST_List : public AST_expr {
public:
std::vector<AST_expr*> elts;
......@@ -916,6 +929,7 @@ public:
virtual bool visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_invoke(AST_Invoke* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_keyword(AST_keyword* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
......@@ -978,6 +992,7 @@ public:
virtual bool visit_index(AST_Index* node) { return false; }
virtual bool visit_invoke(AST_Invoke* node) { return false; }
virtual bool visit_keyword(AST_keyword* node) { return false; }
virtual bool visit_lambda(AST_Lambda* node) { return false; }
virtual bool visit_langprimitive(AST_LangPrimitive* node) { return false; }
virtual bool visit_list(AST_List* node) { return false; }
virtual bool visit_listcomp(AST_ListComp* node) { return false; }
......@@ -1020,6 +1035,7 @@ public:
virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_list(AST_List* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_listcomp(AST_ListComp* node) { RELEASE_ASSERT(0, ""); }
......@@ -1108,6 +1124,7 @@ public:
virtual bool visit_index(AST_Index* node);
virtual bool visit_invoke(AST_Invoke* node);
virtual bool visit_keyword(AST_keyword* node);
virtual bool visit_lambda(AST_Lambda* node);
virtual bool visit_langprimitive(AST_LangPrimitive* node);
virtual bool visit_list(AST_List* node);
virtual bool visit_listcomp(AST_ListComp* node);
......
......@@ -565,6 +565,21 @@ private:
return rtn;
}
AST_expr* remapLambda(AST_Lambda* node) {
AST_Lambda* rtn = new AST_Lambda();
rtn->lineno = node->lineno;
rtn->col_offset = node->col_offset;
rtn->args = node->args;
// remap default arguments
rtn->args->defaults.clear();
for (auto& e : node->args->defaults)
rtn->args->defaults.push_back(remapExpr(e));
rtn->body = node->body;
return rtn;
}
AST_expr* remapLangPrimitive(AST_LangPrimitive* node) {
AST_LangPrimitive* rtn = new AST_LangPrimitive(node->opcode);
for (AST_expr* arg : node->args) {
......@@ -676,6 +691,9 @@ private:
case AST_TYPE::Index:
rtn = remapIndex(ast_cast<AST_Index>(node));
break;
case AST_TYPE::Lambda:
rtn = remapLambda(ast_cast<AST_Lambda>(node));
break;
case AST_TYPE::LangPrimitive:
rtn = remapLangPrimitive(ast_cast<AST_LangPrimitive>(node));
break;
......@@ -1023,7 +1041,7 @@ public:
}
virtual bool visit_return(AST_Return* node) {
if (root_type != AST_TYPE::FunctionDef) {
if (root_type != AST_TYPE::FunctionDef && root_type != AST_TYPE::Lambda) {
fprintf(stderr, "SyntaxError: 'return' outside function\n");
exit(1);
}
......
......@@ -203,14 +203,14 @@ public:
CFG* cfg;
LivenessAnalysis* liveness;
PhiAnalysis* phis;
const std::vector<AST_stmt*> body;
const std::string getName();
AST_arguments* getArgsAST();
const std::vector<AST_expr*>& getArgNames();
const std::vector<AST_stmt*>& getBody();
SourceInfo(BoxedModule* m, ScopingAnalysis* scoping)
: parent_module(m), scoping(scoping), ast(NULL), cfg(NULL), liveness(NULL), phis(NULL) {}
SourceInfo(BoxedModule* m, ScopingAnalysis* scoping, AST* ast, const std::vector<AST_stmt*>& body)
: parent_module(m), scoping(scoping), ast(ast), cfg(NULL), liveness(NULL), phis(NULL), body(body) {}
};
typedef std::vector<CompiledFunction*> FunctionList;
......
s = lambda x: x**2
print s(8), s(100)
for i in range(10):
print (lambda x, y: x < y)(i, 5)
t = lambda s: " ".join(s.split())
print t("test \tstr\ni\n ng")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment