Commit f57db823 authored by Kevin Modzelewski's avatar Kevin Modzelewski

tuple.contains tests identity not just equality

parent f3e03b35
......@@ -2074,6 +2074,23 @@ ConcreteCompilerVariable* makeBool(bool b) {
return new ConcreteCompilerVariable(BOOL, llvm::ConstantInt::get(BOOL->llvmType(), b, false), true);
}
ConcreteCompilerVariable* doIs(IREmitter& emitter, CompilerVariable* lhs, CompilerVariable* rhs, bool negate) {
// TODO: I think we can do better here and not force the types to box themselves
ConcreteCompilerVariable* converted_left = lhs->makeConverted(emitter, UNKNOWN);
ConcreteCompilerVariable* converted_right = rhs->makeConverted(emitter, UNKNOWN);
llvm::Value* cmp;
if (!negate)
cmp = emitter.getBuilder()->CreateICmpEQ(converted_left->getValue(), converted_right->getValue());
else
cmp = emitter.getBuilder()->CreateICmpNE(converted_left->getValue(), converted_right->getValue());
converted_left->decvref(emitter);
converted_right->decvref(emitter);
return boolFromI1(emitter, cmp);
}
ConcreteCompilerType* BOXED_TUPLE;
class TupleType : public ValuedCompilerType<const std::vector<CompilerVariable*>*> {
private:
......@@ -2228,21 +2245,35 @@ public:
llvm::BasicBlock* end = emitter.createBasicBlock();
ConcreteCompilerVariable* converted_lhs = lhs->makeConverted(emitter, lhs->getConcreteType());
for (CompilerVariable* e : *var->getValue()) {
CompilerVariable* eq = lhs->binexp(emitter, info, e, AST_TYPE::Eq, Compare);
// TODO: we could potentially avoid the identity tests if we know that either type has
// an __eq__ that is reflexive (returns True for the same object).
{
ConcreteCompilerVariable* is_same = doIs(emitter, converted_lhs, e, false);
llvm::Value* raw = i1FromBool(emitter, is_same);
phi_incoming.push_back(std::make_pair(emitter.currentBasicBlock(), getConstantInt(1, g.i1)));
llvm::BasicBlock* new_bb = emitter.createBasicBlock();
new_bb->moveAfter(emitter.currentBasicBlock());
emitter.getBuilder()->CreateCondBr(raw, end, new_bb);
emitter.setCurrentBasicBlock(new_bb);
}
{
CompilerVariable* eq = converted_lhs->binexp(emitter, info, e, AST_TYPE::Eq, Compare);
ConcreteCompilerVariable* eq_nonzero = eq->nonzero(emitter, info);
assert(eq_nonzero->getType() == BOOL);
llvm::Value* raw = i1FromBool(emitter, eq_nonzero);
phi_incoming.push_back(std::make_pair(emitter.currentBasicBlock(), getConstantInt(1, g.i1)));
llvm::BasicBlock* new_bb = emitter.createBasicBlock();
new_bb->moveAfter(emitter.currentBasicBlock());
emitter.getBuilder()->CreateCondBr(raw, end, new_bb);
emitter.setCurrentBasicBlock(new_bb);
}
}
// TODO This last block is unnecessary:
phi_incoming.push_back(std::make_pair(emitter.currentBasicBlock(), getConstantInt(0, g.i1)));
......@@ -2255,6 +2286,9 @@ public:
for (auto p : phi_incoming) {
phi->addIncoming(p.second, p.first);
}
converted_lhs->decvref(emitter);
return boolFromI1(emitter, phi);
}
......
......@@ -396,6 +396,9 @@ public:
// assert(value->getType() == type->llvmType());
//}
// Emit the test for whether one variable 'is' another one.
ConcreteCompilerVariable* doIs(IREmitter& emitter, CompilerVariable* lhs, CompilerVariable* rhs, bool negate);
ConcreteCompilerVariable* makeBool(bool);
ConcreteCompilerVariable* makeInt(int64_t);
ConcreteCompilerVariable* makeFloat(double);
......
......@@ -816,17 +816,7 @@ private:
assert(right);
if (node->ops[0] == AST_TYPE::Is || node->ops[0] == AST_TYPE::IsNot) {
// TODO: I think we can do better here and not force the types to box themselves
ConcreteCompilerVariable* converted_left = left->makeConverted(emitter, UNKNOWN);
ConcreteCompilerVariable* converted_right = right->makeConverted(emitter, UNKNOWN);
llvm::Value* cmp;
if (node->ops[0] == AST_TYPE::Is)
cmp = emitter.getBuilder()->CreateICmpEQ(converted_left->getValue(), converted_right->getValue());
else
cmp = emitter.getBuilder()->CreateICmpNE(converted_left->getValue(), converted_right->getValue());
return boolFromI1(emitter, cmp);
return doIs(emitter, left, right, node->ops[0] == AST_TYPE::IsNot);
}
CompilerVariable* rtn = _evalBinExp(node, left, right, node->ops[0], Compare, unw_info);
......
......@@ -221,3 +221,6 @@ try:
print (1, 3, 5, 3).index(2)
except ValueError as e:
print e
n = float('nan')
print n in (n, n)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment