Commit ffd48c9e authored by Serhiy Storchaka's avatar Serhiy Storchaka

Issue #23268: Fixed bugs in the comparison of ipaddress classes.

parents 34af5023 f186e128
...@@ -382,40 +382,7 @@ def get_mixed_type_key(obj): ...@@ -382,40 +382,7 @@ def get_mixed_type_key(obj):
return NotImplemented return NotImplemented
class _TotalOrderingMixin: class _IPAddressBase:
# Helper that derives the other comparison operations from
# __lt__ and __eq__
# We avoid functools.total_ordering because it doesn't handle
# NotImplemented correctly yet (http://bugs.python.org/issue10042)
def __eq__(self, other):
raise NotImplementedError
def __ne__(self, other):
equal = self.__eq__(other)
if equal is NotImplemented:
return NotImplemented
return not equal
def __lt__(self, other):
raise NotImplementedError
def __le__(self, other):
less = self.__lt__(other)
if less is NotImplemented or not less:
return self.__eq__(other)
return less
def __gt__(self, other):
less = self.__lt__(other)
if less is NotImplemented:
return NotImplemented
equal = self.__eq__(other)
if equal is NotImplemented:
return NotImplemented
return not (less or equal)
def __ge__(self, other):
less = self.__lt__(other)
if less is NotImplemented:
return NotImplemented
return not less
class _IPAddressBase(_TotalOrderingMixin):
"""The mother class.""" """The mother class."""
...@@ -567,6 +534,7 @@ class _IPAddressBase(_TotalOrderingMixin): ...@@ -567,6 +534,7 @@ class _IPAddressBase(_TotalOrderingMixin):
return self.__class__, (str(self),) return self.__class__, (str(self),)
@functools.total_ordering
class _BaseAddress(_IPAddressBase): class _BaseAddress(_IPAddressBase):
"""A generic IP object. """A generic IP object.
...@@ -586,12 +554,11 @@ class _BaseAddress(_IPAddressBase): ...@@ -586,12 +554,11 @@ class _BaseAddress(_IPAddressBase):
return NotImplemented return NotImplemented
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, _BaseAddress):
return NotImplemented
if self._version != other._version: if self._version != other._version:
raise TypeError('%s and %s are not of the same version' % ( raise TypeError('%s and %s are not of the same version' % (
self, other)) self, other))
if not isinstance(other, _BaseAddress):
raise TypeError('%s and %s are not of the same type' % (
self, other))
if self._ip != other._ip: if self._ip != other._ip:
return self._ip < other._ip return self._ip < other._ip
return False return False
...@@ -624,6 +591,7 @@ class _BaseAddress(_IPAddressBase): ...@@ -624,6 +591,7 @@ class _BaseAddress(_IPAddressBase):
return self.__class__, (self._ip,) return self.__class__, (self._ip,)
@functools.total_ordering
class _BaseNetwork(_IPAddressBase): class _BaseNetwork(_IPAddressBase):
"""A generic IP network object. """A generic IP network object.
...@@ -673,12 +641,11 @@ class _BaseNetwork(_IPAddressBase): ...@@ -673,12 +641,11 @@ class _BaseNetwork(_IPAddressBase):
return self._address_class(broadcast + n) return self._address_class(broadcast + n)
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, _BaseNetwork):
return NotImplemented
if self._version != other._version: if self._version != other._version:
raise TypeError('%s and %s are not of the same version' % ( raise TypeError('%s and %s are not of the same version' % (
self, other)) self, other))
if not isinstance(other, _BaseNetwork):
raise TypeError('%s and %s are not of the same type' % (
self, other))
if self.network_address != other.network_address: if self.network_address != other.network_address:
return self.network_address < other.network_address return self.network_address < other.network_address
if self.netmask != other.netmask: if self.netmask != other.netmask:
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import unittest import unittest
import re import re
import contextlib import contextlib
import functools
import operator import operator
import pickle import pickle
import ipaddress import ipaddress
...@@ -552,6 +553,20 @@ class FactoryFunctionErrors(BaseTestCase): ...@@ -552,6 +553,20 @@ class FactoryFunctionErrors(BaseTestCase):
self.assertFactoryError(ipaddress.ip_network, "network") self.assertFactoryError(ipaddress.ip_network, "network")
@functools.total_ordering
class LargestObject:
def __eq__(self, other):
return isinstance(other, LargestObject)
def __lt__(self, other):
return False
@functools.total_ordering
class SmallestObject:
def __eq__(self, other):
return isinstance(other, SmallestObject)
def __gt__(self, other):
return False
class ComparisonTests(unittest.TestCase): class ComparisonTests(unittest.TestCase):
v4addr = ipaddress.IPv4Address(1) v4addr = ipaddress.IPv4Address(1)
...@@ -605,6 +620,28 @@ class ComparisonTests(unittest.TestCase): ...@@ -605,6 +620,28 @@ class ComparisonTests(unittest.TestCase):
self.assertRaises(TypeError, lambda: lhs <= rhs) self.assertRaises(TypeError, lambda: lhs <= rhs)
self.assertRaises(TypeError, lambda: lhs >= rhs) self.assertRaises(TypeError, lambda: lhs >= rhs)
def test_foreign_type_ordering(self):
other = object()
smallest = SmallestObject()
largest = LargestObject()
for obj in self.objects:
with self.assertRaises(TypeError):
obj < other
with self.assertRaises(TypeError):
obj > other
with self.assertRaises(TypeError):
obj <= other
with self.assertRaises(TypeError):
obj >= other
self.assertTrue(obj < largest)
self.assertFalse(obj > largest)
self.assertTrue(obj <= largest)
self.assertFalse(obj >= largest)
self.assertFalse(obj < smallest)
self.assertTrue(obj > smallest)
self.assertFalse(obj <= smallest)
self.assertTrue(obj >= smallest)
def test_mixed_type_key(self): def test_mixed_type_key(self):
# with get_mixed_type_key, you can sort addresses and network. # with get_mixed_type_key, you can sort addresses and network.
v4_ordered = [self.v4addr, self.v4net, self.v4intf] v4_ordered = [self.v4addr, self.v4net, self.v4intf]
...@@ -625,7 +662,7 @@ class ComparisonTests(unittest.TestCase): ...@@ -625,7 +662,7 @@ class ComparisonTests(unittest.TestCase):
v4addr = ipaddress.ip_address('1.1.1.1') v4addr = ipaddress.ip_address('1.1.1.1')
v4net = ipaddress.ip_network('1.1.1.1') v4net = ipaddress.ip_network('1.1.1.1')
v6addr = ipaddress.ip_address('::1') v6addr = ipaddress.ip_address('::1')
v6net = ipaddress.ip_address('::1') v6net = ipaddress.ip_network('::1')
self.assertRaises(TypeError, v4addr.__lt__, v6addr) self.assertRaises(TypeError, v4addr.__lt__, v6addr)
self.assertRaises(TypeError, v4addr.__gt__, v6addr) self.assertRaises(TypeError, v4addr.__gt__, v6addr)
...@@ -1383,10 +1420,10 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -1383,10 +1420,10 @@ class IpaddrUnitTest(unittest.TestCase):
unsorted = [ip4, ip1, ip3, ip2] unsorted = [ip4, ip1, ip3, ip2]
unsorted.sort() unsorted.sort()
self.assertEqual(sorted, unsorted) self.assertEqual(sorted, unsorted)
self.assertRaises(TypeError, ip1.__lt__, self.assertIs(ip1.__lt__(ipaddress.ip_address('10.10.10.0')),
ipaddress.ip_address('10.10.10.0')) NotImplemented)
self.assertRaises(TypeError, ip2.__lt__, self.assertIs(ip2.__lt__(ipaddress.ip_address('10.10.10.0')),
ipaddress.ip_address('10.10.10.0')) NotImplemented)
# <=, >= # <=, >=
self.assertTrue(ipaddress.ip_network('1.1.1.1') <= self.assertTrue(ipaddress.ip_network('1.1.1.1') <=
......
...@@ -218,6 +218,8 @@ Core and Builtins ...@@ -218,6 +218,8 @@ Core and Builtins
Library Library
------- -------
- Issue #23268: Fixed bugs in the comparison of ipaddress classes.
- Issue #21408: Removed incorrect implementations of __ne__() which didn't - Issue #21408: Removed incorrect implementations of __ne__() which didn't
returned NotImplemented if __eq__() returned NotImplemented. The default returned NotImplemented if __eq__() returned NotImplemented. The default
__ne__() now works correctly. __ne__() now works correctly.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment