Commit 932346f5 authored by Nick Coghlan's avatar Nick Coghlan

Issue #18805: better netmask validation in ipaddress

parent 578c6777
...@@ -456,8 +456,8 @@ class _IPAddressBase(_TotalOrderingMixin): ...@@ -456,8 +456,8 @@ class _IPAddressBase(_TotalOrderingMixin):
raise AddressValueError(msg % (address, address_len, raise AddressValueError(msg % (address, address_len,
expected_len, self._version)) expected_len, self._version))
def _ip_int_from_prefix(self, prefixlen=None): def _ip_int_from_prefix(self, prefixlen):
"""Turn the prefix length netmask into a int for comparison. """Turn the prefix length into a bitwise netmask
Args: Args:
prefixlen: An integer, the prefix length. prefixlen: An integer, the prefix length.
...@@ -466,36 +466,92 @@ class _IPAddressBase(_TotalOrderingMixin): ...@@ -466,36 +466,92 @@ class _IPAddressBase(_TotalOrderingMixin):
An integer. An integer.
""" """
if prefixlen is None:
prefixlen = self._prefixlen
return self._ALL_ONES ^ (self._ALL_ONES >> prefixlen) return self._ALL_ONES ^ (self._ALL_ONES >> prefixlen)
def _prefix_from_ip_int(self, ip_int, mask=32): def _prefix_from_ip_int(self, ip_int):
"""Return prefix length from the decimal netmask. """Return prefix length from the bitwise netmask.
Args: Args:
ip_int: An integer, the IP address. ip_int: An integer, the netmask in axpanded bitwise format
mask: The netmask. Defaults to 32.
Returns:
An integer, the prefix length.
Raises:
ValueError: If the input intermingles zeroes & ones
"""
trailing_zeroes = _count_righthand_zero_bits(ip_int,
self._max_prefixlen)
prefixlen = self._max_prefixlen - trailing_zeroes
leading_ones = ip_int >> trailing_zeroes
all_ones = (1 << prefixlen) - 1
if leading_ones != all_ones:
byteslen = self._max_prefixlen // 8
details = ip_int.to_bytes(byteslen, 'big')
msg = 'Netmask pattern %r mixes zeroes & ones'
raise ValueError(msg % details)
return prefixlen
def _report_invalid_netmask(self, netmask_str):
msg = '%r is not a valid netmask' % netmask_str
raise NetmaskValueError(msg) from None
def _prefix_from_prefix_string(self, prefixlen_str):
"""Return prefix length from a numeric string
Args:
prefixlen_str: The string to be converted
Returns: Returns:
An integer, the prefix length. An integer, the prefix length.
Raises:
NetmaskValueError: If the input is not a valid netmask
""" """
return mask - _count_righthand_zero_bits(ip_int, mask) # int allows a leading +/- as well as surrounding whitespace,
# so we ensure that isn't the case
if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str):
self._report_invalid_netmask(prefixlen_str)
try:
prefixlen = int(prefixlen_str)
except ValueError:
self._report_invalid_netmask(prefixlen_str)
if not (0 <= prefixlen <= self._max_prefixlen):
self._report_invalid_netmask(prefixlen_str)
return prefixlen
def _ip_string_from_prefix(self, prefixlen=None): def _prefix_from_ip_string(self, ip_str):
"""Turn a prefix length into a dotted decimal string. """Turn a netmask/hostmask string into a prefix length
Args: Args:
prefixlen: An integer, the netmask prefix length. ip_str: The netmask/hostmask to be converted
Returns: Returns:
A string, the dotted decimal netmask string. An integer, the prefix length.
Raises:
NetmaskValueError: If the input is not a valid netmask/hostmask
""" """
if not prefixlen: # Parse the netmask/hostmask like an IP address.
prefixlen = self._prefixlen try:
return self._string_from_ip_int(self._ip_int_from_prefix(prefixlen)) ip_int = self._ip_int_from_string(ip_str)
except AddressValueError:
self._report_invalid_netmask(ip_str)
# Try matching a netmask (this would be /1*0*/ as a bitwise regexp).
# Note that the two ambiguous cases (all-ones and all-zeroes) are
# treated as netmasks.
try:
return self._prefix_from_ip_int(ip_int)
except ValueError:
pass
# Invert the bits, and try matching a /0+1+/ hostmask instead.
ip_int ^= self._ALL_ONES
try:
return self._prefix_from_ip_int(ip_int)
except ValueError:
self._report_invalid_netmask(ip_str)
class _BaseAddress(_IPAddressBase): class _BaseAddress(_IPAddressBase):
...@@ -504,7 +560,6 @@ class _BaseAddress(_IPAddressBase): ...@@ -504,7 +560,6 @@ class _BaseAddress(_IPAddressBase):
This IP class contains the version independent methods which are This IP class contains the version independent methods which are
used by single IP addresses. used by single IP addresses.
""" """
def __init__(self, address): def __init__(self, address):
...@@ -873,7 +928,7 @@ class _BaseNetwork(_IPAddressBase): ...@@ -873,7 +928,7 @@ class _BaseNetwork(_IPAddressBase):
raise ValueError('prefix length diff must be > 0') raise ValueError('prefix length diff must be > 0')
new_prefixlen = self._prefixlen + prefixlen_diff new_prefixlen = self._prefixlen + prefixlen_diff
if not self._is_valid_netmask(str(new_prefixlen)): if new_prefixlen > self._max_prefixlen:
raise ValueError( raise ValueError(
'prefix length diff %d is invalid for netblock %s' % ( 'prefix length diff %d is invalid for netblock %s' % (
new_prefixlen, self)) new_prefixlen, self))
...@@ -1428,33 +1483,16 @@ class IPv4Network(_BaseV4, _BaseNetwork): ...@@ -1428,33 +1483,16 @@ class IPv4Network(_BaseV4, _BaseNetwork):
self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) self.network_address = IPv4Address(self._ip_int_from_string(addr[0]))
if len(addr) == 2: if len(addr) == 2:
mask = addr[1].split('.') try:
# Check for a netmask in prefix length form
if len(mask) == 4: self._prefixlen = self._prefix_from_prefix_string(addr[1])
# We have dotted decimal netmask. except NetmaskValueError:
if self._is_valid_netmask(addr[1]): # Check for a netmask or hostmask in dotted-quad form.
self.netmask = IPv4Address(self._ip_int_from_string( # This may raise NetmaskValueError.
addr[1])) self._prefixlen = self._prefix_from_ip_string(addr[1])
elif self._is_hostmask(addr[1]):
self.netmask = IPv4Address(
self._ip_int_from_string(addr[1]) ^ self._ALL_ONES)
else:
raise NetmaskValueError('%r is not a valid netmask'
% addr[1])
self._prefixlen = self._prefix_from_ip_int(int(self.netmask))
else:
# We have a netmask in prefix length form.
if not self._is_valid_netmask(addr[1]):
raise NetmaskValueError('%r is not a valid netmask'
% addr[1])
self._prefixlen = int(addr[1])
self.netmask = IPv4Address(self._ip_int_from_prefix(
self._prefixlen))
else: else:
self._prefixlen = self._max_prefixlen self._prefixlen = self._max_prefixlen
self.netmask = IPv4Address(self._ip_int_from_prefix( self.netmask = IPv4Address(self._ip_int_from_prefix(self._prefixlen))
self._prefixlen))
if strict: if strict:
if (IPv4Address(int(self.network_address) & int(self.netmask)) != if (IPv4Address(int(self.network_address) & int(self.netmask)) !=
...@@ -2042,11 +2080,8 @@ class IPv6Network(_BaseV6, _BaseNetwork): ...@@ -2042,11 +2080,8 @@ class IPv6Network(_BaseV6, _BaseNetwork):
self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) self.network_address = IPv6Address(self._ip_int_from_string(addr[0]))
if len(addr) == 2: if len(addr) == 2:
if self._is_valid_netmask(addr[1]): # This may raise NetmaskValueError
self._prefixlen = int(addr[1]) self._prefixlen = self._prefix_from_prefix_string(addr[1])
else:
raise NetmaskValueError('%r is not a valid netmask'
% addr[1])
else: else:
self._prefixlen = self._max_prefixlen self._prefixlen = self._max_prefixlen
...@@ -2061,23 +2096,6 @@ class IPv6Network(_BaseV6, _BaseNetwork): ...@@ -2061,23 +2096,6 @@ class IPv6Network(_BaseV6, _BaseNetwork):
if self._prefixlen == (self._max_prefixlen - 1): if self._prefixlen == (self._max_prefixlen - 1):
self.hosts = self.__iter__ self.hosts = self.__iter__
def _is_valid_netmask(self, prefixlen):
"""Verify that the netmask/prefixlen is valid.
Args:
prefixlen: A string, the netmask in prefix length format.
Returns:
A boolean, True if the prefix represents a valid IPv6
netmask.
"""
try:
prefixlen = int(prefixlen)
except ValueError:
return False
return 0 <= prefixlen <= self._max_prefixlen
@property @property
def is_site_local(self): def is_site_local(self):
"""Test if the address is reserved for site-local. """Test if the address is reserved for site-local.
......
...@@ -398,18 +398,47 @@ class NetmaskTestMixin_v4(CommonTestMixin_v4): ...@@ -398,18 +398,47 @@ class NetmaskTestMixin_v4(CommonTestMixin_v4):
assertBadAddress("::1.2.3.4", "Only decimal digits") assertBadAddress("::1.2.3.4", "Only decimal digits")
assertBadAddress("1.2.3.256", re.escape("256 (> 255)")) assertBadAddress("1.2.3.256", re.escape("256 (> 255)"))
def test_valid_netmask(self):
self.assertEqual(str(self.factory('192.0.2.0/255.255.255.0')),
'192.0.2.0/24')
for i in range(0, 33):
# Generate and re-parse the CIDR format (trivial).
net_str = '0.0.0.0/%d' % i
net = self.factory(net_str)
self.assertEqual(str(net), net_str)
# Generate and re-parse the expanded netmask.
self.assertEqual(
str(self.factory('0.0.0.0/%s' % net.netmask)), net_str)
# Zero prefix is treated as decimal.
self.assertEqual(str(self.factory('0.0.0.0/0%d' % i)), net_str)
# Generate and re-parse the expanded hostmask. The ambiguous
# cases (/0 and /32) are treated as netmasks.
if i in (32, 0):
net_str = '0.0.0.0/%d' % (32 - i)
self.assertEqual(
str(self.factory('0.0.0.0/%s' % net.hostmask)), net_str)
def test_netmask_errors(self): def test_netmask_errors(self):
def assertBadNetmask(addr, netmask): def assertBadNetmask(addr, netmask):
msg = "%r is not a valid netmask" msg = "%r is not a valid netmask" % netmask
with self.assertNetmaskError(msg % netmask): with self.assertNetmaskError(re.escape(msg)):
self.factory("%s/%s" % (addr, netmask)) self.factory("%s/%s" % (addr, netmask))
assertBadNetmask("1.2.3.4", "") assertBadNetmask("1.2.3.4", "")
assertBadNetmask("1.2.3.4", "-1")
assertBadNetmask("1.2.3.4", "+1")
assertBadNetmask("1.2.3.4", " 1 ")
assertBadNetmask("1.2.3.4", "0x1")
assertBadNetmask("1.2.3.4", "33") assertBadNetmask("1.2.3.4", "33")
assertBadNetmask("1.2.3.4", "254.254.255.256") assertBadNetmask("1.2.3.4", "254.254.255.256")
assertBadNetmask("1.2.3.4", "1.a.2.3")
assertBadNetmask("1.1.1.1", "254.xyz.2.3") assertBadNetmask("1.1.1.1", "254.xyz.2.3")
assertBadNetmask("1.1.1.1", "240.255.0.0") assertBadNetmask("1.1.1.1", "240.255.0.0")
assertBadNetmask("1.1.1.1", "255.254.128.0")
assertBadNetmask("1.1.1.1", "0.1.127.255")
assertBadNetmask("1.1.1.1", "pudding") assertBadNetmask("1.1.1.1", "pudding")
assertBadNetmask("1.1.1.1", "::")
class InterfaceTestCase_v4(BaseTestCase, NetmaskTestMixin_v4): class InterfaceTestCase_v4(BaseTestCase, NetmaskTestMixin_v4):
factory = ipaddress.IPv4Interface factory = ipaddress.IPv4Interface
...@@ -438,17 +467,34 @@ class NetmaskTestMixin_v6(CommonTestMixin_v6): ...@@ -438,17 +467,34 @@ class NetmaskTestMixin_v6(CommonTestMixin_v6):
assertBadAddress("10/8", "At least 3 parts") assertBadAddress("10/8", "At least 3 parts")
assertBadAddress("1234:axy::b", "Only hex digits") assertBadAddress("1234:axy::b", "Only hex digits")
def test_valid_netmask(self):
# We only support CIDR for IPv6, because expanded netmasks are not
# standard notation.
self.assertEqual(str(self.factory('2001:db8::/32')), '2001:db8::/32')
for i in range(0, 129):
# Generate and re-parse the CIDR format (trivial).
net_str = '::/%d' % i
self.assertEqual(str(self.factory(net_str)), net_str)
# Zero prefix is treated as decimal.
self.assertEqual(str(self.factory('::/0%d' % i)), net_str)
def test_netmask_errors(self): def test_netmask_errors(self):
def assertBadNetmask(addr, netmask): def assertBadNetmask(addr, netmask):
msg = "%r is not a valid netmask" msg = "%r is not a valid netmask" % netmask
with self.assertNetmaskError(msg % netmask): with self.assertNetmaskError(re.escape(msg)):
self.factory("%s/%s" % (addr, netmask)) self.factory("%s/%s" % (addr, netmask))
assertBadNetmask("::1", "") assertBadNetmask("::1", "")
assertBadNetmask("::1", "::1") assertBadNetmask("::1", "::1")
assertBadNetmask("::1", "1::") assertBadNetmask("::1", "1::")
assertBadNetmask("::1", "-1")
assertBadNetmask("::1", "+1")
assertBadNetmask("::1", " 1 ")
assertBadNetmask("::1", "0x1")
assertBadNetmask("::1", "129") assertBadNetmask("::1", "129")
assertBadNetmask("::1", "1.2.3.4")
assertBadNetmask("::1", "pudding") assertBadNetmask("::1", "pudding")
assertBadNetmask("::", "::")
class InterfaceTestCase_v6(BaseTestCase, NetmaskTestMixin_v6): class InterfaceTestCase_v6(BaseTestCase, NetmaskTestMixin_v6):
factory = ipaddress.IPv6Interface factory = ipaddress.IPv6Interface
...@@ -694,16 +740,14 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -694,16 +740,14 @@ class IpaddrUnitTest(unittest.TestCase):
def testZeroNetmask(self): def testZeroNetmask(self):
ipv4_zero_netmask = ipaddress.IPv4Interface('1.2.3.4/0') ipv4_zero_netmask = ipaddress.IPv4Interface('1.2.3.4/0')
self.assertEqual(int(ipv4_zero_netmask.network.netmask), 0) self.assertEqual(int(ipv4_zero_netmask.network.netmask), 0)
self.assertTrue(ipv4_zero_netmask.network._is_valid_netmask( self.assertEqual(ipv4_zero_netmask._prefix_from_prefix_string('0'), 0)
str(0)))
self.assertTrue(ipv4_zero_netmask._is_valid_netmask('0')) self.assertTrue(ipv4_zero_netmask._is_valid_netmask('0'))
self.assertTrue(ipv4_zero_netmask._is_valid_netmask('0.0.0.0')) self.assertTrue(ipv4_zero_netmask._is_valid_netmask('0.0.0.0'))
self.assertFalse(ipv4_zero_netmask._is_valid_netmask('invalid')) self.assertFalse(ipv4_zero_netmask._is_valid_netmask('invalid'))
ipv6_zero_netmask = ipaddress.IPv6Interface('::1/0') ipv6_zero_netmask = ipaddress.IPv6Interface('::1/0')
self.assertEqual(int(ipv6_zero_netmask.network.netmask), 0) self.assertEqual(int(ipv6_zero_netmask.network.netmask), 0)
self.assertTrue(ipv6_zero_netmask.network._is_valid_netmask( self.assertEqual(ipv6_zero_netmask._prefix_from_prefix_string('0'), 0)
str(0)))
def testIPv4NetAndHostmasks(self): def testIPv4NetAndHostmasks(self):
net = self.ipv4_network net = self.ipv4_network
...@@ -719,7 +763,7 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -719,7 +763,7 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertFalse(net._is_hostmask('1.2.3.4')) self.assertFalse(net._is_hostmask('1.2.3.4'))
net = ipaddress.IPv4Network('127.0.0.0/0.0.0.255') net = ipaddress.IPv4Network('127.0.0.0/0.0.0.255')
self.assertEqual(24, net.prefixlen) self.assertEqual(net.prefixlen, 24)
def testGetBroadcast(self): def testGetBroadcast(self):
self.assertEqual(int(self.ipv4_network.broadcast_address), 16909311) self.assertEqual(int(self.ipv4_network.broadcast_address), 16909311)
...@@ -1271,11 +1315,6 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -1271,11 +1315,6 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(ipaddress.IPv6Interface('::1:0:0:0:0').packed, self.assertEqual(ipaddress.IPv6Interface('::1:0:0:0:0').packed,
b'\x00' * 6 + b'\x00\x01' + b'\x00' * 8) b'\x00' * 6 + b'\x00\x01' + b'\x00' * 8)
def testIpStrFromPrefixlen(self):
ipv4 = ipaddress.IPv4Interface('1.2.3.4/24')
self.assertEqual(ipv4._ip_string_from_prefix(), '255.255.255.0')
self.assertEqual(ipv4._ip_string_from_prefix(28), '255.255.255.240')
def testIpType(self): def testIpType(self):
ipv4net = ipaddress.ip_network('1.2.3.4') ipv4net = ipaddress.ip_network('1.2.3.4')
ipv4addr = ipaddress.ip_address('1.2.3.4') ipv4addr = ipaddress.ip_address('1.2.3.4')
...@@ -1467,14 +1506,8 @@ class IpaddrUnitTest(unittest.TestCase): ...@@ -1467,14 +1506,8 @@ class IpaddrUnitTest(unittest.TestCase):
def testIPBases(self): def testIPBases(self):
net = self.ipv4_network net = self.ipv4_network
self.assertEqual('1.2.3.0/24', net.compressed) self.assertEqual('1.2.3.0/24', net.compressed)
self.assertEqual(
net._ip_int_from_prefix(24),
net._ip_int_from_prefix(None))
net = self.ipv6_network net = self.ipv6_network
self.assertRaises(ValueError, net._string_from_ip_int, 2**128 + 1) self.assertRaises(ValueError, net._string_from_ip_int, 2**128 + 1)
self.assertEqual(
self.ipv6_address._string_from_ip_int(self.ipv6_address._ip),
self.ipv6_address._string_from_ip_int(None))
def testIPv6NetworkHelpers(self): def testIPv6NetworkHelpers(self):
net = self.ipv6_network net = self.ipv6_network
......
...@@ -48,6 +48,9 @@ Core and Builtins ...@@ -48,6 +48,9 @@ Core and Builtins
Library Library
------- -------
- Issue #18805: the netmask/hostmask parsing in ipaddress now more reliably
filters out illegal values
- Issue #17369: get_filename was raising an exception if the filename - Issue #17369: get_filename was raising an exception if the filename
parameter's RFC2231 encoding was broken in certain ways. This was parameter's RFC2231 encoding was broken in certain ways. This was
a regression relative to python2. a regression relative to python2.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment