Commit 71195a57 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Merge pull request #1145 from undingen/fix_set_insert_order

fix set.add() for existing keys and fix set ast node when encountering keys with same hashes
parents bacf0cfa aea4f11f
...@@ -1590,20 +1590,20 @@ Value ASTInterpreter::visit_dict(AST_Dict* node) { ...@@ -1590,20 +1590,20 @@ Value ASTInterpreter::visit_dict(AST_Dict* node) {
} }
Value ASTInterpreter::visit_set(AST_Set* node) { Value ASTInterpreter::visit_set(AST_Set* node) {
llvm::SmallVector<RewriterVar*, 8> items; try {
// insert the elements in reverse like cpython does
BoxedSet::Set set; // important for {1, 1L}
for (AST_expr* e : node->elts) { llvm::SmallVector<RewriterVar*, 8> items;
Value v = visit_expr(e); BoxedSet* set = (BoxedSet*)createSet();
auto&& p = set.insert(v.o); for (auto it = node->elts.rbegin(), it_end = node->elts.rend(); it != it_end; ++it) {
if (!p.second /* already exists */) { Value v = visit_expr(*it);
Py_DECREF(p.first->value); _setAddStolen(set, v.o);
*p.first = v.o; items.push_back(v);
} }
items.push_back(v); return Value(set, jit ? jit->emitCreateSet(items) : NULL);
} catch (ExcInfo e) {
RELEASE_ASSERT(0, "this leaks in case of an exception");
} }
return Value(new BoxedSet(std::move(set)), jit ? jit->emitCreateSet(items) : NULL);
} }
Value ASTInterpreter::visit_str(AST_Str* node) { Value ASTInterpreter::visit_str(AST_Str* node) {
......
...@@ -871,15 +871,15 @@ Box* JitFragmentWriter::createListHelper(uint64_t num, Box** data) { ...@@ -871,15 +871,15 @@ Box* JitFragmentWriter::createListHelper(uint64_t num, Box** data) {
} }
Box* JitFragmentWriter::createSetHelper(uint64_t num, Box** data) { Box* JitFragmentWriter::createSetHelper(uint64_t num, Box** data) {
BoxedSet* set = (BoxedSet*)createSet(); try {
for (int i = 0; i < num; ++i) { BoxedSet* set = (BoxedSet*)createSet();
auto&& p = set->s.insert(data[i]); for (int i = 0; i < num; ++i) {
if (!p.second /* already exists */) { _setAddStolen(set, data[i]);
Py_DECREF(p.first->value);
*p.first = data[i];
} }
return set;
} catch (ExcInfo e) {
RELEASE_ASSERT(0, "this leaks in case of an exception");
} }
return set;
} }
Box* JitFragmentWriter::createTupleHelper(uint64_t num, Box** data) { Box* JitFragmentWriter::createTupleHelper(uint64_t num, Box** data) {
......
...@@ -1445,8 +1445,10 @@ private: ...@@ -1445,8 +1445,10 @@ private:
static BoxedString* add_str = getStaticString("add"); static BoxedString* add_str = getStaticString("add");
for (int i = 0; i < node->elts.size(); i++) { // insert the elements in reverse like cpython does
CompilerVariable* elt = elts[i]; // important for {1, 1L}
for (auto it = elts.rbegin(), it_end = elts.rend(); it != it_end; ++it) {
CompilerVariable* elt = *it;
CallattrFlags flags = {.cls_only = true, .null_on_nonexistent = false, .argspec = ArgPassSpec(1) }; CallattrFlags flags = {.cls_only = true, .null_on_nonexistent = false, .argspec = ArgPassSpec(1) };
CompilerVariable* r CompilerVariable* r
= rtn->callattr(emitter, getOpInfoForNode(node, unw_info), add_str, flags, { elt }, NULL); = rtn->callattr(emitter, getOpInfoForNode(node, unw_info), add_str, flags, { elt }, NULL);
......
...@@ -24,6 +24,14 @@ extern "C" Box* createSet() { ...@@ -24,6 +24,14 @@ extern "C" Box* createSet() {
return new BoxedSet(); return new BoxedSet();
} }
void _setAddStolen(BoxedSet* self, STOLEN(BoxAndHash) val) {
auto&& p = self->s.insert(val);
if (!p.second /* already exists */) {
// keep the original key
Py_DECREF(val.value);
}
}
namespace set { namespace set {
class BoxedSetIterator : public Box { class BoxedSetIterator : public Box {
...@@ -89,14 +97,6 @@ Box* setiteratorIter(BoxedSetIterator* self) { ...@@ -89,14 +97,6 @@ Box* setiteratorIter(BoxedSetIterator* self) {
return incref(self); return incref(self);
} }
static void _setAddStolen(BoxedSet* self, STOLEN(BoxAndHash) val) {
auto&& p = self->s.insert(val);
if (!p.second /* already exists */) {
Py_DECREF(p.first->value);
*p.first = val;
}
}
static void _setAdd(BoxedSet* self, BoxAndHash val) { static void _setAdd(BoxedSet* self, BoxAndHash val) {
Py_INCREF(val.value); Py_INCREF(val.value);
_setAddStolen(self, val); _setAddStolen(self, val);
......
...@@ -41,6 +41,8 @@ public: ...@@ -41,6 +41,8 @@ public:
static int traverse(Box* self, visitproc visit, void* arg) noexcept; static int traverse(Box* self, visitproc visit, void* arg) noexcept;
static int clear(Box* self) noexcept; static int clear(Box* self) noexcept;
}; };
void _setAddStolen(BoxedSet* self, STOLEN(BoxAndHash) val);
} }
#endif #endif
d = {2:"should get overwritten", 2:2} d = {2:"should get overwritten", 2L:2}
d[1] = 1 d[1] = 1
print d print d
print d[1], d[1L], d[1.0], d[True] print d[1], d[1L], d[1.0], d[True]
......
...@@ -260,3 +260,9 @@ s.remove(1L) ...@@ -260,3 +260,9 @@ s.remove(1L)
s = set([1, 2, 3, 4]) s = set([1, 2, 3, 4])
s2 = set([3L, 4L, 5L, 6L]) s2 = set([3L, 4L, 5L, 6L])
s.symmetric_difference_update(s2) s.symmetric_difference_update(s2)
# make sure we are inserting the tuple elements in reverse:
print {1, 1L}, {1L, 1}, set([1, 1L]), set([1L, 1])
s = {1}
s.add(1L)
print s
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment