Commit 23f9fc34 authored by Raymond Hettinger's avatar Raymond Hettinger

Issue #10042: Fixed the total_ordering decorator to handle cross-type

comparisons that could lead to infinite recursion.
parent 06ec45e2
...@@ -68,17 +68,17 @@ def wraps(wrapped, ...@@ -68,17 +68,17 @@ def wraps(wrapped,
def total_ordering(cls): def total_ordering(cls):
"""Class decorator that fills in missing ordering methods""" """Class decorator that fills in missing ordering methods"""
convert = { convert = {
'__lt__': [('__gt__', lambda self, other: other < self), '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
('__le__', lambda self, other: not other < self), ('__le__', lambda self, other: self < other or self == other),
('__ge__', lambda self, other: not self < other)], ('__ge__', lambda self, other: not self < other)],
'__le__': [('__ge__', lambda self, other: other <= self), '__le__': [('__ge__', lambda self, other: not self <= other or self == other),
('__lt__', lambda self, other: not other <= self), ('__lt__', lambda self, other: self <= other and not self == other),
('__gt__', lambda self, other: not self <= other)], ('__gt__', lambda self, other: not self <= other)],
'__gt__': [('__lt__', lambda self, other: other > self), '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
('__ge__', lambda self, other: not other > self), ('__ge__', lambda self, other: self > other or self == other),
('__le__', lambda self, other: not self > other)], ('__le__', lambda self, other: not self > other)],
'__ge__': [('__le__', lambda self, other: other >= self), '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
('__gt__', lambda self, other: not other >= self), ('__gt__', lambda self, other: self >= other and not self == other),
('__lt__', lambda self, other: not self >= other)] ('__lt__', lambda self, other: not self >= other)]
} }
# Find user-defined comparisons (not those inherited from object). # Find user-defined comparisons (not those inherited from object).
......
...@@ -457,6 +457,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -457,6 +457,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __lt__(self, other): def __lt__(self, other):
return self.value < other.value return self.value < other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -471,6 +473,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -471,6 +473,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __le__(self, other): def __le__(self, other):
return self.value <= other.value return self.value <= other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -485,6 +489,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -485,6 +489,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __gt__(self, other): def __gt__(self, other):
return self.value > other.value return self.value > other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -499,6 +505,8 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -499,6 +505,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value self.value = value
def __ge__(self, other): def __ge__(self, other):
return self.value >= other.value return self.value >= other.value
def __eq__(self, other):
return self.value == other.value
self.assertTrue(A(1) < A(2)) self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1)) self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2)) self.assertTrue(A(1) <= A(2))
...@@ -524,6 +532,22 @@ class TestTotalOrdering(unittest.TestCase): ...@@ -524,6 +532,22 @@ class TestTotalOrdering(unittest.TestCase):
class A: class A:
pass pass
def test_bug_10042(self):
@functools.total_ordering
class TestTO:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, TestTO):
return self.value == other.value
return False
def __lt__(self, other):
if isinstance(other, TestTO):
return self.value < other.value
raise TypeError
with self.assertRaises(TypeError):
TestTO(8) <= ()
class TestLRU(unittest.TestCase): class TestLRU(unittest.TestCase):
def test_lru(self): def test_lru(self):
......
...@@ -700,6 +700,7 @@ Bernhard Reiter ...@@ -700,6 +700,7 @@ Bernhard Reiter
Steven Reiz Steven Reiz
Roeland Rengelink Roeland Rengelink
Tim Rice Tim Rice
Francesco Ricciardi
Jan Pieter Riegel Jan Pieter Riegel
Armin Rigo Armin Rigo
Nicholas Riley Nicholas Riley
......
...@@ -40,6 +40,9 @@ Core and Builtins ...@@ -40,6 +40,9 @@ Core and Builtins
Library Library
------- -------
- Issue #10042: Fixed the total_ordering decorator to handle cross-type
comparisons that could lead to infinite recursion.
- Issue #10686: the email package now :rfc:`2047`\ -encodes headers with - Issue #10686: the email package now :rfc:`2047`\ -encodes headers with
non-ASCII bytes (parsed by a Bytes Parser) when doing conversion to non-ASCII bytes (parsed by a Bytes Parser) when doing conversion to
7bit-clean presentation, instead of replacing them with ?s. 7bit-clean presentation, instead of replacing them with ?s.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment