Commit 141ac41f authored by Kevin Modzelewski's avatar Kevin Modzelewski

Implement generator expressions by converting to generator functions

parent 934d8f6f
...@@ -212,6 +212,7 @@ public: ...@@ -212,6 +212,7 @@ public:
virtual bool visit_for(AST_For* node) { return false; } virtual bool visit_for(AST_For* node) { return false; }
// virtual bool visit_functiondef(AST_FunctionDef *node) { return false; } // virtual bool visit_functiondef(AST_FunctionDef *node) { return false; }
// virtual bool visit_global(AST_Global *node) { return false; } // virtual bool visit_global(AST_Global *node) { return false; }
virtual bool visit_generatorexp(AST_GeneratorExp* node) { return false; }
virtual bool visit_if(AST_If* node) { return false; } virtual bool visit_if(AST_If* node) { return false; }
virtual bool visit_ifexp(AST_IfExp* node) { return false; } virtual bool visit_ifexp(AST_IfExp* node) { return false; }
virtual bool visit_index(AST_Index* node) { return false; } virtual bool visit_index(AST_Index* node) { return false; }
......
...@@ -386,6 +386,16 @@ AST_FunctionDef* read_functiondef(BufferedReader* reader) { ...@@ -386,6 +386,16 @@ AST_FunctionDef* read_functiondef(BufferedReader* reader) {
return rtn; return rtn;
} }
AST_GeneratorExp* read_generatorexp(BufferedReader* reader) {
AST_GeneratorExp* rtn = new AST_GeneratorExp();
rtn->col_offset = readColOffset(reader);
rtn->elt = readASTExpr(reader);
readMiscVector(rtn->generators, reader);
rtn->lineno = reader->readULL();
return rtn;
}
AST_Global* read_global(BufferedReader* reader) { AST_Global* read_global(BufferedReader* reader) {
AST_Global* rtn = new AST_Global(); AST_Global* rtn = new AST_Global();
...@@ -736,6 +746,8 @@ AST_expr* readASTExpr(BufferedReader* reader) { ...@@ -736,6 +746,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_dict(reader); return read_dict(reader);
case AST_TYPE::DictComp: case AST_TYPE::DictComp:
return read_dictcomp(reader); return read_dictcomp(reader);
case AST_TYPE::GeneratorExp:
return read_generatorexp(reader);
case AST_TYPE::IfExp: case AST_TYPE::IfExp:
return read_ifexp(reader); return read_ifexp(reader);
case AST_TYPE::Index: case AST_TYPE::Index:
......
...@@ -473,6 +473,22 @@ void AST_FunctionDef::accept_stmt(StmtVisitor* v) { ...@@ -473,6 +473,22 @@ void AST_FunctionDef::accept_stmt(StmtVisitor* v) {
v->visit_functiondef(this); v->visit_functiondef(this);
} }
void AST_GeneratorExp::accept(ASTVisitor* v) {
bool skip = v->visit_generatorexp(this);
if (skip)
return;
for (auto c : generators) {
c->accept(v);
}
elt->accept(v);
}
void* AST_GeneratorExp::accept_expr(ExprVisitor* v) {
return v->visit_generatorexp(this);
}
void AST_Global::accept(ASTVisitor* v) { void AST_Global::accept(ASTVisitor* v) {
bool skip = v->visit_global(this); bool skip = v->visit_global(this);
if (skip) if (skip)
...@@ -1228,6 +1244,17 @@ bool PrintVisitor::visit_functiondef(AST_FunctionDef* node) { ...@@ -1228,6 +1244,17 @@ bool PrintVisitor::visit_functiondef(AST_FunctionDef* node) {
return true; return true;
} }
bool PrintVisitor::visit_generatorexp(AST_GeneratorExp* node) {
printf("[");
node->elt->accept(this);
for (auto c : node->generators) {
printf(" ");
c->accept(this);
}
printf("]");
return true;
}
bool PrintVisitor::visit_global(AST_Global* node) { bool PrintVisitor::visit_global(AST_Global* node) {
printf("global "); printf("global ");
for (int i = 0; i < node->names.size(); i++) { for (int i = 0; i < node->names.size(); i++) {
......
...@@ -444,6 +444,19 @@ public: ...@@ -444,6 +444,19 @@ public:
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::FunctionDef; static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::FunctionDef;
}; };
class AST_GeneratorExp : public AST_expr {
public:
std::vector<AST_comprehension*> generators;
AST_expr* elt;
virtual void accept(ASTVisitor* v);
virtual void* accept_expr(ExprVisitor* v);
AST_GeneratorExp() : AST_expr(AST_TYPE::GeneratorExp) {}
const static AST_TYPE::AST_TYPE TYPE = AST_TYPE::GeneratorExp;
};
class AST_Global : public AST_stmt { class AST_Global : public AST_stmt {
public: public:
std::vector<std::string> names; std::vector<std::string> names;
...@@ -955,6 +968,7 @@ public: ...@@ -955,6 +968,7 @@ public:
virtual bool visit_expr(AST_Expr* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_expr(AST_Expr* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_for(AST_For* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_for(AST_For* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_functiondef(AST_FunctionDef* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_functiondef(AST_FunctionDef* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_generatorexp(AST_GeneratorExp* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_global(AST_Global* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_global(AST_Global* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_if(AST_If* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_if(AST_If* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
...@@ -1020,6 +1034,7 @@ public: ...@@ -1020,6 +1034,7 @@ public:
virtual bool visit_expr(AST_Expr* node) { return false; } virtual bool visit_expr(AST_Expr* node) { return false; }
virtual bool visit_for(AST_For* node) { return false; } virtual bool visit_for(AST_For* node) { return false; }
virtual bool visit_functiondef(AST_FunctionDef* node) { return false; } virtual bool visit_functiondef(AST_FunctionDef* node) { return false; }
virtual bool visit_generatorexp(AST_GeneratorExp* node) { return false; }
virtual bool visit_global(AST_Global* node) { return false; } virtual bool visit_global(AST_Global* node) { return false; }
virtual bool visit_if(AST_If* node) { return false; } virtual bool visit_if(AST_If* node) { return false; }
virtual bool visit_ifexp(AST_IfExp* node) { return false; } virtual bool visit_ifexp(AST_IfExp* node) { return false; }
...@@ -1071,6 +1086,7 @@ public: ...@@ -1071,6 +1086,7 @@ public:
virtual void* visit_compare(AST_Compare* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_compare(AST_Compare* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_generatorexp(AST_GeneratorExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_lambda(AST_Lambda* node) { RELEASE_ASSERT(0, ""); }
...@@ -1156,6 +1172,7 @@ public: ...@@ -1156,6 +1172,7 @@ public:
virtual bool visit_expr(AST_Expr* node); virtual bool visit_expr(AST_Expr* node);
virtual bool visit_for(AST_For* node); virtual bool visit_for(AST_For* node);
virtual bool visit_functiondef(AST_FunctionDef* node); virtual bool visit_functiondef(AST_FunctionDef* node);
virtual bool visit_generatorexp(AST_GeneratorExp* node);
virtual bool visit_global(AST_Global* node); virtual bool visit_global(AST_Global* node);
virtual bool visit_if(AST_If* node); virtual bool visit_if(AST_If* node);
virtual bool visit_ifexp(AST_IfExp* node); virtual bool visit_ifexp(AST_IfExp* node);
......
...@@ -267,7 +267,6 @@ private: ...@@ -267,7 +267,6 @@ private:
return makeName(rtn_name, AST_TYPE::Load); return makeName(rtn_name, AST_TYPE::Load);
} }
AST_expr* makeNum(int n) { AST_expr* makeNum(int n) {
AST_Num* node = new AST_Num(); AST_Num* node = new AST_Num();
node->num_type = AST_Num::INT; node->num_type = AST_Num::INT;
...@@ -576,6 +575,67 @@ private: ...@@ -576,6 +575,67 @@ private:
return rtn; return rtn;
}; };
AST_expr* remapGeneratorExp(AST_GeneratorExp* node) {
assert(node->generators.size());
// I don't think it's easy to determine by the user, but it looks like the first generator iterator gets
// evaluated in the parent scope / frame.
AST_expr* first = remapExpr(node->generators[0]->iter);
AST_FunctionDef* func = new AST_FunctionDef();
func->lineno = node->lineno;
func->col_offset = node->col_offset;
func->name = nodeName(func);
func->args = new AST_arguments();
func->args->vararg = "";
func->args->kwarg = "";
std::string first_generator_name = nodeName(node);
func->args->args.push_back(makeName(first_generator_name, AST_TYPE::Param));
std::vector<AST_stmt*>* insert_point = &func->body;
for (int i = 0; i < node->generators.size(); i++) {
AST_comprehension* c = node->generators[i];
AST_For* loop = new AST_For();
loop->target = c->target;
if (i == 0) {
loop->iter = makeName(first_generator_name, AST_TYPE::Load);
} else {
loop->iter = c->iter;
}
insert_point->push_back(loop);
insert_point = &loop->body;
for (AST_expr* if_condition : c->ifs) {
AST_If* if_block = new AST_If();
if_block->test = if_condition;
insert_point->push_back(if_block);
insert_point = &if_block->body;
}
}
AST_Yield* y = new AST_Yield();
y->value = node->elt;
insert_point->push_back(makeExpr(y));
push_back(func);
AST_Call* call = new AST_Call();
call->lineno = node->lineno;
call->col_offset = node->col_offset;
call->starargs = NULL;
call->kwargs = NULL;
call->func = makeName(nodeName(func), AST_TYPE::Load);
call->args.push_back(first);
return call;
};
AST_expr* remapIfExp(AST_IfExp* node) { AST_expr* remapIfExp(AST_IfExp* node) {
std::string rtn_name = nodeName(node); std::string rtn_name = nodeName(node);
...@@ -764,6 +824,9 @@ private: ...@@ -764,6 +824,9 @@ private:
case AST_TYPE::DictComp: case AST_TYPE::DictComp:
rtn = remapComprehension<AST_Dict>(ast_cast<AST_DictComp>(node)); rtn = remapComprehension<AST_Dict>(ast_cast<AST_DictComp>(node));
break; break;
case AST_TYPE::GeneratorExp:
rtn = remapGeneratorExp(ast_cast<AST_GeneratorExp>(node));
break;
case AST_TYPE::IfExp: case AST_TYPE::IfExp:
rtn = remapIfExp(ast_cast<AST_IfExp>(node)); rtn = remapIfExp(ast_cast<AST_IfExp>(node));
break; break;
......
def f(o, msg):
print msg
return o
g1 = (f(i, i) for i in f(xrange(5), "xrange"))
print 1
print g1.next()
print list(g1)
print
def f2():
g2 = (f(i, j) for i in f(xrange(4), "inner xrange") if i != f(2, 2) if i != f(20, 20) for j in f(xrange(4), "outer xrange") if i % 2 == j % 2)
print 1
print g2.next()
print list(g2)
f2()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment