Commit 4e81296b authored by Eric V. Smith's avatar Eric V. Smith Committed by GitHub

bpo-33536: Validate make_dataclass() field names. (GH-6906)

parent 5db5c066
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
import copy import copy
import types import types
import inspect import inspect
import keyword
__all__ = ['dataclass', __all__ = ['dataclass',
'field', 'field',
...@@ -1100,6 +1101,9 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, ...@@ -1100,6 +1101,9 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
# Copy namespace since we're going to mutate it. # Copy namespace since we're going to mutate it.
namespace = namespace.copy() namespace = namespace.copy()
# While we're looking through the field names, validate that they
# are identifiers, are not keywords, and not duplicates.
seen = set()
anns = {} anns = {}
for item in fields: for item in fields:
if isinstance(item, str): if isinstance(item, str):
...@@ -1110,6 +1114,17 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, ...@@ -1110,6 +1114,17 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
elif len(item) == 3: elif len(item) == 3:
name, tp, spec = item name, tp, spec = item
namespace[name] = spec namespace[name] = spec
else:
raise TypeError(f'Invalid field: {item!r}')
if not isinstance(name, str) or not name.isidentifier():
raise TypeError(f'Field names must be valid identifers: {name!r}')
if keyword.iskeyword(name):
raise TypeError(f'Field names must not be keywords: {name!r}')
if name in seen:
raise TypeError(f'Field name duplicated: {name!r}')
seen.add(name)
anns[name] = tp anns[name] = tp
namespace['__annotations__'] = anns namespace['__annotations__'] = anns
......
...@@ -1826,114 +1826,6 @@ class TestCase(unittest.TestCase): ...@@ -1826,114 +1826,6 @@ class TestCase(unittest.TestCase):
self.assertEqual(new_sample.x, another_new_sample.x) self.assertEqual(new_sample.x, another_new_sample.x)
self.assertEqual(sample.y, another_new_sample.y) self.assertEqual(sample.y, another_new_sample.y)
def test_helper_make_dataclass(self):
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace={'add_one': lambda self: self.x + 1})
c = C(10)
self.assertEqual((c.x, c.y), (10, 5))
self.assertEqual(c.add_one(), 11)
def test_helper_make_dataclass_no_mutate_namespace(self):
# Make sure a provided namespace isn't mutated.
ns = {}
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace=ns)
self.assertEqual(ns, {})
def test_helper_make_dataclass_base(self):
class Base1:
pass
class Base2:
pass
C = make_dataclass('C',
[('x', int)],
bases=(Base1, Base2))
c = C(2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)
def test_helper_make_dataclass_base_dataclass(self):
@dataclass
class Base1:
x: int
class Base2:
pass
C = make_dataclass('C',
[('y', int)],
bases=(Base1, Base2))
with self.assertRaisesRegex(TypeError, 'required positional'):
c = C(2)
c = C(1, 2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)
self.assertEqual((c.x, c.y), (1, 2))
def test_helper_make_dataclass_init_var(self):
def post_init(self, y):
self.x *= y
C = make_dataclass('C',
[('x', int),
('y', InitVar[int]),
],
namespace={'__post_init__': post_init},
)
c = C(2, 3)
self.assertEqual(vars(c), {'x': 6})
self.assertEqual(len(fields(c)), 1)
def test_helper_make_dataclass_class_var(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
])
c = C(1)
self.assertEqual(vars(c), {'x': 1})
self.assertEqual(len(fields(c)), 1)
self.assertEqual(C.y, 10)
self.assertEqual(C.z, 20)
def test_helper_make_dataclass_other_params(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
],
init=False)
# Make sure we have a repr, but no init.
self.assertNotIn('__init__', vars(C))
self.assertIn('__repr__', vars(C))
# Make sure random other params don't work.
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
C = make_dataclass('C',
[],
xxinit=False)
def test_helper_make_dataclass_no_types(self):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': 'typing.Any',
'z': 'typing.Any'})
C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': int,
'z': 'typing.Any'})
class TestFieldNoAnnotation(unittest.TestCase): class TestFieldNoAnnotation(unittest.TestCase):
def test_field_without_annotation(self): def test_field_without_annotation(self):
...@@ -2947,5 +2839,170 @@ class TestStringAnnotations(unittest.TestCase): ...@@ -2947,5 +2839,170 @@ class TestStringAnnotations(unittest.TestCase):
self.assertNotIn('not_iv4', c.__dict__) self.assertNotIn('not_iv4', c.__dict__)
class TestMakeDataclass(unittest.TestCase):
def test_simple(self):
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace={'add_one': lambda self: self.x + 1})
c = C(10)
self.assertEqual((c.x, c.y), (10, 5))
self.assertEqual(c.add_one(), 11)
def test_no_mutate_namespace(self):
# Make sure a provided namespace isn't mutated.
ns = {}
C = make_dataclass('C',
[('x', int),
('y', int, field(default=5))],
namespace=ns)
self.assertEqual(ns, {})
def test_base(self):
class Base1:
pass
class Base2:
pass
C = make_dataclass('C',
[('x', int)],
bases=(Base1, Base2))
c = C(2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)
def test_base_dataclass(self):
@dataclass
class Base1:
x: int
class Base2:
pass
C = make_dataclass('C',
[('y', int)],
bases=(Base1, Base2))
with self.assertRaisesRegex(TypeError, 'required positional'):
c = C(2)
c = C(1, 2)
self.assertIsInstance(c, C)
self.assertIsInstance(c, Base1)
self.assertIsInstance(c, Base2)
self.assertEqual((c.x, c.y), (1, 2))
def test_init_var(self):
def post_init(self, y):
self.x *= y
C = make_dataclass('C',
[('x', int),
('y', InitVar[int]),
],
namespace={'__post_init__': post_init},
)
c = C(2, 3)
self.assertEqual(vars(c), {'x': 6})
self.assertEqual(len(fields(c)), 1)
def test_class_var(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
])
c = C(1)
self.assertEqual(vars(c), {'x': 1})
self.assertEqual(len(fields(c)), 1)
self.assertEqual(C.y, 10)
self.assertEqual(C.z, 20)
def test_other_params(self):
C = make_dataclass('C',
[('x', int),
('y', ClassVar[int], 10),
('z', ClassVar[int], field(default=20)),
],
init=False)
# Make sure we have a repr, but no init.
self.assertNotIn('__init__', vars(C))
self.assertIn('__repr__', vars(C))
# Make sure random other params don't work.
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
C = make_dataclass('C',
[],
xxinit=False)
def test_no_types(self):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': 'typing.Any',
'z': 'typing.Any'})
C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': int,
'z': 'typing.Any'})
def test_invalid_type_specification(self):
for bad_field in [(),
(1, 2, 3, 4),
]:
with self.subTest(bad_field=bad_field):
with self.assertRaisesRegex(TypeError, r'Invalid field: '):
make_dataclass('C', ['a', bad_field])
# And test for things with no len().
for bad_field in [float,
lambda x:x,
]:
with self.subTest(bad_field=bad_field):
with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
make_dataclass('C', ['a', bad_field])
def test_duplicate_field_names(self):
for field in ['a', 'ab']:
with self.subTest(field=field):
with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
make_dataclass('C', [field, 'a', field])
def test_keyword_field_names(self):
for field in ['for', 'async', 'await', 'as']:
with self.subTest(field=field):
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
make_dataclass('C', ['a', field])
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
make_dataclass('C', [field])
with self.assertRaisesRegex(TypeError, 'must not be keywords'):
make_dataclass('C', [field, 'a'])
def test_non_identifier_field_names(self):
for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
with self.subTest(field=field):
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
make_dataclass('C', ['a', field])
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
make_dataclass('C', [field])
with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
make_dataclass('C', [field, 'a'])
def test_underscore_field_names(self):
# Unlike namedtuple, it's okay if dataclass field names have
# an underscore.
make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
def test_funny_class_names_names(self):
# No reason to prevent weird class names, since
# types.new_class allows them.
for classname in ['()', 'x,y', '*', '2@3', '']:
with self.subTest(classname=classname):
C = make_dataclass(classname, ['a', 'b'])
self.assertEqual(C.__name__, classname)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
dataclasses.make_dataclass now checks for invalid field names and duplicate
fields. Also, added a check for invalid field specifications.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment