Commit ccc8fda6 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Merge pull request #65 from vinzenz/dict-comprehension

Implementation of dict comprehension
parents 694413da 9b38fa1a
...@@ -158,6 +158,7 @@ public: ...@@ -158,6 +158,7 @@ public:
// virtual bool visit_classdef(AST_ClassDef *node) { return false; } // virtual bool visit_classdef(AST_ClassDef *node) { return false; }
virtual bool visit_continue(AST_Continue* node) { return false; } virtual bool visit_continue(AST_Continue* node) { return false; }
virtual bool visit_dict(AST_Dict* node) { return false; } virtual bool visit_dict(AST_Dict* node) { return false; }
virtual bool visit_dictcomp(AST_DictComp* node) { return false; }
virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; } virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; }
virtual bool visit_expr(AST_Expr* node) { return false; } virtual bool visit_expr(AST_Expr* node) { return false; }
virtual bool visit_for(AST_For* node) { return false; } virtual bool visit_for(AST_For* node) { return false; }
......
...@@ -328,6 +328,17 @@ AST_Dict* read_dict(BufferedReader* reader) { ...@@ -328,6 +328,17 @@ AST_Dict* read_dict(BufferedReader* reader) {
return rtn; return rtn;
} }
AST_DictComp* read_dictcomp(BufferedReader* reader) {
AST_DictComp* rtn = new AST_DictComp();
rtn->col_offset = readColOffset(reader);
readMiscVector(rtn->generators, reader);
rtn->key = readASTExpr(reader);
rtn->lineno = reader->readULL();
rtn->value = readASTExpr(reader);
return rtn;
}
AST_ExceptHandler* read_excepthandler(BufferedReader* reader) { AST_ExceptHandler* read_excepthandler(BufferedReader* reader) {
AST_ExceptHandler* rtn = new AST_ExceptHandler(); AST_ExceptHandler* rtn = new AST_ExceptHandler();
...@@ -679,6 +690,8 @@ AST_expr* readASTExpr(BufferedReader* reader) { ...@@ -679,6 +690,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_compare(reader); return read_compare(reader);
case AST_TYPE::Dict: case AST_TYPE::Dict:
return read_dict(reader); return read_dict(reader);
case AST_TYPE::DictComp:
return read_dictcomp(reader);
case AST_TYPE::IfExp: case AST_TYPE::IfExp:
return read_ifexp(reader); return read_ifexp(reader);
case AST_TYPE::Index: case AST_TYPE::Index:
......
...@@ -405,6 +405,23 @@ void* AST_Dict::accept_expr(ExprVisitor* v) { ...@@ -405,6 +405,23 @@ void* AST_Dict::accept_expr(ExprVisitor* v) {
return v->visit_dict(this); return v->visit_dict(this);
} }
void AST_DictComp::accept(ASTVisitor* v) {
bool skip = v->visit_dictcomp(this);
if (skip)
return;
for (auto c : generators) {
c->accept(v);
}
value->accept(v);
key->accept(v);
}
void* AST_DictComp::accept_expr(ExprVisitor* v) {
return v->visit_dictcomp(this);
}
void AST_ExceptHandler::accept(ASTVisitor* v) { void AST_ExceptHandler::accept(ASTVisitor* v) {
bool skip = v->visit_excepthandler(this); bool skip = v->visit_excepthandler(this);
if (skip) if (skip)
...@@ -1100,6 +1117,19 @@ bool PrintVisitor::visit_dict(AST_Dict* node) { ...@@ -1100,6 +1117,19 @@ bool PrintVisitor::visit_dict(AST_Dict* node) {
return true; return true;
} }
bool PrintVisitor::visit_dictcomp(AST_DictComp* node) {
printf("{");
node->key->accept(this);
printf(":");
node->value->accept(this);
for (auto c : node->generators) {
printf(" ");
c->accept(this);
}
printf("}");
return true;
}
bool PrintVisitor::visit_excepthandler(AST_ExceptHandler* node) { bool PrintVisitor::visit_excepthandler(AST_ExceptHandler* node) {
printf("except"); printf("except");
if (node->type) { if (node->type) {
...@@ -1633,6 +1663,10 @@ public: ...@@ -1633,6 +1663,10 @@ public:
output->push_back(node); output->push_back(node);
return false; return false;
} }
virtual bool visit_dictcomp(AST_DictComp* node) {
output->push_back(node);
return false;
}
virtual bool visit_excepthandler(AST_ExceptHandler* node) { virtual bool visit_excepthandler(AST_ExceptHandler* node) {
output->push_back(node); output->push_back(node);
return false; return false;
......
...@@ -364,6 +364,19 @@ public: ...@@ -364,6 +364,19 @@ public:
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Dict; static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Dict;
}; };
class AST_DictComp : public AST_expr {
public:
std::vector<AST_comprehension*> generators;
AST_expr* key, *value;
virtual void accept(ASTVisitor* v);
virtual void* accept_expr(ExprVisitor* v);
AST_DictComp() : AST_expr(AST_TYPE::DictComp) {}
const static AST_TYPE::AST_TYPE TYPE = AST_TYPE::DictComp;
};
class AST_Delete : public AST_stmt { class AST_Delete : public AST_stmt {
public: public:
std::vector<AST_expr*> targets; std::vector<AST_expr*> targets;
...@@ -875,6 +888,7 @@ public: ...@@ -875,6 +888,7 @@ public:
virtual bool visit_continue(AST_Continue* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_continue(AST_Continue* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_delete(AST_Delete* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_delete(AST_Delete* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_excepthandler(AST_ExceptHandler* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_excepthandler(AST_ExceptHandler* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_expr(AST_Expr* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_expr(AST_Expr* node) { RELEASE_ASSERT(0, ""); }
virtual bool visit_for(AST_For* node) { RELEASE_ASSERT(0, ""); } virtual bool visit_for(AST_For* node) { RELEASE_ASSERT(0, ""); }
...@@ -935,6 +949,7 @@ public: ...@@ -935,6 +949,7 @@ public:
virtual bool visit_continue(AST_Continue* node) { return false; } virtual bool visit_continue(AST_Continue* node) { return false; }
virtual bool visit_delete(AST_Delete* node) { return false; } virtual bool visit_delete(AST_Delete* node) { return false; }
virtual bool visit_dict(AST_Dict* node) { return false; } virtual bool visit_dict(AST_Dict* node) { return false; }
virtual bool visit_dictcomp(AST_DictComp* node) { return false; }
virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; } virtual bool visit_excepthandler(AST_ExceptHandler* node) { return false; }
virtual bool visit_expr(AST_Expr* node) { return false; } virtual bool visit_expr(AST_Expr* node) { return false; }
virtual bool visit_for(AST_For* node) { return false; } virtual bool visit_for(AST_For* node) { return false; }
...@@ -985,6 +1000,7 @@ public: ...@@ -985,6 +1000,7 @@ public:
virtual void* visit_clsattribute(AST_ClsAttribute* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_clsattribute(AST_ClsAttribute* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_compare(AST_Compare* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_compare(AST_Compare* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_dict(AST_Dict* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_dictcomp(AST_DictComp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_ifexp(AST_IfExp* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_index(AST_Index* node) { RELEASE_ASSERT(0, ""); }
virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); } virtual void* visit_langprimitive(AST_LangPrimitive* node) { RELEASE_ASSERT(0, ""); }
...@@ -1061,6 +1077,7 @@ public: ...@@ -1061,6 +1077,7 @@ public:
virtual bool visit_continue(AST_Continue* node); virtual bool visit_continue(AST_Continue* node);
virtual bool visit_delete(AST_Delete* node); virtual bool visit_delete(AST_Delete* node);
virtual bool visit_dict(AST_Dict* node); virtual bool visit_dict(AST_Dict* node);
virtual bool visit_dictcomp(AST_DictComp* node);
virtual bool visit_excepthandler(AST_ExceptHandler* node); virtual bool visit_excepthandler(AST_ExceptHandler* node);
virtual bool visit_expr(AST_Expr* node); virtual bool visit_expr(AST_Expr* node);
virtual bool visit_for(AST_For* node); virtual bool visit_for(AST_For* node);
......
...@@ -117,6 +117,146 @@ private: ...@@ -117,6 +117,146 @@ private:
return NULL; return NULL;
} }
AST_expr* applyComprehensionCall(AST_DictComp * node, AST_Name* name) {
AST_expr* key = remapExpr(node->key);
AST_expr* value = remapExpr(node->value);
return makeCall(makeLoadAttribute(name, "__setitem__", true), key, value);
}
AST_expr* applyComprehensionCall(AST_ListComp * node, AST_Name* name) {
AST_expr* elt = remapExpr(node->elt);
return makeCall(makeLoadAttribute(name, "append", true), elt);
}
template<typename ResultASTType, typename CompType>
AST_expr* remapComprehension(CompType * node) {
std::string rtn_name = nodeName(node);
push_back(makeAssign(rtn_name, new ResultASTType()));
std::vector<CFGBlock*> exit_blocks;
// Where the current level should jump to after finishing its iteration.
// For the outermost comprehension, this is NULL, and it doesn't jump anywhere;
// for the inner comprehensions, they should jump to the next-outer comprehension
// when they are done iterating.
CFGBlock* finished_block = NULL;
for (int i = 0, n = node->generators.size(); i < n; i++) {
AST_comprehension* c = node->generators[i];
bool is_innermost = (i == n - 1);
AST_expr* remapped_iter = remapExpr(c->iter);
AST_expr* iter_attr = makeLoadAttribute(remapped_iter, "__iter__", true);
AST_expr* iter_call = makeCall(iter_attr);
std::string iter_name = nodeName(node, "iter", i);
AST_stmt* iter_assign = makeAssign(iter_name, iter_call);
push_back(iter_assign);
// TODO bad to save these like this?
AST_expr* hasnext_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "__hasnext__", true);
AST_expr* next_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "next", true);
AST_Jump* j;
CFGBlock* test_block = cfg->addBlock();
test_block->info = "comprehension_test";
// printf("Test block for comp %d is %d\n", i, test_block->idx);
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block);
push_back(j);
curblock = test_block;
AST_expr* test_call = remapExpr(makeCall(hasnext_attr));
CFGBlock* body_block = cfg->addBlock();
body_block->info = "comprehension_body";
CFGBlock* exit_block = cfg->addDeferredBlock();
exit_block->info = "comprehension_exit";
exit_blocks.push_back(exit_block);
// printf("Body block for comp %d is %d\n", i, body_block->idx);
AST_Branch* br = new AST_Branch();
br->col_offset = node->col_offset;
br->lineno = node->lineno;
br->test = test_call;
br->iftrue = body_block;
br->iffalse = exit_block;
curblock->connectTo(body_block);
curblock->connectTo(exit_block);
push_back(br);
curblock = body_block;
push_back(makeAssign(c->target, makeCall(next_attr)));
for (AST_expr* if_condition : c->ifs) {
AST_expr* remapped = remapExpr(if_condition);
AST_Branch* br = new AST_Branch();
br->test = remapped;
push_back(br);
// Put this below the entire body?
CFGBlock* body_tramp = cfg->addBlock();
body_tramp->info = "comprehension_if_trampoline";
// printf("body_tramp for %d is %d\n", i, body_tramp->idx);
CFGBlock* body_continue = cfg->addBlock();
body_continue->info = "comprehension_if_continue";
// printf("body_continue for %d is %d\n", i, body_continue->idx);
br->iffalse = body_tramp;
curblock->connectTo(body_tramp);
br->iftrue = body_continue;
curblock->connectTo(body_continue);
curblock = body_tramp;
j = new AST_Jump();
j->target = test_block;
push_back(j);
curblock->connectTo(test_block, true);
curblock = body_continue;
}
CFGBlock* body_end = curblock;
assert((finished_block != NULL) == (i != 0));
if (finished_block) {
curblock = exit_block;
j = new AST_Jump();
j->target = finished_block;
curblock->connectTo(finished_block, true);
push_back(j);
}
finished_block = test_block;
curblock = body_end;
if (is_innermost) {
push_back(
makeExpr(applyComprehensionCall(node, makeName(rtn_name, AST_TYPE::Load)))
);
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block, true);
push_back(j);
assert(exit_blocks.size());
curblock = exit_blocks[0];
} else {
// continue onto the next comprehension and add to this body
}
}
// Wait until the end to place the end blocks, so that
// we get a nice nesting structure, that looks similar to what
// you'd get with a nested for loop:
for (int i = exit_blocks.size() - 1; i >= 0; i--) {
cfg->placeBlock(exit_blocks[i]);
// printf("Exit block for comp %d is %d\n", i, exit_blocks[i]->idx);
}
return makeName(rtn_name, AST_TYPE::Load);
}
AST_expr* makeNum(int n) { AST_expr* makeNum(int n) {
...@@ -181,6 +321,18 @@ private: ...@@ -181,6 +321,18 @@ private:
return call; return call;
} }
AST_Call* makeCall(AST_expr* func, AST_expr* arg0, AST_expr* arg1) {
AST_Call* call = new AST_Call();
call->args.push_back(arg0);
call->args.push_back(arg1);
call->starargs = NULL;
call->kwargs = NULL;
call->func = func;
call->col_offset = func->col_offset;
call->lineno = func->lineno;
return call;
}
AST_Name* makeName(const std::string& id, AST_TYPE::AST_TYPE ctx_type, int lineno = -1, int col_offset = -1) { AST_Name* makeName(const std::string& id, AST_TYPE::AST_TYPE ctx_type, int lineno = -1, int col_offset = -1) {
AST_Name* name = new AST_Name(); AST_Name* name = new AST_Name();
name->id = id; name->id = id;
...@@ -441,136 +593,6 @@ private: ...@@ -441,136 +593,6 @@ private:
return rtn; return rtn;
} }
AST_expr* remapListComp(AST_ListComp* node) {
std::string rtn_name = nodeName(node);
push_back(makeAssign(rtn_name, new AST_List()));
std::vector<CFGBlock*> exit_blocks;
// Where the current level should jump to after finishing its iteration.
// For the outermost comprehension, this is NULL, and it doesn't jump anywhere;
// for the inner comprehensions, they should jump to the next-outer comprehension
// when they are done iterating.
CFGBlock* finished_block = NULL;
for (int i = 0, n = node->generators.size(); i < n; i++) {
AST_comprehension* c = node->generators[i];
bool is_innermost = (i == n - 1);
AST_expr* remapped_iter = remapExpr(c->iter);
AST_expr* iter_attr = makeLoadAttribute(remapped_iter, "__iter__", true);
AST_expr* iter_call = makeCall(iter_attr);
std::string iter_name = nodeName(node, "iter", i);
AST_stmt* iter_assign = makeAssign(iter_name, iter_call);
push_back(iter_assign);
// TODO bad to save these like this?
AST_expr* hasnext_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "__hasnext__", true);
AST_expr* next_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "next", true);
AST_Jump* j;
CFGBlock* test_block = cfg->addBlock();
test_block->info = "listcomp_test";
// printf("Test block for comp %d is %d\n", i, test_block->idx);
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block);
push_back(j);
curblock = test_block;
AST_expr* test_call = remapExpr(makeCall(hasnext_attr));
CFGBlock* body_block = cfg->addBlock();
body_block->info = "listcomp_body";
CFGBlock* exit_block = cfg->addDeferredBlock();
exit_block->info = "listcomp_exit";
exit_blocks.push_back(exit_block);
// printf("Body block for comp %d is %d\n", i, body_block->idx);
AST_Branch* br = new AST_Branch();
br->col_offset = node->col_offset;
br->lineno = node->lineno;
br->test = test_call;
br->iftrue = body_block;
br->iffalse = exit_block;
curblock->connectTo(body_block);
curblock->connectTo(exit_block);
push_back(br);
curblock = body_block;
push_back(makeAssign(c->target, makeCall(next_attr)));
for (AST_expr* if_condition : c->ifs) {
AST_expr* remapped = remapExpr(if_condition);
AST_Branch* br = new AST_Branch();
br->test = remapped;
push_back(br);
// Put this below the entire body?
CFGBlock* body_tramp = cfg->addBlock();
body_tramp->info = "listcomp_if_trampoline";
// printf("body_tramp for %d is %d\n", i, body_tramp->idx);
CFGBlock* body_continue = cfg->addBlock();
body_continue->info = "listcomp_if_continue";
// printf("body_continue for %d is %d\n", i, body_continue->idx);
br->iffalse = body_tramp;
curblock->connectTo(body_tramp);
br->iftrue = body_continue;
curblock->connectTo(body_continue);
curblock = body_tramp;
j = new AST_Jump();
j->target = test_block;
push_back(j);
curblock->connectTo(test_block, true);
curblock = body_continue;
}
CFGBlock* body_end = curblock;
assert((finished_block != NULL) == (i != 0));
if (finished_block) {
curblock = exit_block;
j = new AST_Jump();
j->target = finished_block;
curblock->connectTo(finished_block, true);
push_back(j);
}
finished_block = test_block;
curblock = body_end;
if (is_innermost) {
AST_expr* elt = remapExpr(node->elt);
push_back(
makeExpr(makeCall(makeLoadAttribute(makeName(rtn_name, AST_TYPE::Load), "append", true), elt)));
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block, true);
push_back(j);
assert(exit_blocks.size());
curblock = exit_blocks[0];
} else {
// continue onto the next comprehension and add to this body
}
}
// Wait until the end to place the end blocks, so that
// we get a nice nesting structure, that looks similar to what
// you'd get with a nested for loop:
for (int i = exit_blocks.size() - 1; i >= 0; i--) {
cfg->placeBlock(exit_blocks[i]);
// printf("Exit block for comp %d is %d\n", i, exit_blocks[i]->idx);
}
return makeName(rtn_name, AST_TYPE::Load);
};
AST_expr* remapRepr(AST_Repr* node) { AST_expr* remapRepr(AST_Repr* node) {
AST_Repr* rtn = new AST_Repr(); AST_Repr* rtn = new AST_Repr();
rtn->lineno = node->lineno; rtn->lineno = node->lineno;
...@@ -651,6 +673,9 @@ private: ...@@ -651,6 +673,9 @@ private:
case AST_TYPE::Dict: case AST_TYPE::Dict:
rtn = remapDict(ast_cast<AST_Dict>(node)); rtn = remapDict(ast_cast<AST_Dict>(node));
break; break;
case AST_TYPE::DictComp:
rtn = remapComprehension<AST_Dict>(ast_cast<AST_DictComp>(node));
break;
case AST_TYPE::IfExp: case AST_TYPE::IfExp:
rtn = remapIfExp(ast_cast<AST_IfExp>(node)); rtn = remapIfExp(ast_cast<AST_IfExp>(node));
break; break;
...@@ -664,7 +689,7 @@ private: ...@@ -664,7 +689,7 @@ private:
rtn = remapList(ast_cast<AST_List>(node)); rtn = remapList(ast_cast<AST_List>(node));
break; break;
case AST_TYPE::ListComp: case AST_TYPE::ListComp:
rtn = remapListComp(ast_cast<AST_ListComp>(node)); rtn = remapComprehension<AST_List>(ast_cast<AST_ListComp>(node));
break; break;
case AST_TYPE::Name: case AST_TYPE::Name:
rtn = node; rtn = node;
......
def dict2str(d):
result = ''
for k, v in sorted(d.items()):
if result:
result += ', '
result += '%s: %s' % (str(k), str(v))
return '{%s}' % result
print dict2str({i: j for i in range(4) for j in range(4)})
def f():
print dict2str({i: j for i in range(4) for j in range(4)})
# print i, j
f()
# Combine a list comprehension with a bunch of other control-flow expressions:
def f(x, y):
# TODO make sure to use an 'if' in a comprehension where the if contains control flow
print dict2str({y if i % 3 else y ** 2 + i: (i if i%2 else i/2) for i in (xrange(4 if x else 5) if y else xrange(3))})
f(0, 0)
f(0, 1)
f(1, 0)
f(1, 1)
# TODO: test on ifs
def f():
print dict2str({i : j for (i, j) in sorted({1:2, 3:4, 5:6, 7:8}.items())})
f()
# The expr should not get evaluated if the if-condition fails:
def f():
def p(i):
print i
return i ** 2
def k(i):
print i
return i * 4 + i
print dict2str({k(i):p(i) for i in xrange(50) if i % 5 == 0 if i % 3 == 0})
f()
def f():
print dict2str({i: j for i in xrange(4) for j in xrange(i)})
f()
def f():
j = 1
# The 'if' part of this list comprehension references j;
# the first time through it will use the j above, but later times
# it may-or-may-not use the j from the inner part of the listcomp.
print dict2str({i: j for i in xrange(7) if i % 2 != j % 2 for j in xrange(i)})
f()
def f():
# Checking the order of evaluation of the if conditions:
def c1(x):
print "c1", x
return x % 2 == 0
def c2(x):
print "c2", x
return x % 3 == 0
print dict2str({i : i for i in xrange(20) if c1(i) if c2(i)})
f()
def control_flow_in_listcomp():
print dict2str({(i ** 2 if i > 5 else i ** 2 * -1):(i if i else -1) for i in (xrange(10) if True else []) if (i % 2 == 0 or i % 3 != 0)})
control_flow_in_listcomp()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment