Commit 662db125 authored by Serhiy Storchaka's avatar Serhiy Storchaka Committed by GitHub

bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. (GH-14952)

They now return NotImplemented for unsupported type of the other operand.
parent 4c69be22
......@@ -119,20 +119,24 @@ class TimerHandle(Handle):
return hash(self._when)
def __lt__(self, other):
return self._when < other._when
if isinstance(other, TimerHandle):
return self._when < other._when
return NotImplemented
def __le__(self, other):
if self._when < other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when < other._when or self.__eq__(other)
return NotImplemented
def __gt__(self, other):
return self._when > other._when
if isinstance(other, TimerHandle):
return self._when > other._when
return NotImplemented
def __ge__(self, other):
if self._when > other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when > other._when or self.__eq__(other)
return NotImplemented
def __eq__(self, other):
if isinstance(other, TimerHandle):
......@@ -142,10 +146,6 @@ class TimerHandle(Handle):
self._cancelled == other._cancelled)
return NotImplemented
def __ne__(self, other):
equal = self.__eq__(other)
return NotImplemented if equal is NotImplemented else not equal
def cancel(self):
if not self._cancelled:
self._loop._timer_handle_cancelled(self)
......
......@@ -45,6 +45,14 @@ class VersionTestCase(unittest.TestCase):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))
def test_cmp(self):
......@@ -63,6 +71,14 @@ class VersionTestCase(unittest.TestCase):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))
def test_suite():
return unittest.makeSuite(VersionTestCase)
......
......@@ -166,6 +166,8 @@ class StrictVersion (Version):
def _cmp (self, other):
if isinstance(other, str):
other = StrictVersion(other)
elif not isinstance(other, StrictVersion):
return NotImplemented
if self.version != other.version:
# numeric versions don't match
......@@ -331,6 +333,8 @@ class LooseVersion (Version):
def _cmp (self, other):
if isinstance(other, str):
other = LooseVersion(other)
elif not isinstance(other, LooseVersion):
return NotImplemented
if self.version == other.version:
return 0
......
......@@ -97,8 +97,8 @@ class Address:
return self.addr_spec
def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Address):
return NotImplemented
return (self.display_name == other.display_name and
self.username == other.username and
self.domain == other.domain)
......@@ -150,8 +150,8 @@ class Group:
return "{}:{};".format(disp, adrstr)
def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Group):
return NotImplemented
return (self.display_name == other.display_name and
self.addresses == other.addresses)
......
......@@ -371,7 +371,7 @@ class ModuleSpec:
self.cached == other.cached and
self.has_location == other.has_location)
except AttributeError:
return False
return NotImplemented
@property
def cached(self):
......
......@@ -32,6 +32,7 @@ from asyncio import proactor_events
from asyncio import selector_events
from test.test_asyncio import utils as test_utils
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
def tearDownModule():
......@@ -2364,6 +2365,28 @@ class TimerTests(unittest.TestCase):
self.assertIs(NotImplemented, h1.__eq__(h3))
self.assertIs(NotImplemented, h1.__ne__(h3))
with self.assertRaises(TypeError):
h1 < ()
with self.assertRaises(TypeError):
h1 > ()
with self.assertRaises(TypeError):
h1 <= ()
with self.assertRaises(TypeError):
h1 >= ()
self.assertFalse(h1 == ())
self.assertTrue(h1 != ())
self.assertTrue(h1 == ALWAYS_EQ)
self.assertFalse(h1 != ALWAYS_EQ)
self.assertTrue(h1 < LARGEST)
self.assertFalse(h1 > LARGEST)
self.assertTrue(h1 <= LARGEST)
self.assertFalse(h1 >= LARGEST)
self.assertFalse(h1 < SMALLEST)
self.assertTrue(h1 > SMALLEST)
self.assertFalse(h1 <= SMALLEST)
self.assertTrue(h1 >= SMALLEST)
class AbstractEventLoopTests(unittest.TestCase):
......
......@@ -7,6 +7,7 @@ from email.message import Message
from test.test_email import TestEmailBase, parameterize
from email import headerregistry
from email.headerregistry import Address, Group
from test.support import ALWAYS_EQ
DITTO = object()
......@@ -1525,6 +1526,24 @@ class TestAddressAndGroup(TestEmailBase):
self.assertEqual(m['to'], 'foo bar:;')
self.assertEqual(m['to'].addresses, g.addresses)
def test_address_comparison(self):
a = Address('foo', 'bar', 'example.com')
self.assertEqual(Address('foo', 'bar', 'example.com'), a)
self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
self.assertFalse(a == object())
self.assertTrue(a == ALWAYS_EQ)
def test_group_comparison(self):
a = Address('foo', 'bar', 'example.com')
g = Group('foo bar', [a])
self.assertEqual(Group('foo bar', (a,)), g)
self.assertNotEqual(Group('baz', [a]), g)
self.assertNotEqual(Group('foo bar', []), g)
self.assertFalse(g == object())
self.assertTrue(g == ALWAYS_EQ)
class TestFolding(TestHeaderBase):
......
......@@ -7,7 +7,7 @@ import sys
import unittest
import re
from test import support
from test.support import TESTFN, Error, captured_output, unlink, cpython_only
from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ
from test.support.script_helper import assert_python_ok
import textwrap
......@@ -887,6 +887,8 @@ class TestFrame(unittest.TestCase):
# operator fallbacks to FrameSummary.__eq__.
self.assertEqual(tuple(f), f)
self.assertIsNone(f.locals)
self.assertNotEqual(f, object())
self.assertEqual(f, ALWAYS_EQ)
def test_lazy_lines(self):
linecache.clearcache()
......@@ -1083,6 +1085,18 @@ class TestTracebackException(unittest.TestCase):
self.assertEqual(exc_info[0], exc.exc_type)
self.assertEqual(str(exc_info[1]), str(exc))
def test_comparison(self):
try:
1/0
except Exception:
exc_info = sys.exc_info()
exc = traceback.TracebackException(*exc_info)
exc2 = traceback.TracebackException(*exc_info)
self.assertIsNot(exc, exc2)
self.assertEqual(exc, exc2)
self.assertNotEqual(exc, object())
self.assertEqual(exc, ALWAYS_EQ)
def test_unhashable(self):
class UnhashableException(Exception):
def __eq__(self, other):
......
......@@ -11,7 +11,7 @@ import time
import random
from test import support
from test.support import script_helper
from test.support import script_helper, ALWAYS_EQ
# Used in ReferencesTestCase.test_ref_created_during_del() .
ref_from_del = None
......@@ -794,6 +794,10 @@ class ReferencesTestCase(TestBase):
self.assertTrue(a != c)
self.assertTrue(a == d)
self.assertFalse(a != d)
self.assertFalse(a == x)
self.assertTrue(a != x)
self.assertTrue(a == ALWAYS_EQ)
self.assertFalse(a != ALWAYS_EQ)
del x, y, z
gc.collect()
for r in a, b, c:
......@@ -1102,6 +1106,9 @@ class WeakMethodTestCase(unittest.TestCase):
_ne(a, f)
_ne(b, e)
_ne(b, f)
# Compare with different types
_ne(a, x.some_method)
_eq(a, ALWAYS_EQ)
del x, y, z
gc.collect()
# Dead WeakMethods compare by identity
......
......@@ -15,6 +15,7 @@ import re
import io
import contextlib
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
try:
import gzip
......@@ -530,14 +531,10 @@ class DateTimeTestCase(unittest.TestCase):
# some other types
dbytes = dstr.encode('ascii')
dtuple = now.timetuple()
with self.assertRaises(TypeError):
dtime == 1970
with self.assertRaises(TypeError):
dtime != dbytes
with self.assertRaises(TypeError):
dtime == bytearray(dbytes)
with self.assertRaises(TypeError):
dtime != dtuple
self.assertFalse(dtime == 1970)
self.assertTrue(dtime != dbytes)
self.assertFalse(dtime == bytearray(dbytes))
self.assertTrue(dtime != dtuple)
with self.assertRaises(TypeError):
dtime < float(1970)
with self.assertRaises(TypeError):
......@@ -547,6 +544,18 @@ class DateTimeTestCase(unittest.TestCase):
with self.assertRaises(TypeError):
dtime >= dtuple
self.assertTrue(dtime == ALWAYS_EQ)
self.assertFalse(dtime != ALWAYS_EQ)
self.assertTrue(dtime < LARGEST)
self.assertFalse(dtime > LARGEST)
self.assertTrue(dtime <= LARGEST)
self.assertFalse(dtime >= LARGEST)
self.assertFalse(dtime < SMALLEST)
self.assertTrue(dtime > SMALLEST)
self.assertFalse(dtime <= SMALLEST)
self.assertTrue(dtime >= SMALLEST)
class BinaryTestCase(unittest.TestCase):
# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"
......
......@@ -484,6 +484,8 @@ class Variable:
Note: if the Variable's master matters to behavior
also compare self._master == other._master
"""
if not isinstance(other, Variable):
return NotImplemented
return self.__class__.__name__ == other.__class__.__name__ \
and self._name == other._name
......
......@@ -101,7 +101,9 @@ class Font:
return self.name
def __eq__(self, other):
return isinstance(other, Font) and self.name == other.name
if not isinstance(other, Font):
return NotImplemented
return self.name == other.name
def __getitem__(self, key):
return self.cget(key)
......
import unittest
import tkinter
from tkinter import font
from test.support import requires, run_unittest, gc_collect
from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ
from tkinter.test.support import AbstractTkTest
requires('gui')
......@@ -70,6 +70,7 @@ class FontTest(AbstractTkTest, unittest.TestCase):
self.assertEqual(font1, font2)
self.assertNotEqual(font1, font1.copy())
self.assertNotEqual(font1, 0)
self.assertEqual(font1, ALWAYS_EQ)
def test_measure(self):
self.assertIsInstance(self.font.measure('abc'), int)
......
......@@ -2,6 +2,7 @@ import unittest
import gc
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
TclError)
from test.support import ALWAYS_EQ
class Var(Variable):
......@@ -59,11 +60,17 @@ class TestVariable(TestBase):
# values doesn't matter, only class and name are checked
v1 = Variable(self.root, name="abc")
v2 = Variable(self.root, name="abc")
self.assertIsNot(v1, v2)
self.assertEqual(v1, v2)
v3 = Variable(self.root, name="abc")
v4 = StringVar(self.root, name="abc")
self.assertNotEqual(v3, v4)
v3 = StringVar(self.root, name="abc")
self.assertNotEqual(v1, v3)
V = type('Variable', (), {})
self.assertNotEqual(v1, V())
self.assertNotEqual(v1, object())
self.assertEqual(v1, ALWAYS_EQ)
def test_invalid_name(self):
with self.assertRaises(TypeError):
......
......@@ -538,7 +538,9 @@ class TracebackException:
self.__cause__._load_lines()
def __eq__(self, other):
return self.__dict__ == other.__dict__
if isinstance(other, TracebackException):
return self.__dict__ == other.__dict__
return NotImplemented
def __str__(self):
return self._str
......
......@@ -43,6 +43,8 @@ class Statistic:
return hash((self.traceback, self.size, self.count))
def __eq__(self, other):
if not isinstance(other, Statistic):
return NotImplemented
return (self.traceback == other.traceback
and self.size == other.size
and self.count == other.count)
......@@ -84,6 +86,8 @@ class StatisticDiff:
self.count, self.count_diff))
def __eq__(self, other):
if not isinstance(other, StatisticDiff):
return NotImplemented
return (self.traceback == other.traceback
and self.size == other.size
and self.size_diff == other.size_diff
......@@ -153,9 +157,13 @@ class Frame:
return self._frame[1]
def __eq__(self, other):
if not isinstance(other, Frame):
return NotImplemented
return (self._frame == other._frame)
def __lt__(self, other):
if not isinstance(other, Frame):
return NotImplemented
return (self._frame < other._frame)
def __hash__(self):
......@@ -200,9 +208,13 @@ class Traceback(Sequence):
return hash(self._frames)
def __eq__(self, other):
if not isinstance(other, Traceback):
return NotImplemented
return (self._frames == other._frames)
def __lt__(self, other):
if not isinstance(other, Traceback):
return NotImplemented
return (self._frames < other._frames)
def __str__(self):
......@@ -271,6 +283,8 @@ class Trace:
return Traceback(self._trace[2])
def __eq__(self, other):
if not isinstance(other, Trace):
return NotImplemented
return (self._trace == other._trace)
def __hash__(self):
......@@ -303,6 +317,8 @@ class _Traces(Sequence):
return trace._trace in self._traces
def __eq__(self, other):
if not isinstance(other, _Traces):
return NotImplemented
return (self._traces == other._traces)
def __repr__(self):
......
......@@ -2358,12 +2358,10 @@ class _Call(tuple):
def __eq__(self, other):
if other is ANY:
return True
try:
len_other = len(other)
except TypeError:
return False
return NotImplemented
self_name = ''
if len(self) == 2:
......
......@@ -3,6 +3,7 @@ import re
import sys
import tempfile
from test.support import ALWAYS_EQ
import unittest
from unittest.test.testmock.support import is_instance
from unittest import mock
......@@ -322,6 +323,8 @@ class MockTest(unittest.TestCase):
self.assertFalse(mm != mock.ANY)
self.assertTrue(mock.ANY == mm)
self.assertFalse(mock.ANY != mm)
self.assertTrue(mm == ALWAYS_EQ)
self.assertFalse(mm != ALWAYS_EQ)
call1 = mock.call(mock.MagicMock())
call2 = mock.call(mock.ANY)
......@@ -330,6 +333,11 @@ class MockTest(unittest.TestCase):
self.assertTrue(call2 == call1)
self.assertFalse(call2 != call1)
self.assertTrue(call1 == ALWAYS_EQ)
self.assertFalse(call1 != ALWAYS_EQ)
self.assertFalse(call1 == 1)
self.assertTrue(call1 != 1)
def test_assert_called_with(self):
mock = Mock()
......
......@@ -75,14 +75,14 @@ class WeakMethod(ref):
if not self._alive or not other._alive:
return self is other
return ref.__eq__(self, other) and self._func_ref == other._func_ref
return False
return NotImplemented
def __ne__(self, other):
if isinstance(other, WeakMethod):
if not self._alive or not other._alive:
return self is not other
return ref.__ne__(self, other) or self._func_ref != other._func_ref
return True
return NotImplemented
__hash__ = ref.__hash__
......
......@@ -313,31 +313,38 @@ class DateTime:
s = self.timetuple()
o = other.timetuple()
else:
otype = (hasattr(other, "__class__")
and other.__class__.__name__
or type(other))
raise TypeError("Can't compare %s and %s" %
(self.__class__.__name__, otype))
s = self
o = NotImplemented
return s, o
def __lt__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s < o
def __le__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s <= o
def __gt__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s > o
def __ge__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s >= o
def __eq__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s == o
def timetuple(self):
......
Fixed ``__eq__``, ``__lt__`` etc implementations in some classes. They now
return :data:`NotImplemented` for unsupported type of the other operand.
This allows the other operand to play role (for example the equality
comparison with :data:`~unittest.mock.ANY` will return ``True``).
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -281,10 +281,14 @@ class PopupViewer:
self.__window.deiconify()
def __eq__(self, other):
return self.__menutext == other.__menutext
if isinstance(self, PopupViewer):
return self.__menutext == other.__menutext
return NotImplemented
def __lt__(self, other):
return self.__menutext < other.__menutext
if isinstance(self, PopupViewer):
return self.__menutext < other.__menutext
return NotImplemented
def make_view_popups(switchboard, root, extrapath):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment