Commit d1f7e31b authored by Kevin Modzelewski's avatar Kevin Modzelewski

dicts check by identity in addition to equality

This isn't just a performance optimization, but is actually important
since people put things in dicts that don't __eq__ themselves.
The common example (and test here) is NaN, but in particular sqlalchemy
does this with clauses (where __eq__ returns a new sql clause).

Also, handle cases that __eq__ returns non-bool
parent 50d735fc
...@@ -250,10 +250,7 @@ extern "C" PyObject* PyDict_GetItem(PyObject* dict, PyObject* key) noexcept { ...@@ -250,10 +250,7 @@ extern "C" PyObject* PyDict_GetItem(PyObject* dict, PyObject* key) noexcept {
ASSERT(isSubclass(dict->cls, dict_cls) || dict->cls == attrwrapper_cls, "%s", getTypeName(dict)); ASSERT(isSubclass(dict->cls, dict_cls) || dict->cls == attrwrapper_cls, "%s", getTypeName(dict));
if (isSubclass(dict->cls, dict_cls)) { if (isSubclass(dict->cls, dict_cls)) {
BoxedDict* d = static_cast<BoxedDict*>(dict); BoxedDict* d = static_cast<BoxedDict*>(dict);
auto it = d->d.find(key); return d->getOrNull(key);
if (it != d->d.end())
return it->second;
return NULL;
} }
// This path doesn't exist in CPython; we have it to support extension modules that do // This path doesn't exist in CPython; we have it to support extension modules that do
......
...@@ -141,6 +141,9 @@ size_t PyHasher::operator()(Box* b) const { ...@@ -141,6 +141,9 @@ size_t PyHasher::operator()(Box* b) const {
bool PyEq::operator()(Box* lhs, Box* rhs) const { bool PyEq::operator()(Box* lhs, Box* rhs) const {
STAT_TIMER(t0, "us_timer_PyEq"); STAT_TIMER(t0, "us_timer_PyEq");
if (lhs == rhs)
return true;
if (lhs->cls == rhs->cls) { if (lhs->cls == rhs->cls) {
if (lhs->cls == str_cls) { if (lhs->cls == str_cls) {
return static_cast<BoxedString*>(lhs)->s == static_cast<BoxedString*>(rhs)->s; return static_cast<BoxedString*>(lhs)->s == static_cast<BoxedString*>(rhs)->s;
...@@ -149,8 +152,7 @@ bool PyEq::operator()(Box* lhs, Box* rhs) const { ...@@ -149,8 +152,7 @@ bool PyEq::operator()(Box* lhs, Box* rhs) const {
// TODO fix this // TODO fix this
Box* cmp = compareInternal(lhs, rhs, AST_TYPE::Eq, NULL); Box* cmp = compareInternal(lhs, rhs, AST_TYPE::Eq, NULL);
assert(cmp->cls == bool_cls); return cmp->nonzeroIC();
return cmp == True;
} }
bool PyLt::operator()(Box* lhs, Box* rhs) const { bool PyLt::operator()(Box* lhs, Box* rhs) const {
...@@ -158,8 +160,7 @@ bool PyLt::operator()(Box* lhs, Box* rhs) const { ...@@ -158,8 +160,7 @@ bool PyLt::operator()(Box* lhs, Box* rhs) const {
// TODO fix this // TODO fix this
Box* cmp = compareInternal(lhs, rhs, AST_TYPE::Lt, NULL); Box* cmp = compareInternal(lhs, rhs, AST_TYPE::Lt, NULL);
assert(cmp->cls == bool_cls); return cmp->nonzeroIC();
return cmp == True;
} }
extern "C" Box* deopt(AST_expr* expr, Box* value) { extern "C" Box* deopt(AST_expr* expr, Box* value) {
......
...@@ -70,6 +70,7 @@ for fn in test_files: ...@@ -70,6 +70,7 @@ for fn in test_files:
mname = fn[len(SQLALCHEMY_DIR) + 1:-3].replace('/', '.') mname = fn[len(SQLALCHEMY_DIR) + 1:-3].replace('/', '.')
if mname not in MODULES_TO_TEST: if mname not in MODULES_TO_TEST:
continue continue
print
print mname print mname
try: try:
...@@ -91,11 +92,9 @@ for fn in test_files: ...@@ -91,11 +92,9 @@ for fn in test_files:
except Exception: except Exception:
print mname, "FAILED" print mname, "FAILED"
traceback.print_exc() traceback.print_exc()
print
failed.append(mname) failed.append(mname)
else: else:
print mname, "PASSED" print mname, "PASSED"
print
passed.append(mname) passed.append(mname)
print "passing:", passed print "passing:", passed
......
...@@ -96,3 +96,20 @@ class EqOnly(object): ...@@ -96,3 +96,20 @@ class EqOnly(object):
print EqOnly() == 1 print EqOnly() == 1
print EqOnly() != 1 print EqOnly() != 1
class NonboolEq(object):
def __init__(self, n):
self.n = n
def __eq__(self, rhs):
return 2 if self.n == rhs.n else ()
def __hash__(self):
return 0
print NonboolEq(1) == NonboolEq(2)
print NonboolEq(1) == NonboolEq(True)
d = {}
for i in xrange(20):
d[NonboolEq(i % 10)] = i
print len(d), sorted(d.values())
...@@ -28,3 +28,13 @@ d.__setitem__(c2, 2) ...@@ -28,3 +28,13 @@ d.__setitem__(c2, 2)
d.__setitem__(c3, 3) d.__setitem__(c3, 3)
print d print d
# dicts need to check identify and not just equality.
# This is important for sqlalchemy where equality constructs a sql equals clause and doesn't
# do comparison of the objects at hand.
d = {}
nan = float('nan')
d[nan] = "hello world"
print d[nan]
...@@ -129,3 +129,7 @@ for i in xrange(10): ...@@ -129,3 +129,7 @@ for i in xrange(10):
for s1 in set(range(5)), frozenset(range(5)): for s1 in set(range(5)), frozenset(range(5)):
for s2 in compare_to: for s2 in compare_to:
print type(s2), sorted(s2), s.issubset(s2), s.issuperset(s2), s == s2, s != s2, s.difference(s2), s.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2)) print type(s2), sorted(s2), s.issubset(s2), s.issuperset(s2), s == s2, s != s2, s.difference(s2), s.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
f = float('nan')
s = set([f])
print f in s, f == list(s)[0]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment