Commit d1f7e31b authored by Kevin Modzelewski's avatar Kevin Modzelewski

dicts check by identity in addition to equality

This isn't just a performance optimization, but is actually important
since people put things in dicts that don't __eq__ themselves.
The common example (and test here) is NaN, but in particular sqlalchemy
does this with clauses (where __eq__ returns a new sql clause).

Also, handle cases that __eq__ returns non-bool
parent 50d735fc
......@@ -250,10 +250,7 @@ extern "C" PyObject* PyDict_GetItem(PyObject* dict, PyObject* key) noexcept {
ASSERT(isSubclass(dict->cls, dict_cls) || dict->cls == attrwrapper_cls, "%s", getTypeName(dict));
if (isSubclass(dict->cls, dict_cls)) {
BoxedDict* d = static_cast<BoxedDict*>(dict);
auto it = d->d.find(key);
if (it != d->d.end())
return it->second;
return NULL;
return d->getOrNull(key);
}
// This path doesn't exist in CPython; we have it to support extension modules that do
......
......@@ -141,6 +141,9 @@ size_t PyHasher::operator()(Box* b) const {
bool PyEq::operator()(Box* lhs, Box* rhs) const {
STAT_TIMER(t0, "us_timer_PyEq");
if (lhs == rhs)
return true;
if (lhs->cls == rhs->cls) {
if (lhs->cls == str_cls) {
return static_cast<BoxedString*>(lhs)->s == static_cast<BoxedString*>(rhs)->s;
......@@ -149,8 +152,7 @@ bool PyEq::operator()(Box* lhs, Box* rhs) const {
// TODO fix this
Box* cmp = compareInternal(lhs, rhs, AST_TYPE::Eq, NULL);
assert(cmp->cls == bool_cls);
return cmp == True;
return cmp->nonzeroIC();
}
bool PyLt::operator()(Box* lhs, Box* rhs) const {
......@@ -158,8 +160,7 @@ bool PyLt::operator()(Box* lhs, Box* rhs) const {
// TODO fix this
Box* cmp = compareInternal(lhs, rhs, AST_TYPE::Lt, NULL);
assert(cmp->cls == bool_cls);
return cmp == True;
return cmp->nonzeroIC();
}
extern "C" Box* deopt(AST_expr* expr, Box* value) {
......
......@@ -70,6 +70,7 @@ for fn in test_files:
mname = fn[len(SQLALCHEMY_DIR) + 1:-3].replace('/', '.')
if mname not in MODULES_TO_TEST:
continue
print
print mname
try:
......@@ -91,11 +92,9 @@ for fn in test_files:
except Exception:
print mname, "FAILED"
traceback.print_exc()
print
failed.append(mname)
else:
print mname, "PASSED"
print
passed.append(mname)
print "passing:", passed
......
......@@ -96,3 +96,20 @@ class EqOnly(object):
print EqOnly() == 1
print EqOnly() != 1
class NonboolEq(object):
def __init__(self, n):
self.n = n
def __eq__(self, rhs):
return 2 if self.n == rhs.n else ()
def __hash__(self):
return 0
print NonboolEq(1) == NonboolEq(2)
print NonboolEq(1) == NonboolEq(True)
d = {}
for i in xrange(20):
d[NonboolEq(i % 10)] = i
print len(d), sorted(d.values())
......@@ -28,3 +28,13 @@ d.__setitem__(c2, 2)
d.__setitem__(c3, 3)
print d
# dicts need to check identify and not just equality.
# This is important for sqlalchemy where equality constructs a sql equals clause and doesn't
# do comparison of the objects at hand.
d = {}
nan = float('nan')
d[nan] = "hello world"
print d[nan]
......@@ -129,3 +129,7 @@ for i in xrange(10):
for s1 in set(range(5)), frozenset(range(5)):
for s2 in compare_to:
print type(s2), sorted(s2), s.issubset(s2), s.issuperset(s2), s == s2, s != s2, s.difference(s2), s.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
f = float('nan')
s = set([f])
print f in s, f == list(s)[0]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment