Commit e89bbb2d authored by Nick Coghlan's avatar Nick Coghlan

Issue #14814: In the spirit of TOOWTDI, ditch the redundant version parameter...

Issue #14814: In the spirit of TOOWTDI, ditch the redundant version parameter to the factory functions by using the appropriate direct class references instead
parent 3feb8a79
...@@ -36,34 +36,22 @@ class NetmaskValueError(ValueError): ...@@ -36,34 +36,22 @@ class NetmaskValueError(ValueError):
"""A Value Error related to the netmask.""" """A Value Error related to the netmask."""
def ip_address(address, version=None): def ip_address(address):
"""Take an IP string/int and return an object of the correct type. """Take an IP string/int and return an object of the correct type.
Args: Args:
address: A string or integer, the IP address. Either IPv4 or address: A string or integer, the IP address. Either IPv4 or
IPv6 addresses may be supplied; integers less than 2**32 will IPv6 addresses may be supplied; integers less than 2**32 will
be considered to be IPv4 by default. be considered to be IPv4 by default.
version: An integer, 4 or 6. If set, don't try to automatically
determine what the IP address type is. Important for things
like ip_address(1), which could be IPv4, '192.0.2.1', or IPv6,
'2001:db8::1'.
Returns: Returns:
An IPv4Address or IPv6Address object. An IPv4Address or IPv6Address object.
Raises: Raises:
ValueError: if the *address* passed isn't either a v4 or a v6 ValueError: if the *address* passed isn't either a v4 or a v6
address, or if the version is not None, 4, or 6. address
""" """
if version is not None:
if version == 4:
return IPv4Address(address)
elif version == 6:
return IPv6Address(address)
else:
raise ValueError()
try: try:
return IPv4Address(address) return IPv4Address(address)
except (AddressValueError, NetmaskValueError): except (AddressValueError, NetmaskValueError):
...@@ -78,35 +66,22 @@ def ip_address(address, version=None): ...@@ -78,35 +66,22 @@ def ip_address(address, version=None):
address) address)
def ip_network(address, version=None, strict=True): def ip_network(address, strict=True):
"""Take an IP string/int and return an object of the correct type. """Take an IP string/int and return an object of the correct type.
Args: Args:
address: A string or integer, the IP network. Either IPv4 or address: A string or integer, the IP network. Either IPv4 or
IPv6 networks may be supplied; integers less than 2**32 will IPv6 networks may be supplied; integers less than 2**32 will
be considered to be IPv4 by default. be considered to be IPv4 by default.
version: An integer, 4 or 6. If set, don't try to automatically
determine what the IP address type is. Important for things
like ip_network(1), which could be IPv4, '192.0.2.1/32', or IPv6,
'2001:db8::1/128'.
Returns: Returns:
An IPv4Network or IPv6Network object. An IPv4Network or IPv6Network object.
Raises: Raises:
ValueError: if the string passed isn't either a v4 or a v6 ValueError: if the string passed isn't either a v4 or a v6
address. Or if the network has host bits set. Or if the version address. Or if the network has host bits set.
is not None, 4, or 6.
""" """
if version is not None:
if version == 4:
return IPv4Network(address, strict)
elif version == 6:
return IPv6Network(address, strict)
else:
raise ValueError()
try: try:
return IPv4Network(address, strict) return IPv4Network(address, strict)
except (AddressValueError, NetmaskValueError): except (AddressValueError, NetmaskValueError):
...@@ -121,24 +96,20 @@ def ip_network(address, version=None, strict=True): ...@@ -121,24 +96,20 @@ def ip_network(address, version=None, strict=True):
address) address)
def ip_interface(address, version=None): def ip_interface(address):
"""Take an IP string/int and return an object of the correct type. """Take an IP string/int and return an object of the correct type.
Args: Args:
address: A string or integer, the IP address. Either IPv4 or address: A string or integer, the IP address. Either IPv4 or
IPv6 addresses may be supplied; integers less than 2**32 will IPv6 addresses may be supplied; integers less than 2**32 will
be considered to be IPv4 by default. be considered to be IPv4 by default.
version: An integer, 4 or 6. If set, don't try to automatically
determine what the IP address type is. Important for things
like ip_interface(1), which could be IPv4, '192.0.2.1/32', or IPv6,
'2001:db8::1/128'.
Returns: Returns:
An IPv4Interface or IPv6Interface object. An IPv4Interface or IPv6Interface object.
Raises: Raises:
ValueError: if the string passed isn't either a v4 or a v6 ValueError: if the string passed isn't either a v4 or a v6
address. Or if the version is not None, 4, or 6. address.
Notes: Notes:
The IPv?Interface classes describe an Address on a particular The IPv?Interface classes describe an Address on a particular
...@@ -146,14 +117,6 @@ def ip_interface(address, version=None): ...@@ -146,14 +117,6 @@ def ip_interface(address, version=None):
and Network classes. and Network classes.
""" """
if version is not None:
if version == 4:
return IPv4Interface(address)
elif version == 6:
return IPv6Interface(address)
else:
raise ValueError()
try: try:
return IPv4Interface(address) return IPv4Interface(address)
except (AddressValueError, NetmaskValueError): except (AddressValueError, NetmaskValueError):
...@@ -281,7 +244,7 @@ def summarize_address_range(first, last): ...@@ -281,7 +244,7 @@ def summarize_address_range(first, last):
If the first and last objects are not the same version. If the first and last objects are not the same version.
ValueError: ValueError:
If the last object is not greater than the first. If the last object is not greater than the first.
If the version is not 4 or 6. If the version of the first address is not 4 or 6.
""" """
if (not (isinstance(first, _BaseAddress) and if (not (isinstance(first, _BaseAddress) and
...@@ -318,7 +281,7 @@ def summarize_address_range(first, last): ...@@ -318,7 +281,7 @@ def summarize_address_range(first, last):
if current == ip._ALL_ONES: if current == ip._ALL_ONES:
break break
first_int = current + 1 first_int = current + 1
first = ip_address(first_int, version=first._version) first = first.__class__(first_int)
def _collapse_addresses_recursive(addresses): def _collapse_addresses_recursive(addresses):
...@@ -586,12 +549,12 @@ class _BaseAddress(_IPAddressBase): ...@@ -586,12 +549,12 @@ class _BaseAddress(_IPAddressBase):
def __add__(self, other): def __add__(self, other):
if not isinstance(other, int): if not isinstance(other, int):
return NotImplemented return NotImplemented
return ip_address(int(self) + other, version=self._version) return self.__class__(int(self) + other)
def __sub__(self, other): def __sub__(self, other):
if not isinstance(other, int): if not isinstance(other, int):
return NotImplemented return NotImplemented
return ip_address(int(self) - other, version=self._version) return self.__class__(int(self) - other)
def __repr__(self): def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, str(self)) return '%s(%r)' % (self.__class__.__name__, str(self))
...@@ -612,13 +575,12 @@ class _BaseAddress(_IPAddressBase): ...@@ -612,13 +575,12 @@ class _BaseAddress(_IPAddressBase):
class _BaseNetwork(_IPAddressBase): class _BaseNetwork(_IPAddressBase):
"""A generic IP object. """A generic IP network object.
This IP class contains the version independent methods which are This IP class contains the version independent methods which are
used by networks. used by networks.
""" """
def __init__(self, address): def __init__(self, address):
self._cache = {} self._cache = {}
...@@ -642,14 +604,14 @@ class _BaseNetwork(_IPAddressBase): ...@@ -642,14 +604,14 @@ class _BaseNetwork(_IPAddressBase):
bcast = int(self.broadcast_address) - 1 bcast = int(self.broadcast_address) - 1
while cur <= bcast: while cur <= bcast:
cur += 1 cur += 1
yield ip_address(cur - 1, version=self._version) yield self._address_class(cur - 1)
def __iter__(self): def __iter__(self):
cur = int(self.network_address) cur = int(self.network_address)
bcast = int(self.broadcast_address) bcast = int(self.broadcast_address)
while cur <= bcast: while cur <= bcast:
cur += 1 cur += 1
yield ip_address(cur - 1, version=self._version) yield self._address_class(cur - 1)
def __getitem__(self, n): def __getitem__(self, n):
network = int(self.network_address) network = int(self.network_address)
...@@ -657,12 +619,12 @@ class _BaseNetwork(_IPAddressBase): ...@@ -657,12 +619,12 @@ class _BaseNetwork(_IPAddressBase):
if n >= 0: if n >= 0:
if network + n > broadcast: if network + n > broadcast:
raise IndexError raise IndexError
return ip_address(network + n, version=self._version) return self._address_class(network + n)
else: else:
n += 1 n += 1
if broadcast + n < network: if broadcast + n < network:
raise IndexError raise IndexError
return ip_address(broadcast + n, version=self._version) return self._address_class(broadcast + n)
def __lt__(self, other): def __lt__(self, other):
if self._version != other._version: if self._version != other._version:
...@@ -746,8 +708,8 @@ class _BaseNetwork(_IPAddressBase): ...@@ -746,8 +708,8 @@ class _BaseNetwork(_IPAddressBase):
def broadcast_address(self): def broadcast_address(self):
x = self._cache.get('broadcast_address') x = self._cache.get('broadcast_address')
if x is None: if x is None:
x = ip_address(int(self.network_address) | int(self.hostmask), x = self._address_class(int(self.network_address) |
version=self._version) int(self.hostmask))
self._cache['broadcast_address'] = x self._cache['broadcast_address'] = x
return x return x
...@@ -755,15 +717,15 @@ class _BaseNetwork(_IPAddressBase): ...@@ -755,15 +717,15 @@ class _BaseNetwork(_IPAddressBase):
def hostmask(self): def hostmask(self):
x = self._cache.get('hostmask') x = self._cache.get('hostmask')
if x is None: if x is None:
x = ip_address(int(self.netmask) ^ self._ALL_ONES, x = self._address_class(int(self.netmask) ^ self._ALL_ONES)
version=self._version)
self._cache['hostmask'] = x self._cache['hostmask'] = x
return x return x
@property @property
def network(self): def network(self):
return ip_network('%s/%d' % (str(self.network_address), # XXX (ncoghlan): This is redundant now and will likely be removed
self.prefixlen)) return self.__class__('%s/%d' % (str(self.network_address),
self.prefixlen))
@property @property
def with_prefixlen(self): def with_prefixlen(self):
...@@ -786,6 +748,10 @@ class _BaseNetwork(_IPAddressBase): ...@@ -786,6 +748,10 @@ class _BaseNetwork(_IPAddressBase):
def version(self): def version(self):
raise NotImplementedError('BaseNet has no version') raise NotImplementedError('BaseNet has no version')
@property
def _address_class(self):
raise NotImplementedError('BaseNet has no associated address class')
@property @property
def prefixlen(self): def prefixlen(self):
return self._prefixlen return self._prefixlen
...@@ -840,9 +806,8 @@ class _BaseNetwork(_IPAddressBase): ...@@ -840,9 +806,8 @@ class _BaseNetwork(_IPAddressBase):
raise StopIteration raise StopIteration
# Make sure we're comparing the network of other. # Make sure we're comparing the network of other.
other = ip_network('%s/%s' % (str(other.network_address), other = other.__class__('%s/%s' % (str(other.network_address),
str(other.prefixlen)), str(other.prefixlen)))
version=other._version)
s1, s2 = self.subnets() s1, s2 = self.subnets()
while s1 != other and s2 != other: while s1 != other and s2 != other:
...@@ -973,9 +938,9 @@ class _BaseNetwork(_IPAddressBase): ...@@ -973,9 +938,9 @@ class _BaseNetwork(_IPAddressBase):
'prefix length diff %d is invalid for netblock %s' % ( 'prefix length diff %d is invalid for netblock %s' % (
new_prefixlen, str(self))) new_prefixlen, str(self)))
first = ip_network('%s/%s' % (str(self.network_address), first = self.__class__('%s/%s' %
str(self._prefixlen + prefixlen_diff)), (str(self.network_address),
version=self._version) str(self._prefixlen + prefixlen_diff)))
yield first yield first
current = first current = first
...@@ -983,16 +948,17 @@ class _BaseNetwork(_IPAddressBase): ...@@ -983,16 +948,17 @@ class _BaseNetwork(_IPAddressBase):
broadcast = current.broadcast_address broadcast = current.broadcast_address
if broadcast == self.broadcast_address: if broadcast == self.broadcast_address:
return return
new_addr = ip_address(int(broadcast) + 1, version=self._version) new_addr = self._address_class(int(broadcast) + 1)
current = ip_network('%s/%s' % (str(new_addr), str(new_prefixlen)), current = self.__class__('%s/%s' % (str(new_addr),
version=self._version) str(new_prefixlen)))
yield current yield current
def masked(self): def masked(self):
"""Return the network object with the host bits masked out.""" """Return the network object with the host bits masked out."""
return ip_network('%s/%d' % (self.network_address, self._prefixlen), # XXX (ncoghlan): This is redundant now and will likely be removed
version=self._version) return self.__class__('%s/%d' % (self.network_address,
self._prefixlen))
def supernet(self, prefixlen_diff=1, new_prefix=None): def supernet(self, prefixlen_diff=1, new_prefix=None):
"""The supernet containing the current network. """The supernet containing the current network.
...@@ -1030,11 +996,10 @@ class _BaseNetwork(_IPAddressBase): ...@@ -1030,11 +996,10 @@ class _BaseNetwork(_IPAddressBase):
'current prefixlen is %d, cannot have a prefixlen_diff of %d' % 'current prefixlen is %d, cannot have a prefixlen_diff of %d' %
(self.prefixlen, prefixlen_diff)) (self.prefixlen, prefixlen_diff))
# TODO (pmoody): optimize this. # TODO (pmoody): optimize this.
t = ip_network('%s/%d' % (str(self.network_address), t = self.__class__('%s/%d' % (str(self.network_address),
self.prefixlen - prefixlen_diff), self.prefixlen - prefixlen_diff),
version=self._version, strict=False) strict=False)
return ip_network('%s/%d' % (str(t.network_address), t.prefixlen), return t.__class__('%s/%d' % (str(t.network_address), t.prefixlen))
version=t._version)
class _BaseV4(object): class _BaseV4(object):
...@@ -1391,6 +1356,9 @@ class IPv4Network(_BaseV4, _BaseNetwork): ...@@ -1391,6 +1356,9 @@ class IPv4Network(_BaseV4, _BaseNetwork):
.prefixlen: 27 .prefixlen: 27
""" """
# Class to use when creating address objects
# TODO (ncoghlan): Investigate using IPv4Interface instead
_address_class = IPv4Address
# the valid octets for host and netmasks. only useful for IPv4. # the valid octets for host and netmasks. only useful for IPv4.
_valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0)) _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
...@@ -2071,6 +2039,10 @@ class IPv6Network(_BaseV6, _BaseNetwork): ...@@ -2071,6 +2039,10 @@ class IPv6Network(_BaseV6, _BaseNetwork):
""" """
# Class to use when creating address objects
# TODO (ncoghlan): Investigate using IPv6Interface instead
_address_class = IPv6Address
def __init__(self, address, strict=True): def __init__(self, address, strict=True):
"""Instantiate a new IPv6 Network object. """Instantiate a new IPv6 Network object.
......
...@@ -780,12 +780,6 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -780,12 +780,6 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(self.ipv4_address.version, 4) self.assertEqual(self.ipv4_address.version, 4)
self.assertEqual(self.ipv6_address.version, 6) self.assertEqual(self.ipv6_address.version, 6)
with self.assertRaises(ValueError):
ipaddress.ip_address('1', version=[])
with self.assertRaises(ValueError):
ipaddress.ip_address('1', version=5)
def testMaxPrefixLength(self): def testMaxPrefixLength(self):
self.assertEqual(self.ipv4_interface.max_prefixlen, 32) self.assertEqual(self.ipv4_interface.max_prefixlen, 32)
self.assertEqual(self.ipv6_interface.max_prefixlen, 128) self.assertEqual(self.ipv6_interface.max_prefixlen, 128)
...@@ -1052,12 +1046,7 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -1052,12 +1046,7 @@ class IpaddrUnitTest(unittest.TestCase):
def testForceVersion(self): def testForceVersion(self):
self.assertEqual(ipaddress.ip_network(1).version, 4) self.assertEqual(ipaddress.ip_network(1).version, 4)
self.assertEqual(ipaddress.ip_network(1, version=6).version, 6) self.assertEqual(ipaddress.IPv6Network(1).version, 6)
with self.assertRaises(ValueError):
ipaddress.ip_network(1, version='l')
with self.assertRaises(ValueError):
ipaddress.ip_network(1, version=3)
def testWithStar(self): def testWithStar(self):
self.assertEqual(str(self.ipv4_interface.with_prefixlen), "1.2.3.4/24") self.assertEqual(str(self.ipv4_interface.with_prefixlen), "1.2.3.4/24")
...@@ -1148,13 +1137,6 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -1148,13 +1137,6 @@ class IpaddrUnitTest(unittest.TestCase):
sixtofouraddr.sixtofour) sixtofouraddr.sixtofour)
self.assertFalse(bad_addr.sixtofour) self.assertFalse(bad_addr.sixtofour)
def testIpInterfaceVersion(self):
with self.assertRaises(ValueError):
ipaddress.ip_interface(1, version=123)
with self.assertRaises(ValueError):
ipaddress.ip_interface(1, version='')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment