Commit 2b75fc2b authored by Eric V. Smith's avatar Eric V. Smith Committed by GitHub

Minor fixes to dataclass tests. (GH-6243)

 Also, re-enable a test for ClassVars with default_factory.
parent dfb6e54d
...@@ -133,8 +133,8 @@ class TestCase(unittest.TestCase): ...@@ -133,8 +133,8 @@ class TestCase(unittest.TestCase):
self.assertEqual(hash(C(10)), hash((10,))) self.assertEqual(hash(C(10)), hash((10,)))
# Creating this class should generate an exception, because # Creating this class should generate an exception, because
# __hash__ exists and is not None, which it would be if it had # __hash__ exists and is not None, which it would be if it
# been auto-generated do due __eq__ being defined. # had been auto-generated due to __eq__ being defined.
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __hash__'): 'Cannot overwrite attribute __hash__'):
@dataclass(unsafe_hash=True) @dataclass(unsafe_hash=True)
...@@ -145,7 +145,6 @@ class TestCase(unittest.TestCase): ...@@ -145,7 +145,6 @@ class TestCase(unittest.TestCase):
def __hash__(self): def __hash__(self):
pass pass
def test_overwrite_fields_in_derived_class(self): def test_overwrite_fields_in_derived_class(self):
# Note that x from C1 replaces x in Base, but the order remains # Note that x from C1 replaces x in Base, but the order remains
# the same as defined in Base. # the same as defined in Base.
...@@ -624,7 +623,7 @@ class TestCase(unittest.TestCase): ...@@ -624,7 +623,7 @@ class TestCase(unittest.TestCase):
self.assertIs(o1.x, o2.x) self.assertIs(o1.x, o2.x)
def test_no_options(self): def test_no_options(self):
# call with dataclass() # Call with dataclass().
@dataclass() @dataclass()
class C: class C:
x: int x: int
...@@ -639,7 +638,7 @@ class TestCase(unittest.TestCase): ...@@ -639,7 +638,7 @@ class TestCase(unittest.TestCase):
y: int y: int
self.assertNotEqual(Point(1, 2), (1, 2)) self.assertNotEqual(Point(1, 2), (1, 2))
# And that we can't compare to another unrelated dataclass # And that we can't compare to another unrelated dataclass.
@dataclass @dataclass
class C: class C:
x: int x: int
...@@ -664,7 +663,7 @@ class TestCase(unittest.TestCase): ...@@ -664,7 +663,7 @@ class TestCase(unittest.TestCase):
self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
# Make sure we can't unpack # Make sure we can't unpack.
with self.assertRaisesRegex(TypeError, 'unpack'): with self.assertRaisesRegex(TypeError, 'unpack'):
x, y, z = Point3D(4, 5, 6) x, y, z = Point3D(4, 5, 6)
...@@ -695,7 +694,7 @@ class TestCase(unittest.TestCase): ...@@ -695,7 +694,7 @@ class TestCase(unittest.TestCase):
# Verify __init__. # Verify __init__.
signature = inspect.signature(cls.__init__) signature = inspect.signature(cls.__init__)
# Check the return type, should be None # Check the return type, should be None.
self.assertIs(signature.return_annotation, None) self.assertIs(signature.return_annotation, None)
# Check each parameter. # Check each parameter.
...@@ -716,12 +715,12 @@ class TestCase(unittest.TestCase): ...@@ -716,12 +715,12 @@ class TestCase(unittest.TestCase):
param = next(params) param = next(params)
self.assertEqual(param.name, 'k') self.assertEqual(param.name, 'k')
self.assertIs (param.annotation, F) self.assertIs (param.annotation, F)
# Don't test for the default, since it's set to MISSING # Don't test for the default, since it's set to MISSING.
self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
param = next(params) param = next(params)
self.assertEqual(param.name, 'l') self.assertEqual(param.name, 'l')
self.assertIs (param.annotation, float) self.assertIs (param.annotation, float)
# Don't test for the default, since it's set to MISSING # Don't test for the default, since it's set to MISSING.
self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
self.assertRaises(StopIteration, next, params) self.assertRaises(StopIteration, next, params)
...@@ -867,7 +866,7 @@ class TestCase(unittest.TestCase): ...@@ -867,7 +866,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(C().x, 5) self.assertEqual(C().x, 5)
# Now call super(), and it will raise # Now call super(), and it will raise.
@dataclass @dataclass
class C(B): class C(B):
def __post_init__(self): def __post_init__(self):
...@@ -928,8 +927,8 @@ class TestCase(unittest.TestCase): ...@@ -928,8 +927,8 @@ class TestCase(unittest.TestCase):
c = C(5) c = C(5)
self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
self.assertEqual(len(fields(C)), 2) # We have 2 fields self.assertEqual(len(fields(C)), 2) # We have 2 fields.
self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars.
self.assertEqual(c.z, 1000) self.assertEqual(c.z, 1000)
self.assertEqual(c.w, 2000) self.assertEqual(c.w, 2000)
self.assertEqual(c.t, 3000) self.assertEqual(c.t, 3000)
...@@ -1205,14 +1204,13 @@ class TestCase(unittest.TestCase): ...@@ -1205,14 +1204,13 @@ class TestCase(unittest.TestCase):
d = D(4, 5) d = D(4, 5)
self.assertEqual((d.x, d.z), (4, 5)) self.assertEqual((d.x, d.z), (4, 5))
def test_classvar_default_factory(self):
def x_test_classvar_default_factory(self): # It's an error for a ClassVar to have a factory function.
# XXX: it's an error for a ClassVar to have a factory function with self.assertRaisesRegex(TypeError,
@dataclass 'cannot have a default factory'):
class C: @dataclass
x: ClassVar[int] = field(default_factory=int) class C:
x: ClassVar[int] = field(default_factory=int)
self.assertIs(C().x, int)
def test_is_dataclass(self): def test_is_dataclass(self):
class NotDataClass: class NotDataClass:
...@@ -1264,7 +1262,7 @@ class TestCase(unittest.TestCase): ...@@ -1264,7 +1262,7 @@ class TestCase(unittest.TestCase):
fields(C()) fields(C())
def test_helper_asdict(self): def test_helper_asdict(self):
# Basic tests for asdict(), it should return a new dictionary # Basic tests for asdict(), it should return a new dictionary.
@dataclass @dataclass
class C: class C:
x: int x: int
...@@ -1279,7 +1277,7 @@ class TestCase(unittest.TestCase): ...@@ -1279,7 +1277,7 @@ class TestCase(unittest.TestCase):
self.assertIs(type(asdict(c)), dict) self.assertIs(type(asdict(c)), dict)
def test_helper_asdict_raises_on_classes(self): def test_helper_asdict_raises_on_classes(self):
# asdict() should raise on a class object # asdict() should raise on a class object.
@dataclass @dataclass
class C: class C:
x: int x: int
...@@ -1377,7 +1375,7 @@ class TestCase(unittest.TestCase): ...@@ -1377,7 +1375,7 @@ class TestCase(unittest.TestCase):
self.assertIs(type(d), OrderedDict) self.assertIs(type(d), OrderedDict)
def test_helper_astuple(self): def test_helper_astuple(self):
# Basic tests for astuple(), it should return a new tuple # Basic tests for astuple(), it should return a new tuple.
@dataclass @dataclass
class C: class C:
x: int x: int
...@@ -1392,7 +1390,7 @@ class TestCase(unittest.TestCase): ...@@ -1392,7 +1390,7 @@ class TestCase(unittest.TestCase):
self.assertIs(type(astuple(c)), tuple) self.assertIs(type(astuple(c)), tuple)
def test_helper_astuple_raises_on_classes(self): def test_helper_astuple_raises_on_classes(self):
# astuple() should raise on a class object # astuple() should raise on a class object.
@dataclass @dataclass
class C: class C:
x: int x: int
...@@ -1489,7 +1487,7 @@ class TestCase(unittest.TestCase): ...@@ -1489,7 +1487,7 @@ class TestCase(unittest.TestCase):
self.assertIs(type(t), NT) self.assertIs(type(t), NT)
def test_dynamic_class_creation(self): def test_dynamic_class_creation(self):
cls_dict = {'__annotations__': {'x':int, 'y':int}, cls_dict = {'__annotations__': {'x': int, 'y': int},
} }
# Create the class. # Create the class.
...@@ -1502,7 +1500,7 @@ class TestCase(unittest.TestCase): ...@@ -1502,7 +1500,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
def test_dynamic_class_creation_using_field(self): def test_dynamic_class_creation_using_field(self):
cls_dict = {'__annotations__': {'x':int, 'y':int}, cls_dict = {'__annotations__': {'x': int, 'y': int},
'y': field(default=5), 'y': field(default=5),
} }
...@@ -1569,8 +1567,8 @@ class TestCase(unittest.TestCase): ...@@ -1569,8 +1567,8 @@ class TestCase(unittest.TestCase):
def test_alternate_classmethod_constructor(self): def test_alternate_classmethod_constructor(self):
# Since __post_init__ can't take params, use a classmethod # Since __post_init__ can't take params, use a classmethod
# alternate constructor. This is mostly an example to show how # alternate constructor. This is mostly an example to show
# to use this technique. # how to use this technique.
@dataclass @dataclass
class C: class C:
x: int x: int
...@@ -1604,7 +1602,7 @@ class TestCase(unittest.TestCase): ...@@ -1604,7 +1602,7 @@ class TestCase(unittest.TestCase):
class C: class C:
i: int = field(metadata=0) i: int = field(metadata=0)
# Make sure an empty dict works # Make sure an empty dict works.
@dataclass @dataclass
class C: class C:
i: int = field(metadata={}) i: int = field(metadata={})
...@@ -1666,7 +1664,7 @@ class TestCase(unittest.TestCase): ...@@ -1666,7 +1664,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(box.content, 42) self.assertEqual(box.content, 42)
self.assertEqual(box.label, '<unknown>') self.assertEqual(box.label, '<unknown>')
# subscripting the resulting class should work, etc. # Subscripting the resulting class should work, etc.
Alias = List[LabeledBox[int]] Alias = List[LabeledBox[int]]
def test_generic_extending(self): def test_generic_extending(self):
...@@ -1931,7 +1929,7 @@ class TestFieldNoAnnotation(unittest.TestCase): ...@@ -1931,7 +1929,7 @@ class TestFieldNoAnnotation(unittest.TestCase):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
"'f' is a field but has no type annotation"): "'f' is a field but has no type annotation"):
# This is still an error: make sure we don't pick up the # This is still an error: make sure we don't pick up the
# type annotation in the base class. # type annotation in the base class.
@dataclass @dataclass
class C(B): class C(B):
f = field() f = field()
...@@ -1944,7 +1942,7 @@ class TestFieldNoAnnotation(unittest.TestCase): ...@@ -1944,7 +1942,7 @@ class TestFieldNoAnnotation(unittest.TestCase):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
"'f' is a field but has no type annotation"): "'f' is a field but has no type annotation"):
# This is still an error: make sure we don't pick up the # This is still an error: make sure we don't pick up the
# type annotation in the base class. # type annotation in the base class.
@dataclass @dataclass
class C(B): class C(B):
f = field() f = field()
...@@ -2178,7 +2176,7 @@ class TestRepr(unittest.TestCase): ...@@ -2178,7 +2176,7 @@ class TestRepr(unittest.TestCase):
class TestFrozen(unittest.TestCase): class TestFrozen(unittest.TestCase):
def test_overwriting_frozen(self): def test_overwriting_frozen(self):
# frozen uses __setattr__ and __delattr__ # frozen uses __setattr__ and __delattr__.
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
'Cannot overwrite attribute __setattr__'): 'Cannot overwrite attribute __setattr__'):
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -2473,16 +2471,16 @@ class TestHash(unittest.TestCase): ...@@ -2473,16 +2471,16 @@ class TestHash(unittest.TestCase):
def test_hash_no_args(self): def test_hash_no_args(self):
# Test dataclasses with no hash= argument. This exists to # Test dataclasses with no hash= argument. This exists to
# make sure that if the @dataclass parameter name is changed # make sure that if the @dataclass parameter name is changed
# or the non-default hashing behavior changes, the default # or the non-default hashing behavior changes, the default
# hashability keeps working the same way. # hashability keeps working the same way.
class Base: class Base:
def __hash__(self): def __hash__(self):
return 301 return 301
# If frozen or eq is None, then use the default value (do not # If frozen or eq is None, then use the default value (do not
# specify any value in the decorator). # specify any value in the decorator).
for frozen, eq, base, expected in [ for frozen, eq, base, expected in [
(None, None, object, 'unhashable'), (None, None, object, 'unhashable'),
(None, None, Base, 'unhashable'), (None, None, Base, 'unhashable'),
...@@ -2534,9 +2532,9 @@ class TestHash(unittest.TestCase): ...@@ -2534,9 +2532,9 @@ class TestHash(unittest.TestCase):
elif expected == 'object': elif expected == 'object':
# I'm not sure what test to use here. object's # I'm not sure what test to use here. object's
# hash isn't based on id(), so calling hash() # hash isn't based on id(), so calling hash()
# won't tell us much. So, just check the function # won't tell us much. So, just check the
# used is object's. # function used is object's.
self.assertIs(C.__hash__, object.__hash__) self.assertIs(C.__hash__, object.__hash__)
elif expected == 'tuple': elif expected == 'tuple':
...@@ -2665,8 +2663,9 @@ class TestSlots(unittest.TestCase): ...@@ -2665,8 +2663,9 @@ class TestSlots(unittest.TestCase):
__slots__ = ('x',) __slots__ = ('x',)
x: Any x: Any
# There was a bug where a variable in a slot was assumed # There was a bug where a variable in a slot was assumed to
# to also have a default value (of type types.MemberDescriptorType). # also have a default value (of type
# types.MemberDescriptorType).
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r"__init__\(\) missing 1 required positional argument: 'x'"): r"__init__\(\) missing 1 required positional argument: 'x'"):
C() C()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment