Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cpython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Kirill Smelkov
cpython
Commits
f0db54a0
Commit
f0db54a0
authored
Dec 04, 2017
by
Eric V. Smith
Committed by
GitHub
Dec 04, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bpo-32214: Implement PEP 557: Data Classes (#4704)
parent
1e2fcac4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
2854 additions
and
0 deletions
+2854
-0
Lib/dataclasses.py
Lib/dataclasses.py
+776
-0
Lib/test/test_dataclasses.py
Lib/test/test_dataclasses.py
+2076
-0
Misc/NEWS.d/next/Library/2017-12-04-15-51-57.bpo-32214.uozdNj.rst
...S.d/next/Library/2017-12-04-15-51-57.bpo-32214.uozdNj.rst
+2
-0
No files found.
Lib/dataclasses.py
0 → 100644
View file @
f0db54a0
import
sys
import
types
from
copy
import
deepcopy
import
collections
import
inspect
__all__
=
[
'dataclass'
,
'field'
,
'FrozenInstanceError'
,
'InitVar'
,
# Helper functions.
'fields'
,
'asdict'
,
'astuple'
,
'make_dataclass'
,
'replace'
,
]
# Raised when an attempt is made to modify a frozen class.
class
FrozenInstanceError
(
AttributeError
):
pass
# A sentinel object for default values to signal that a
# default-factory will be used.
# This is given a nice repr() which will appear in the function
# signature of dataclasses' constructors.
class
_HAS_DEFAULT_FACTORY_CLASS
:
def
__repr__
(
self
):
return
'<factory>'
_HAS_DEFAULT_FACTORY
=
_HAS_DEFAULT_FACTORY_CLASS
()
# A sentinel object to detect if a parameter is supplied or not.
class
_MISSING_FACTORY
:
def
__repr__
(
self
):
return
'<missing>'
_MISSING
=
_MISSING_FACTORY
()
# Since most per-field metadata will be unused, create an empty
# read-only proxy that can be shared among all fields.
_EMPTY_METADATA
=
types
.
MappingProxyType
({})
# Markers for the various kinds of fields and pseudo-fields.
_FIELD
=
object
()
# An actual field.
_FIELD_CLASSVAR
=
object
()
# Not a field, but a ClassVar.
_FIELD_INITVAR
=
object
()
# Not a field, but an InitVar.
# The name of an attribute on the class where we store the Field
# objects. Also used to check if a class is a Data Class.
_MARKER
=
'__dataclass_fields__'
# The name of the function, that if it exists, is called at the end of
# __init__.
_POST_INIT_NAME
=
'__post_init__'
class
_InitVarMeta
(
type
):
def
__getitem__
(
self
,
params
):
return
self
class
InitVar
(
metaclass
=
_InitVarMeta
):
pass
# Instances of Field are only ever created from within this module,
# and only from the field() function, although Field instances are
# exposed externally as (conceptually) read-only objects.
# name and type are filled in after the fact, not in __init__. They're
# not known at the time this class is instantiated, but it's
# convenient if they're available later.
# When cls._MARKER is filled in with a list of Field objects, the name
# and type fields will have been populated.
class
Field
:
__slots__
=
(
'name'
,
'type'
,
'default'
,
'default_factory'
,
'repr'
,
'hash'
,
'init'
,
'compare'
,
'metadata'
,
'_field_type'
,
# Private: not to be used by user code.
)
def
__init__
(
self
,
default
,
default_factory
,
init
,
repr
,
hash
,
compare
,
metadata
):
self
.
name
=
None
self
.
type
=
None
self
.
default
=
default
self
.
default_factory
=
default_factory
self
.
init
=
init
self
.
repr
=
repr
self
.
hash
=
hash
self
.
compare
=
compare
self
.
metadata
=
(
_EMPTY_METADATA
if
metadata
is
None
or
len
(
metadata
)
==
0
else
types
.
MappingProxyType
(
metadata
))
self
.
_field_type
=
None
def
__repr__
(
self
):
return
(
'Field('
f'name=
{
self
.
name
!
r
}
,'
f'type=
{
self
.
type
}
,'
f'default=
{
self
.
default
}
,'
f'default_factory=
{
self
.
default_factory
}
,'
f'init=
{
self
.
init
}
,'
f'repr=
{
self
.
repr
}
,'
f'hash=
{
self
.
hash
}
,'
f'compare=
{
self
.
compare
}
,'
f'metadata=
{
self
.
metadata
}
'
')'
)
# This function is used instead of exposing Field creation directly,
# so that a type checker can be told (via overloads) that this is a
# function whose type depends on its parameters.
def
field
(
*
,
default
=
_MISSING
,
default_factory
=
_MISSING
,
init
=
True
,
repr
=
True
,
hash
=
None
,
compare
=
True
,
metadata
=
None
):
"""Return an object to identify dataclass fields.
default is the default value of the field. default_factory is a
0-argument function called to initialize a field's value. If init
is True, the field will be a parameter to the class's __init__()
function. If repr is True, the field will be included in the
object's repr(). If hash is True, the field will be included in
the object's hash(). If compare is True, the field will be used in
comparison functions. metadata, if specified, must be a mapping
which is stored but not otherwise examined by dataclass.
It is an error to specify both default and default_factory.
"""
if
default
is
not
_MISSING
and
default_factory
is
not
_MISSING
:
raise
ValueError
(
'cannot specify both default and default_factory'
)
return
Field
(
default
,
default_factory
,
init
,
repr
,
hash
,
compare
,
metadata
)
def
_tuple_str
(
obj_name
,
fields
):
# Return a string representing each field of obj_name as a tuple
# member. So, if fields is ['x', 'y'] and obj_name is "self",
# return "(self.x,self.y)".
# Special case for the 0-tuple.
if
len
(
fields
)
==
0
:
return
'()'
# Note the trailing comma, needed if this turns out to be a 1-tuple.
return
f'(
{
","
.
join
([
f"
{
obj_name
}
.
{
f
.
name
}
" for f in fields])
}
,)'
def
_create_fn
(
name
,
args
,
body
,
globals
=
None
,
locals
=
None
,
return_type
=
_MISSING
):
# Note that we mutate locals when exec() is called. Caller beware!
if
locals
is
None
:
locals
=
{}
return_annotation
=
''
if
return_type
is
not
_MISSING
:
locals
[
'_return_type'
]
=
return_type
return_annotation
=
'->_return_type'
args
=
','
.
join
(
args
)
body
=
'
\
n
'
.
join
(
f'
{
b
}
'
for
b
in
body
)
txt
=
f'def
{
name
}
(
{
args
}
)
{
return_annotation
}
:
\
n
{
body
}
'
exec
(
txt
,
globals
,
locals
)
return
locals
[
name
]
def
_field_assign
(
frozen
,
name
,
value
,
self_name
):
# If we're a frozen class, then assign to our fields in __init__
# via object.__setattr__. Otherwise, just use a simple
# assignment.
# self_name is what "self" is called in this function: don't
# hard-code "self", since that might be a field name.
if
frozen
:
return
f'object.__setattr__(
{
self_name
}
,
{
name
!
r
}
,
{
value
}
)'
return
f'
{
self_name
}
.
{
name
}
=
{
value
}
'
def
_field_init
(
f
,
frozen
,
globals
,
self_name
):
# Return the text of the line in the body of __init__ that will
# initialize this field.
default_name
=
f'_dflt_
{
f
.
name
}
'
if
f
.
default_factory
is
not
_MISSING
:
if
f
.
init
:
# This field has a default factory. If a parameter is
# given, use it. If not, call the factory.
globals
[
default_name
]
=
f
.
default_factory
value
=
(
f'
{
default_name
}
() '
f'if
{
f
.
name
}
is _HAS_DEFAULT_FACTORY '
f'else
{
f
.
name
}
'
)
else
:
# This is a field that's not in the __init__ params, but
# has a default factory function. It needs to be
# initialized here by calling the factory function,
# because there's no other way to initialize it.
# For a field initialized with a default=defaultvalue, the
# class dict just has the default value
# (cls.fieldname=defaultvalue). But that won't work for a
# default factory, the factory must be called in __init__
# and we must assign that to self.fieldname. We can't
# fall back to the class dict's value, both because it's
# not set, and because it might be different per-class
# (which, after all, is why we have a factory function!).
globals
[
default_name
]
=
f
.
default_factory
value
=
f'
{
default_name
}
()'
else
:
# No default factory.
if
f
.
init
:
if
f
.
default
is
_MISSING
:
# There's no default, just do an assignment.
value
=
f
.
name
elif
f
.
default
is
not
_MISSING
:
globals
[
default_name
]
=
f
.
default
value
=
f
.
name
else
:
# This field does not need initialization. Signify that to
# the caller by returning None.
return
None
# Only test this now, so that we can create variables for the
# default. However, return None to signify that we're not going
# to actually do the assignment statement for InitVars.
if
f
.
_field_type
==
_FIELD_INITVAR
:
return
None
# Now, actually generate the field assignment.
return
_field_assign
(
frozen
,
f
.
name
,
value
,
self_name
)
def
_init_param
(
f
):
# Return the __init__ parameter string for this field.
# For example, the equivalent of 'x:int=3' (except instead of 'int',
# reference a variable set to int, and instead of '3', reference a
# variable set to 3).
if
f
.
default
is
_MISSING
and
f
.
default_factory
is
_MISSING
:
# There's no default, and no default_factory, just
# output the variable name and type.
default
=
''
elif
f
.
default
is
not
_MISSING
:
# There's a default, this will be the name that's used to look it up.
default
=
f'=_dflt_
{
f
.
name
}
'
elif
f
.
default_factory
is
not
_MISSING
:
# There's a factory function. Set a marker.
default
=
'=_HAS_DEFAULT_FACTORY'
return
f'
{
f
.
name
}
:_type_
{
f
.
name
}{
default
}
'
def
_init_fn
(
fields
,
frozen
,
has_post_init
,
self_name
):
# fields contains both real fields and InitVar pseudo-fields.
# Make sure we don't have fields without defaults following fields
# with defaults. This actually would be caught when exec-ing the
# function source code, but catching it here gives a better error
# message, and future-proofs us in case we build up the function
# using ast.
seen_default
=
False
for
f
in
fields
:
# Only consider fields in the __init__ call.
if
f
.
init
:
if
not
(
f
.
default
is
_MISSING
and
f
.
default_factory
is
_MISSING
):
seen_default
=
True
elif
seen_default
:
raise
TypeError
(
f'non-default argument
{
f
.
name
!
r
}
'
'follows default argument'
)
globals
=
{
'_MISSING'
:
_MISSING
,
'_HAS_DEFAULT_FACTORY'
:
_HAS_DEFAULT_FACTORY
}
body_lines
=
[]
for
f
in
fields
:
# Do not initialize the pseudo-fields, only the real ones.
line
=
_field_init
(
f
,
frozen
,
globals
,
self_name
)
if
line
is
not
None
:
# line is None means that this field doesn't require
# initialization. Just skip it.
body_lines
.
append
(
line
)
# Does this class have a post-init function?
if
has_post_init
:
params_str
=
','
.
join
(
f
.
name
for
f
in
fields
if
f
.
_field_type
is
_FIELD_INITVAR
)
body_lines
+=
[
f'
{
self_name
}
.
{
_POST_INIT_NAME
}
(
{
params_str
}
)'
]
# If no body lines, use 'pass'.
if
len
(
body_lines
)
==
0
:
body_lines
=
[
'pass'
]
locals
=
{
f'_type_
{
f
.
name
}
'
:
f
.
type
for
f
in
fields
}
return
_create_fn
(
'__init__'
,
[
self_name
]
+
[
_init_param
(
f
)
for
f
in
fields
if
f
.
init
],
body_lines
,
locals
=
locals
,
globals
=
globals
,
return_type
=
None
)
def
_repr_fn
(
fields
):
return
_create_fn
(
'__repr__'
,
[
'self'
],
[
'return self.__class__.__qualname__ + f"('
+
', '
.
join
([
f"
{
f
.
name
}
={{self.
{
f
.
name
}
!r}}"
for
f
in
fields
])
+
')"'
])
def
_frozen_setattr
(
self
,
name
,
value
):
raise
FrozenInstanceError
(
f'cannot assign to field
{
name
!
r
}
'
)
def
_frozen_delattr
(
self
,
name
):
raise
FrozenInstanceError
(
f'cannot delete field
{
name
!
r
}
'
)
def
_cmp_fn
(
name
,
op
,
self_tuple
,
other_tuple
):
# Create a comparison function. If the fields in the object are
# named 'x' and 'y', then self_tuple is the string
# '(self.x,self.y)' and other_tuple is the string
# '(other.x,other.y)'.
return
_create_fn
(
name
,
[
'self'
,
'other'
],
[
'if other.__class__ is self.__class__:'
,
f' return
{
self_tuple
}{
op
}{
other_tuple
}
'
,
'return NotImplemented'
])
def
_set_eq_fns
(
cls
,
fields
):
# Create and set the equality comparison methods on cls.
# Pre-compute self_tuple and other_tuple, then re-use them for
# each function.
self_tuple
=
_tuple_str
(
'self'
,
fields
)
other_tuple
=
_tuple_str
(
'other'
,
fields
)
for
name
,
op
in
[(
'__eq__'
,
'=='
),
(
'__ne__'
,
'!='
),
]:
_set_attribute
(
cls
,
name
,
_cmp_fn
(
name
,
op
,
self_tuple
,
other_tuple
))
def
_set_order_fns
(
cls
,
fields
):
# Create and set the ordering methods on cls.
# Pre-compute self_tuple and other_tuple, then re-use them for
# each function.
self_tuple
=
_tuple_str
(
'self'
,
fields
)
other_tuple
=
_tuple_str
(
'other'
,
fields
)
for
name
,
op
in
[(
'__lt__'
,
'<'
),
(
'__le__'
,
'<='
),
(
'__gt__'
,
'>'
),
(
'__ge__'
,
'>='
),
]:
_set_attribute
(
cls
,
name
,
_cmp_fn
(
name
,
op
,
self_tuple
,
other_tuple
))
def
_hash_fn
(
fields
):
self_tuple
=
_tuple_str
(
'self'
,
fields
)
return
_create_fn
(
'__hash__'
,
[
'self'
],
[
f'return hash(
{
self_tuple
}
)'
])
def
_get_field
(
cls
,
a_name
,
a_type
):
# Return a Field object, for this field name and type. ClassVars
# and InitVars are also returned, but marked as such (see
# f._field_type).
# If the default value isn't derived from field, then it's
# only a normal default value. Convert it to a Field().
default
=
getattr
(
cls
,
a_name
,
_MISSING
)
if
isinstance
(
default
,
Field
):
f
=
default
else
:
f
=
field
(
default
=
default
)
# Assume it's a normal field until proven otherwise.
f
.
_field_type
=
_FIELD
# Only at this point do we know the name and the type. Set them.
f
.
name
=
a_name
f
.
type
=
a_type
# If typing has not been imported, then it's impossible for
# any annotation to be a ClassVar. So, only look for ClassVar
# if typing has been imported.
typing
=
sys
.
modules
.
get
(
'typing'
)
if
typing
is
not
None
:
# This test uses a typing internal class, but it's the best
# way to test if this is a ClassVar.
if
type
(
a_type
)
is
typing
.
_ClassVar
:
# This field is a ClassVar, so it's not a field.
f
.
_field_type
=
_FIELD_CLASSVAR
if
f
.
_field_type
is
_FIELD
:
# Check if this is an InitVar.
if
a_type
is
InitVar
:
# InitVars are not fields, either.
f
.
_field_type
=
_FIELD_INITVAR
# Validations for fields. This is delayed until now, instead of
# in the Field() constructor, since only here do we know the field
# name, which allows better error reporting.
# Special restrictions for ClassVar and InitVar.
if
f
.
_field_type
in
(
_FIELD_CLASSVAR
,
_FIELD_INITVAR
):
if
f
.
default_factory
is
not
_MISSING
:
raise
TypeError
(
f'field
{
f
.
name
}
cannot have a '
'default factory'
)
# Should I check for other field settings? default_factory
# seems the most serious to check for. Maybe add others. For
# example, how about init=False (or really,
# init=<not-the-default-init-value>)? It makes no sense for
# ClassVar and InitVar to specify init=<anything>.
# For real fields, disallow mutable defaults for known types.
if
f
.
_field_type
is
_FIELD
and
isinstance
(
f
.
default
,
(
list
,
dict
,
set
)):
raise
ValueError
(
f'mutable default
{
type
(
f
.
default
)
}
for field '
f'
{
f
.
name
}
is not allowed: use default_factory'
)
return
f
def
_find_fields
(
cls
):
# Return a list of Field objects, in order, for this class (and no
# base classes). Fields are found from __annotations__ (which is
# guaranteed to be ordered). Default values are from class
# attributes, if a field has a default. If the default value is
# a Field(), then it contains additional info beyond (and
# possibly including) the actual default value. Pseudo-fields
# ClassVars and InitVars are included, despite the fact that
# they're not real fields. That's deal with later.
annotations
=
getattr
(
cls
,
'__annotations__'
,
{})
return
[
_get_field
(
cls
,
a_name
,
a_type
)
for
a_name
,
a_type
in
annotations
.
items
()]
def
_set_attribute
(
cls
,
name
,
value
):
# Raise TypeError if an attribute by this name already exists.
if
name
in
cls
.
__dict__
:
raise
TypeError
(
f'Cannot overwrite attribute
{
name
}
'
f'in
{
cls
.
__name__
}
'
)
setattr
(
cls
,
name
,
value
)
def
_process_class
(
cls
,
repr
,
eq
,
order
,
hash
,
init
,
frozen
):
# Use an OrderedDict because:
# - Order matters!
# - Derived class fields overwrite base class fields, but the
# order is defined by the base class, which is found first.
fields
=
collections
.
OrderedDict
()
# Find our base classes in reverse MRO order, and exclude
# ourselves. In reversed order so that more derived classes
# override earlier field definitions in base classes.
for
b
in
cls
.
__mro__
[
-
1
:
0
:
-
1
]:
# Only process classes that have been processed by our
# decorator. That is, they have a _MARKER attribute.
base_fields
=
getattr
(
b
,
_MARKER
,
None
)
if
base_fields
:
for
f
in
base_fields
.
values
():
fields
[
f
.
name
]
=
f
# Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes)
# where we can.
for
f
in
_find_fields
(
cls
):
fields
[
f
.
name
]
=
f
# If the class attribute (which is the default value for
# this field) exists and is of type 'Field', replace it
# with the real default. This is so that normal class
# introspection sees a real default value, not a Field.
if
isinstance
(
getattr
(
cls
,
f
.
name
,
None
),
Field
):
if
f
.
default
is
_MISSING
:
# If there's no default, delete the class attribute.
# This happens if we specify field(repr=False), for
# example (that is, we specified a field object, but
# no default value). Also if we're using a default
# factory. The class attribute should not be set at
# all in the post-processed class.
delattr
(
cls
,
f
.
name
)
else
:
setattr
(
cls
,
f
.
name
,
f
.
default
)
# Remember all of the fields on our class (including bases). This
# marks this class as being a dataclass.
setattr
(
cls
,
_MARKER
,
fields
)
# We also need to check if a parent class is frozen: frozen has to
# be inherited down.
is_frozen
=
frozen
or
cls
.
__setattr__
is
_frozen_setattr
# If we're generating ordering methods, we must be generating
# the eq methods.
if
order
and
not
eq
:
raise
ValueError
(
'eq must be true if order is true'
)
if
init
:
# Does this class have a post-init function?
has_post_init
=
hasattr
(
cls
,
_POST_INIT_NAME
)
# Include InitVars and regular fields (so, not ClassVars).
_set_attribute
(
cls
,
'__init__'
,
_init_fn
(
list
(
filter
(
lambda
f
:
f
.
_field_type
in
(
_FIELD
,
_FIELD_INITVAR
),
fields
.
values
())),
is_frozen
,
has_post_init
,
# The name to use for the "self" param
# in __init__. Use "self" if possible.
'__dataclass_self__'
if
'self'
in
fields
else
'self'
,
))
# Get the fields as a list, and include only real fields. This is
# used in all of the following methods.
field_list
=
list
(
filter
(
lambda
f
:
f
.
_field_type
is
_FIELD
,
fields
.
values
()))
if
repr
:
_set_attribute
(
cls
,
'__repr__'
,
_repr_fn
(
list
(
filter
(
lambda
f
:
f
.
repr
,
field_list
))))
if
is_frozen
:
_set_attribute
(
cls
,
'__setattr__'
,
_frozen_setattr
)
_set_attribute
(
cls
,
'__delattr__'
,
_frozen_delattr
)
generate_hash
=
False
if
hash
is
None
:
if
eq
and
frozen
:
# Generate a hash function.
generate_hash
=
True
elif
eq
and
not
frozen
:
# Not hashable.
_set_attribute
(
cls
,
'__hash__'
,
None
)
elif
not
eq
:
# Otherwise, use the base class definition of hash(). That is,
# don't set anything on this class.
pass
else
:
assert
"can't get here"
else
:
generate_hash
=
hash
if
generate_hash
:
_set_attribute
(
cls
,
'__hash__'
,
_hash_fn
(
list
(
filter
(
lambda
f
:
f
.
compare
if
f
.
hash
is
None
else
f
.
hash
,
field_list
))))
if
eq
:
# Create and __eq__ and __ne__ methods.
_set_eq_fns
(
cls
,
list
(
filter
(
lambda
f
:
f
.
compare
,
field_list
)))
if
order
:
# Create and __lt__, __le__, __gt__, and __ge__ methods.
# Create and set the comparison functions.
_set_order_fns
(
cls
,
list
(
filter
(
lambda
f
:
f
.
compare
,
field_list
)))
if
not
getattr
(
cls
,
'__doc__'
):
# Create a class doc-string.
cls
.
__doc__
=
(
cls
.
__name__
+
str
(
inspect
.
signature
(
cls
)).
replace
(
' -> None'
,
''
))
return
cls
# _cls should never be specified by keyword, so start it with an
# underscore. The presense of _cls is used to detect if this
# decorator is being called with parameters or not.
def
dataclass
(
_cls
=
None
,
*
,
init
=
True
,
repr
=
True
,
eq
=
True
,
order
=
False
,
hash
=
None
,
frozen
=
False
):
"""Returns the same class as was passed in, with dunder methods
added based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine fields.
If init is true, an __init__() method is added to the class. If
repr is true, a __repr__() method is added. If order is true, rich
comparison dunder methods are added. If hash is true, a __hash__()
method function is added. If frozen is true, fields may not be
assigned to after instance creation.
"""
def
wrap
(
cls
):
return
_process_class
(
cls
,
repr
,
eq
,
order
,
hash
,
init
,
frozen
)
# See if we're being called as @dataclass or @dataclass().
if
_cls
is
None
:
# We're called with parens.
return
wrap
# We're called as @dataclass without parens.
return
wrap
(
_cls
)
def
fields
(
class_or_instance
):
"""Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of
type Field.
"""
# Might it be worth caching this, per class?
try
:
fields
=
getattr
(
class_or_instance
,
_MARKER
)
except
AttributeError
:
raise
TypeError
(
'must be called with a dataclass type or instance'
)
# Exclude pseudo-fields.
return
tuple
(
f
for
f
in
fields
.
values
()
if
f
.
_field_type
is
_FIELD
)
def
_isdataclass
(
obj
):
"""Returns True if obj is an instance of a dataclass."""
return
not
isinstance
(
obj
,
type
)
and
hasattr
(
obj
,
_MARKER
)
def
asdict
(
obj
,
*
,
dict_factory
=
dict
):
"""Return the fields of a dataclass instance as a new dictionary mapping
field names to field values.
Example usage:
@dataclass
class C:
x: int
y: int
c = C(1, 2)
assert asdict(c) == {'x': 1, 'y': 2}
If given, 'dict_factory' will be used instead of built-in dict.
The function applies recursively to field values that are
dataclass instances. This will also look into built-in containers:
tuples, lists, and dicts.
"""
if
not
_isdataclass
(
obj
):
raise
TypeError
(
"asdict() should be called on dataclass instances"
)
return
_asdict_inner
(
obj
,
dict_factory
)
def
_asdict_inner
(
obj
,
dict_factory
):
if
_isdataclass
(
obj
):
result
=
[]
for
f
in
fields
(
obj
):
value
=
_asdict_inner
(
getattr
(
obj
,
f
.
name
),
dict_factory
)
result
.
append
((
f
.
name
,
value
))
return
dict_factory
(
result
)
elif
isinstance
(
obj
,
(
list
,
tuple
)):
return
type
(
obj
)(
_asdict_inner
(
v
,
dict_factory
)
for
v
in
obj
)
elif
isinstance
(
obj
,
dict
):
return
type
(
obj
)((
_asdict_inner
(
k
,
dict_factory
),
_asdict_inner
(
v
,
dict_factory
))
for
k
,
v
in
obj
.
items
())
else
:
return
deepcopy
(
obj
)
def
astuple
(
obj
,
*
,
tuple_factory
=
tuple
):
"""Return the fields of a dataclass instance as a new tuple of field values.
Example usage::
@dataclass
class C:
x: int
y: int
c = C(1, 2)
assert asdtuple(c) == (1, 2)
If given, 'tuple_factory' will be used instead of built-in tuple.
The function applies recursively to field values that are
dataclass instances. This will also look into built-in containers:
tuples, lists, and dicts.
"""
if
not
_isdataclass
(
obj
):
raise
TypeError
(
"astuple() should be called on dataclass instances"
)
return
_astuple_inner
(
obj
,
tuple_factory
)
def
_astuple_inner
(
obj
,
tuple_factory
):
if
_isdataclass
(
obj
):
result
=
[]
for
f
in
fields
(
obj
):
value
=
_astuple_inner
(
getattr
(
obj
,
f
.
name
),
tuple_factory
)
result
.
append
(
value
)
return
tuple_factory
(
result
)
elif
isinstance
(
obj
,
(
list
,
tuple
)):
return
type
(
obj
)(
_astuple_inner
(
v
,
tuple_factory
)
for
v
in
obj
)
elif
isinstance
(
obj
,
dict
):
return
type
(
obj
)((
_astuple_inner
(
k
,
tuple_factory
),
_astuple_inner
(
v
,
tuple_factory
))
for
k
,
v
in
obj
.
items
())
else
:
return
deepcopy
(
obj
)
def
make_dataclass
(
cls_name
,
fields
,
*
,
bases
=
(),
namespace
=
None
):
"""Return a new dynamically created dataclass.
The dataclass name will be 'cls_name'. 'fields' is an interable
of either (name, type) or (name, type, Field) objects. Field
objects are created by calling 'field(name, type [, Field])'.
C = make_class('C', [('a', int', ('b', int, Field(init=False))], bases=Base)
is equivalent to:
@dataclass
class C(Base):
a: int
b: int = field(init=False)
For the bases and namespace paremeters, see the builtin type() function.
"""
if
namespace
is
None
:
namespace
=
{}
else
:
# Copy namespace since we're going to mutate it.
namespace
=
namespace
.
copy
()
anns
=
collections
.
OrderedDict
((
name
,
tp
)
for
name
,
tp
,
*
_
in
fields
)
namespace
[
'__annotations__'
]
=
anns
for
item
in
fields
:
if
len
(
item
)
==
3
:
name
,
tp
,
spec
=
item
namespace
[
name
]
=
spec
cls
=
type
(
cls_name
,
bases
,
namespace
)
return
dataclass
(
cls
)
def
replace
(
obj
,
**
changes
):
"""Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@dataclass(frozen=True)
class C:
x: int
y: int
c = C(1, 2)
c1 = replace(c, x=3)
assert c1.x == 3 and c1.y == 2
"""
# We're going to mutate 'changes', but that's okay because it's a new
# dict, even if called with 'replace(obj, **my_changes)'.
if
not
_isdataclass
(
obj
):
raise
TypeError
(
"replace() should be called on dataclass instances"
)
# It's an error to have init=False fields in 'changes'.
# If a field is not in 'changes', read its value from the provided obj.
for
f
in
getattr
(
obj
,
_MARKER
).
values
():
if
not
f
.
init
:
# Error if this field is specified in changes.
if
f
.
name
in
changes
:
raise
ValueError
(
f'field
{
f
.
name
}
is declared with '
'init=False, it cannot be specified with '
'replace()'
)
continue
if
f
.
name
not
in
changes
:
changes
[
f
.
name
]
=
getattr
(
obj
,
f
.
name
)
# Create the new object, which calls __init__() and __post_init__
# (if defined), using all of the init fields we've added and/or
# left in 'changes'.
# If there are values supplied in changes that aren't fields, this
# will correctly raise a TypeError.
return
obj
.
__class__
(
**
changes
)
Lib/test/test_dataclasses.py
0 → 100755
View file @
f0db54a0
from
dataclasses
import
(
dataclass
,
field
,
FrozenInstanceError
,
fields
,
asdict
,
astuple
,
make_dataclass
,
replace
,
InitVar
,
Field
)
import
pickle
import
inspect
import
unittest
from
unittest.mock
import
Mock
from
typing
import
ClassVar
,
Any
,
List
,
Union
,
Tuple
,
Dict
,
Generic
,
TypeVar
from
collections
import
deque
,
OrderedDict
,
namedtuple
# Just any custom exception we can catch.
class
CustomError
(
Exception
):
pass
class
TestCase
(
unittest
.
TestCase
):
def
test_no_fields
(
self
):
@
dataclass
class
C
:
pass
o
=
C
()
self
.
assertEqual
(
len
(
fields
(
C
)),
0
)
def
test_one_field_no_default
(
self
):
@
dataclass
class
C
:
x
:
int
o
=
C
(
42
)
self
.
assertEqual
(
o
.
x
,
42
)
def
test_named_init_params
(
self
):
@
dataclass
class
C
:
x
:
int
o
=
C
(
x
=
32
)
self
.
assertEqual
(
o
.
x
,
32
)
def
test_two_fields_one_default
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
int
=
0
o
=
C
(
3
)
self
.
assertEqual
((
o
.
x
,
o
.
y
),
(
3
,
0
))
# Non-defaults following defaults.
with
self
.
assertRaisesRegex
(
TypeError
,
"non-default argument 'y' follows "
"default argument"
):
@
dataclass
class
C
:
x
:
int
=
0
y
:
int
# A derived class adds a non-default field after a default one.
with
self
.
assertRaisesRegex
(
TypeError
,
"non-default argument 'y' follows "
"default argument"
):
@
dataclass
class
B
:
x
:
int
=
0
@
dataclass
class
C
(
B
):
y
:
int
# Override a base class field and add a default to
# a field which didn't use to have a default.
with
self
.
assertRaisesRegex
(
TypeError
,
"non-default argument 'y' follows "
"default argument"
):
@
dataclass
class
B
:
x
:
int
y
:
int
@
dataclass
class
C
(
B
):
x
:
int
=
0
def
test_overwriting_init
(
self
):
with
self
.
assertRaisesRegex
(
TypeError
,
'Cannot overwrite attribute __init__ '
'in C'
):
@
dataclass
class
C
:
x
:
int
def
__init__
(
self
,
x
):
self
.
x
=
2
*
x
@
dataclass
(
init
=
False
)
class
C
:
x
:
int
def
__init__
(
self
,
x
):
self
.
x
=
2
*
x
self
.
assertEqual
(
C
(
5
).
x
,
10
)
def
test_overwriting_repr
(
self
):
with
self
.
assertRaisesRegex
(
TypeError
,
'Cannot overwrite attribute __repr__ '
'in C'
):
@
dataclass
class
C
:
x
:
int
def
__repr__
(
self
):
pass
@
dataclass
(
repr
=
False
)
class
C
:
x
:
int
def
__repr__
(
self
):
return
'x'
self
.
assertEqual
(
repr
(
C
(
0
)),
'x'
)
def
test_overwriting_cmp
(
self
):
with
self
.
assertRaisesRegex
(
TypeError
,
'Cannot overwrite attribute __eq__ '
'in C'
):
# This will generate the comparison functions, make sure we can't
# overwrite them.
@
dataclass
(
hash
=
False
,
frozen
=
False
)
class
C
:
x
:
int
def
__eq__
(
self
):
pass
@
dataclass
(
order
=
False
,
eq
=
False
)
class
C
:
x
:
int
def
__eq__
(
self
,
other
):
return
True
self
.
assertEqual
(
C
(
0
),
'x'
)
def
test_overwriting_hash
(
self
):
with
self
.
assertRaisesRegex
(
TypeError
,
'Cannot overwrite attribute __hash__ '
'in C'
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
def
__hash__
(
self
):
pass
@
dataclass
(
frozen
=
True
,
hash
=
False
)
class
C
:
x
:
int
def
__hash__
(
self
):
return
600
self
.
assertEqual
(
hash
(
C
(
0
)),
600
)
with
self
.
assertRaisesRegex
(
TypeError
,
'Cannot overwrite attribute __hash__ '
'in C'
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
def
__hash__
(
self
):
pass
@
dataclass
(
frozen
=
True
,
hash
=
False
)
class
C
:
x
:
int
def
__hash__
(
self
):
return
600
self
.
assertEqual
(
hash
(
C
(
0
)),
600
)
def
test_overwriting_frozen
(
self
):
# frozen uses __setattr__ and __delattr__
with
self
.
assertRaisesRegex
(
TypeError
,
'Cannot overwrite attribute __setattr__ '
'in C'
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
def
__setattr__
(
self
):
pass
with
self
.
assertRaisesRegex
(
TypeError
,
'Cannot overwrite attribute __delattr__ '
'in C'
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
def
__delattr__
(
self
):
pass
@
dataclass
(
frozen
=
False
)
class
C
:
x
:
int
def
__setattr__
(
self
,
name
,
value
):
self
.
__dict__
[
'x'
]
=
value
*
2
self
.
assertEqual
(
C
(
10
).
x
,
20
)
def
test_overwrite_fields_in_derived_class
(
self
):
# Note that x from C1 replaces x in Base, but the order remains
# the same as defined in Base.
@
dataclass
class
Base
:
x
:
Any
=
15.0
y
:
int
=
0
@
dataclass
class
C1
(
Base
):
z
:
int
=
10
x
:
int
=
15
o
=
Base
()
self
.
assertEqual
(
repr
(
o
),
'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)'
)
o
=
C1
()
self
.
assertEqual
(
repr
(
o
),
'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)'
)
o
=
C1
(
x
=
5
)
self
.
assertEqual
(
repr
(
o
),
'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)'
)
def
test_field_named_self
(
self
):
@
dataclass
class
C
:
self
:
str
c
=
C
(
'foo'
)
self
.
assertEqual
(
c
.
self
,
'foo'
)
# Make sure the first parameter is not named 'self'.
sig
=
inspect
.
signature
(
C
.
__init__
)
first
=
next
(
iter
(
sig
.
parameters
))
self
.
assertNotEqual
(
'self'
,
first
)
# But we do use 'self' if no field named self.
@
dataclass
class
C
:
selfx
:
str
# Make sure the first parameter is named 'self'.
sig
=
inspect
.
signature
(
C
.
__init__
)
first
=
next
(
iter
(
sig
.
parameters
))
self
.
assertEqual
(
'self'
,
first
)
def
test_repr
(
self
):
@
dataclass
class
B
:
x
:
int
@
dataclass
class
C
(
B
):
y
:
int
=
10
o
=
C
(
4
)
self
.
assertEqual
(
repr
(
o
),
'TestCase.test_repr.<locals>.C(x=4, y=10)'
)
@
dataclass
class
D
(
C
):
x
:
int
=
20
self
.
assertEqual
(
repr
(
D
()),
'TestCase.test_repr.<locals>.D(x=20, y=10)'
)
@
dataclass
class
C
:
@
dataclass
class
D
:
i
:
int
@
dataclass
class
E
:
pass
self
.
assertEqual
(
repr
(
C
.
D
(
0
)),
'TestCase.test_repr.<locals>.C.D(i=0)'
)
self
.
assertEqual
(
repr
(
C
.
E
()),
'TestCase.test_repr.<locals>.C.E()'
)
def
test_0_field_compare
(
self
):
# Ensure that order=False is the default.
@
dataclass
class
C0
:
pass
@
dataclass
(
order
=
False
)
class
C1
:
pass
for
cls
in
[
C0
,
C1
]:
with
self
.
subTest
(
cls
=
cls
):
self
.
assertEqual
(
cls
(),
cls
())
for
idx
,
fn
in
enumerate
([
lambda
a
,
b
:
a
<
b
,
lambda
a
,
b
:
a
<=
b
,
lambda
a
,
b
:
a
>
b
,
lambda
a
,
b
:
a
>=
b
]):
with
self
.
subTest
(
idx
=
idx
):
with
self
.
assertRaisesRegex
(
TypeError
,
f"not supported between instances of '
{
cls
.
__name__
}
' and '
{
cls
.
__name__
}
'"
):
fn
(
cls
(),
cls
())
@
dataclass
(
order
=
True
)
class
C
:
pass
self
.
assertLessEqual
(
C
(),
C
())
self
.
assertGreaterEqual
(
C
(),
C
())
def
test_1_field_compare
(
self
):
# Ensure that order=False is the default.
@
dataclass
class
C0
:
x
:
int
@
dataclass
(
order
=
False
)
class
C1
:
x
:
int
for
cls
in
[
C0
,
C1
]:
with
self
.
subTest
(
cls
=
cls
):
self
.
assertEqual
(
cls
(
1
),
cls
(
1
))
self
.
assertNotEqual
(
cls
(
0
),
cls
(
1
))
for
idx
,
fn
in
enumerate
([
lambda
a
,
b
:
a
<
b
,
lambda
a
,
b
:
a
<=
b
,
lambda
a
,
b
:
a
>
b
,
lambda
a
,
b
:
a
>=
b
]):
with
self
.
subTest
(
idx
=
idx
):
with
self
.
assertRaisesRegex
(
TypeError
,
f"not supported between instances of '
{
cls
.
__name__
}
' and '
{
cls
.
__name__
}
'"
):
fn
(
cls
(
0
),
cls
(
0
))
@
dataclass
(
order
=
True
)
class
C
:
x
:
int
self
.
assertLess
(
C
(
0
),
C
(
1
))
self
.
assertLessEqual
(
C
(
0
),
C
(
1
))
self
.
assertLessEqual
(
C
(
1
),
C
(
1
))
self
.
assertGreater
(
C
(
1
),
C
(
0
))
self
.
assertGreaterEqual
(
C
(
1
),
C
(
0
))
self
.
assertGreaterEqual
(
C
(
1
),
C
(
1
))
def
test_simple_compare
(
self
):
# Ensure that order=False is the default.
@
dataclass
class
C0
:
x
:
int
y
:
int
@
dataclass
(
order
=
False
)
class
C1
:
x
:
int
y
:
int
for
cls
in
[
C0
,
C1
]:
with
self
.
subTest
(
cls
=
cls
):
self
.
assertEqual
(
cls
(
0
,
0
),
cls
(
0
,
0
))
self
.
assertEqual
(
cls
(
1
,
2
),
cls
(
1
,
2
))
self
.
assertNotEqual
(
cls
(
1
,
0
),
cls
(
0
,
0
))
self
.
assertNotEqual
(
cls
(
1
,
0
),
cls
(
1
,
1
))
for
idx
,
fn
in
enumerate
([
lambda
a
,
b
:
a
<
b
,
lambda
a
,
b
:
a
<=
b
,
lambda
a
,
b
:
a
>
b
,
lambda
a
,
b
:
a
>=
b
]):
with
self
.
subTest
(
idx
=
idx
):
with
self
.
assertRaisesRegex
(
TypeError
,
f"not supported between instances of '
{
cls
.
__name__
}
' and '
{
cls
.
__name__
}
'"
):
fn
(
cls
(
0
,
0
),
cls
(
0
,
0
))
@
dataclass
(
order
=
True
)
class
C
:
x
:
int
y
:
int
for
idx
,
fn
in
enumerate
([
lambda
a
,
b
:
a
==
b
,
lambda
a
,
b
:
a
<=
b
,
lambda
a
,
b
:
a
>=
b
]):
with
self
.
subTest
(
idx
=
idx
):
self
.
assertTrue
(
fn
(
C
(
0
,
0
),
C
(
0
,
0
)))
for
idx
,
fn
in
enumerate
([
lambda
a
,
b
:
a
<
b
,
lambda
a
,
b
:
a
<=
b
,
lambda
a
,
b
:
a
!=
b
]):
with
self
.
subTest
(
idx
=
idx
):
self
.
assertTrue
(
fn
(
C
(
0
,
0
),
C
(
0
,
1
)))
self
.
assertTrue
(
fn
(
C
(
0
,
1
),
C
(
1
,
0
)))
self
.
assertTrue
(
fn
(
C
(
1
,
0
),
C
(
1
,
1
)))
for
idx
,
fn
in
enumerate
([
lambda
a
,
b
:
a
>
b
,
lambda
a
,
b
:
a
>=
b
,
lambda
a
,
b
:
a
!=
b
]):
with
self
.
subTest
(
idx
=
idx
):
self
.
assertTrue
(
fn
(
C
(
0
,
1
),
C
(
0
,
0
)))
self
.
assertTrue
(
fn
(
C
(
1
,
0
),
C
(
0
,
1
)))
self
.
assertTrue
(
fn
(
C
(
1
,
1
),
C
(
1
,
0
)))
def
test_compare_subclasses
(
self
):
# Comparisons fail for subclasses, even if no fields
# are added.
@
dataclass
class
B
:
i
:
int
@
dataclass
class
C
(
B
):
pass
for
idx
,
(
fn
,
expected
)
in
enumerate
([(
lambda
a
,
b
:
a
==
b
,
False
),
(
lambda
a
,
b
:
a
!=
b
,
True
)]):
with
self
.
subTest
(
idx
=
idx
):
self
.
assertEqual
(
fn
(
B
(
0
),
C
(
0
)),
expected
)
for
idx
,
fn
in
enumerate
([
lambda
a
,
b
:
a
<
b
,
lambda
a
,
b
:
a
<=
b
,
lambda
a
,
b
:
a
>
b
,
lambda
a
,
b
:
a
>=
b
]):
with
self
.
subTest
(
idx
=
idx
):
with
self
.
assertRaisesRegex
(
TypeError
,
"not supported between instances of 'B' and 'C'"
):
fn
(
B
(
0
),
C
(
0
))
def
test_0_field_hash
(
self
):
@
dataclass
(
hash
=
True
)
class
C
:
pass
self
.
assertEqual
(
hash
(
C
()),
hash
(()))
def
test_1_field_hash
(
self
):
@
dataclass
(
hash
=
True
)
class
C
:
x
:
int
self
.
assertEqual
(
hash
(
C
(
4
)),
hash
((
4
,)))
self
.
assertEqual
(
hash
(
C
(
42
)),
hash
((
42
,)))
def
test_hash
(
self
):
@
dataclass
(
hash
=
True
)
class
C
:
x
:
int
y
:
str
self
.
assertEqual
(
hash
(
C
(
1
,
'foo'
)),
hash
((
1
,
'foo'
)))
def
test_no_hash
(
self
):
@
dataclass
(
hash
=
None
)
class
C
:
x
:
int
with
self
.
assertRaisesRegex
(
TypeError
,
"unhashable type: 'C'"
):
hash
(
C
(
1
))
def
test_hash_rules
(
self
):
# There are 24 cases of:
# hash=True/False/None
# eq=True/False
# order=True/False
# frozen=True/False
for
(
hash
,
eq
,
order
,
frozen
,
result
)
in
[
(
False
,
False
,
False
,
False
,
'absent'
),
(
False
,
False
,
False
,
True
,
'absent'
),
(
False
,
False
,
True
,
False
,
'exception'
),
(
False
,
False
,
True
,
True
,
'exception'
),
(
False
,
True
,
False
,
False
,
'absent'
),
(
False
,
True
,
False
,
True
,
'absent'
),
(
False
,
True
,
True
,
False
,
'absent'
),
(
False
,
True
,
True
,
True
,
'absent'
),
(
True
,
False
,
False
,
False
,
'fn'
),
(
True
,
False
,
False
,
True
,
'fn'
),
(
True
,
False
,
True
,
False
,
'exception'
),
(
True
,
False
,
True
,
True
,
'exception'
),
(
True
,
True
,
False
,
False
,
'fn'
),
(
True
,
True
,
False
,
True
,
'fn'
),
(
True
,
True
,
True
,
False
,
'fn'
),
(
True
,
True
,
True
,
True
,
'fn'
),
(
None
,
False
,
False
,
False
,
'absent'
),
(
None
,
False
,
False
,
True
,
'absent'
),
(
None
,
False
,
True
,
False
,
'exception'
),
(
None
,
False
,
True
,
True
,
'exception'
),
(
None
,
True
,
False
,
False
,
'none'
),
(
None
,
True
,
False
,
True
,
'fn'
),
(
None
,
True
,
True
,
False
,
'none'
),
(
None
,
True
,
True
,
True
,
'fn'
),
]:
with
self
.
subTest
(
hash
=
hash
,
eq
=
eq
,
order
=
order
,
frozen
=
frozen
):
if
result
==
'exception'
:
with
self
.
assertRaisesRegex
(
ValueError
,
'eq must be true if order is true'
):
@
dataclass
(
hash
=
hash
,
eq
=
eq
,
order
=
order
,
frozen
=
frozen
)
class
C
:
pass
else
:
@
dataclass
(
hash
=
hash
,
eq
=
eq
,
order
=
order
,
frozen
=
frozen
)
class
C
:
pass
# See if the result matches what's expected.
if
result
==
'fn'
:
# __hash__ contains the function we generated.
self
.
assertIn
(
'__hash__'
,
C
.
__dict__
)
self
.
assertIsNotNone
(
C
.
__dict__
[
'__hash__'
])
elif
result
==
'absent'
:
# __hash__ is not present in our class.
self
.
assertNotIn
(
'__hash__'
,
C
.
__dict__
)
elif
result
==
'none'
:
# __hash__ is set to None.
self
.
assertIn
(
'__hash__'
,
C
.
__dict__
)
self
.
assertIsNone
(
C
.
__dict__
[
'__hash__'
])
else
:
assert
False
,
f'unknown result
{
result
!
r
}
'
def
test_eq_order
(
self
):
for
(
eq
,
order
,
result
)
in
[
(
False
,
False
,
'neither'
),
(
False
,
True
,
'exception'
),
(
True
,
False
,
'eq_only'
),
(
True
,
True
,
'both'
),
]:
with
self
.
subTest
(
eq
=
eq
,
order
=
order
):
if
result
==
'exception'
:
with
self
.
assertRaisesRegex
(
ValueError
,
'eq must be true if order is true'
):
@
dataclass
(
eq
=
eq
,
order
=
order
)
class
C
:
pass
else
:
@
dataclass
(
eq
=
eq
,
order
=
order
)
class
C
:
pass
if
result
==
'neither'
:
self
.
assertNotIn
(
'__eq__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__ne__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__lt__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__le__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__gt__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__ge__'
,
C
.
__dict__
)
elif
result
==
'both'
:
self
.
assertIn
(
'__eq__'
,
C
.
__dict__
)
self
.
assertIn
(
'__ne__'
,
C
.
__dict__
)
self
.
assertIn
(
'__lt__'
,
C
.
__dict__
)
self
.
assertIn
(
'__le__'
,
C
.
__dict__
)
self
.
assertIn
(
'__gt__'
,
C
.
__dict__
)
self
.
assertIn
(
'__ge__'
,
C
.
__dict__
)
elif
result
==
'eq_only'
:
self
.
assertIn
(
'__eq__'
,
C
.
__dict__
)
self
.
assertIn
(
'__ne__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__lt__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__le__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__gt__'
,
C
.
__dict__
)
self
.
assertNotIn
(
'__ge__'
,
C
.
__dict__
)
else
:
assert
False
,
f'unknown result
{
result
!
r
}
'
def
test_field_no_default
(
self
):
@
dataclass
class
C
:
x
:
int
=
field
()
self
.
assertEqual
(
C
(
5
).
x
,
5
)
with
self
.
assertRaisesRegex
(
TypeError
,
r"__init__\
(
\) missing 1 required "
"positional argument: 'x'"
):
C
()
def
test_field_default
(
self
):
default
=
object
()
@
dataclass
class
C
:
x
:
object
=
field
(
default
=
default
)
self
.
assertIs
(
C
.
x
,
default
)
c
=
C
(
10
)
self
.
assertEqual
(
c
.
x
,
10
)
# If we delete the instance attribute, we should then see the
# class attribute.
del
c
.
x
self
.
assertIs
(
c
.
x
,
default
)
self
.
assertIs
(
C
().
x
,
default
)
def
test_not_in_repr
(
self
):
@
dataclass
class
C
:
x
:
int
=
field
(
repr
=
False
)
with
self
.
assertRaises
(
TypeError
):
C
()
c
=
C
(
10
)
self
.
assertEqual
(
repr
(
c
),
'TestCase.test_not_in_repr.<locals>.C()'
)
@
dataclass
class
C
:
x
:
int
=
field
(
repr
=
False
)
y
:
int
c
=
C
(
10
,
20
)
self
.
assertEqual
(
repr
(
c
),
'TestCase.test_not_in_repr.<locals>.C(y=20)'
)
def
test_not_in_compare
(
self
):
@
dataclass
class
C
:
x
:
int
=
0
y
:
int
=
field
(
compare
=
False
,
default
=
4
)
self
.
assertEqual
(
C
(),
C
(
0
,
20
))
self
.
assertEqual
(
C
(
1
,
10
),
C
(
1
,
20
))
self
.
assertNotEqual
(
C
(
3
),
C
(
4
,
10
))
self
.
assertNotEqual
(
C
(
3
,
10
),
C
(
4
,
10
))
def
test_hash_field_rules
(
self
):
# Test all 6 cases of:
# hash=True/False/None
# compare=True/False
for
(
hash_val
,
compare
,
result
)
in
[
(
True
,
False
,
'field'
),
(
True
,
True
,
'field'
),
(
False
,
False
,
'absent'
),
(
False
,
True
,
'absent'
),
(
None
,
False
,
'absent'
),
(
None
,
True
,
'field'
),
]:
with
self
.
subTest
(
hash_val
=
hash_val
,
compare
=
compare
):
@
dataclass
(
hash
=
True
)
class
C
:
x
:
int
=
field
(
compare
=
compare
,
hash
=
hash_val
,
default
=
5
)
if
result
==
'field'
:
# __hash__ contains the field.
self
.
assertEqual
(
C
(
5
).
__hash__
(),
hash
((
5
,)))
elif
result
==
'absent'
:
# The field is not present in the hash.
self
.
assertEqual
(
C
(
5
).
__hash__
(),
hash
(()))
else
:
assert
False
,
f'unknown result
{
result
!
r
}
'
def
test_init_false_no_default
(
self
):
# If init=False and no default value, then the field won't be
# present in the instance.
@
dataclass
class
C
:
x
:
int
=
field
(
init
=
False
)
self
.
assertNotIn
(
'x'
,
C
().
__dict__
)
@
dataclass
class
C
:
x
:
int
y
:
int
=
0
z
:
int
=
field
(
init
=
False
)
t
:
int
=
10
self
.
assertNotIn
(
'z'
,
C
(
0
).
__dict__
)
self
.
assertEqual
(
vars
(
C
(
5
)),
{
't'
:
10
,
'x'
:
5
,
'y'
:
0
})
def
test_class_marker
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
str
=
field
(
init
=
False
,
default
=
None
)
z
:
str
=
field
(
repr
=
False
)
the_fields
=
fields
(
C
)
# the_fields is a tuple of 3 items, each value
# is in __annotations__.
self
.
assertIsInstance
(
the_fields
,
tuple
)
for
f
in
the_fields
:
self
.
assertIs
(
type
(
f
),
Field
)
self
.
assertIn
(
f
.
name
,
C
.
__annotations__
)
self
.
assertEqual
(
len
(
the_fields
),
3
)
self
.
assertEqual
(
the_fields
[
0
].
name
,
'x'
)
self
.
assertEqual
(
the_fields
[
0
].
type
,
int
)
self
.
assertFalse
(
hasattr
(
C
,
'x'
))
self
.
assertTrue
(
the_fields
[
0
].
init
)
self
.
assertTrue
(
the_fields
[
0
].
repr
)
self
.
assertEqual
(
the_fields
[
1
].
name
,
'y'
)
self
.
assertEqual
(
the_fields
[
1
].
type
,
str
)
self
.
assertIsNone
(
getattr
(
C
,
'y'
))
self
.
assertFalse
(
the_fields
[
1
].
init
)
self
.
assertTrue
(
the_fields
[
1
].
repr
)
self
.
assertEqual
(
the_fields
[
2
].
name
,
'z'
)
self
.
assertEqual
(
the_fields
[
2
].
type
,
str
)
self
.
assertFalse
(
hasattr
(
C
,
'z'
))
self
.
assertTrue
(
the_fields
[
2
].
init
)
self
.
assertFalse
(
the_fields
[
2
].
repr
)
def
test_field_order
(
self
):
@
dataclass
class
B
:
a
:
str
=
'B:a'
b
:
str
=
'B:b'
c
:
str
=
'B:c'
@
dataclass
class
C
(
B
):
b
:
str
=
'C:b'
self
.
assertEqual
([(
f
.
name
,
f
.
default
)
for
f
in
fields
(
C
)],
[(
'a'
,
'B:a'
),
(
'b'
,
'C:b'
),
(
'c'
,
'B:c'
)])
@
dataclass
class
D
(
B
):
c
:
str
=
'D:c'
self
.
assertEqual
([(
f
.
name
,
f
.
default
)
for
f
in
fields
(
D
)],
[(
'a'
,
'B:a'
),
(
'b'
,
'B:b'
),
(
'c'
,
'D:c'
)])
@
dataclass
class
E
(
D
):
a
:
str
=
'E:a'
d
:
str
=
'E:d'
self
.
assertEqual
([(
f
.
name
,
f
.
default
)
for
f
in
fields
(
E
)],
[(
'a'
,
'E:a'
),
(
'b'
,
'B:b'
),
(
'c'
,
'D:c'
),
(
'd'
,
'E:d'
)])
def
test_class_attrs
(
self
):
# We only have a class attribute if a default value is
# specified, either directly or via a field with a default.
default
=
object
()
@
dataclass
class
C
:
x
:
int
y
:
int
=
field
(
repr
=
False
)
z
:
object
=
default
t
:
int
=
field
(
default
=
100
)
self
.
assertFalse
(
hasattr
(
C
,
'x'
))
self
.
assertFalse
(
hasattr
(
C
,
'y'
))
self
.
assertIs
(
C
.
z
,
default
)
self
.
assertEqual
(
C
.
t
,
100
)
def
test_disallowed_mutable_defaults
(
self
):
# For the known types, don't allow mutable default values.
for
typ
,
empty
,
non_empty
in
[(
list
,
[],
[
1
]),
(
dict
,
{},
{
0
:
1
}),
(
set
,
set
(),
set
([
1
])),
]:
with
self
.
subTest
(
typ
=
typ
):
# Can't use a zero-length value.
with
self
.
assertRaisesRegex
(
ValueError
,
f'mutable default
{
typ
}
for field '
'x is not allowed'
):
@
dataclass
class
Point
:
x
:
typ
=
empty
# Nor a non-zero-length value
with
self
.
assertRaisesRegex
(
ValueError
,
f'mutable default
{
typ
}
for field '
'y is not allowed'
):
@
dataclass
class
Point
:
y
:
typ
=
non_empty
# Check subtypes also fail.
class
Subclass
(
typ
):
pass
with
self
.
assertRaisesRegex
(
ValueError
,
f"mutable default .*Subclass'>"
' for field z is not allowed'
):
@
dataclass
class
Point
:
z
:
typ
=
Subclass
()
# Because this is a ClassVar, it can be mutable.
@
dataclass
class
C
:
z
:
ClassVar
[
typ
]
=
typ
()
# Because this is a ClassVar, it can be mutable.
@
dataclass
class
C
:
x
:
ClassVar
[
typ
]
=
Subclass
()
def
test_deliberately_mutable_defaults
(
self
):
# If a mutable default isn't in the known list of
# (list, dict, set), then it's okay.
class
Mutable
:
def
__init__
(
self
):
self
.
l
=
[]
@
dataclass
class
C
:
x
:
Mutable
# These 2 instances will share this value of x.
lst
=
Mutable
()
o1
=
C
(
lst
)
o2
=
C
(
lst
)
self
.
assertEqual
(
o1
,
o2
)
o1
.
x
.
l
.
extend
([
1
,
2
])
self
.
assertEqual
(
o1
,
o2
)
self
.
assertEqual
(
o1
.
x
.
l
,
[
1
,
2
])
self
.
assertIs
(
o1
.
x
,
o2
.
x
)
def
test_no_options
(
self
):
# call with dataclass()
@
dataclass
()
class
C
:
x
:
int
self
.
assertEqual
(
C
(
42
).
x
,
42
)
def
test_not_tuple
(
self
):
# Make sure we can't be compared to a tuple.
@
dataclass
class
Point
:
x
:
int
y
:
int
self
.
assertNotEqual
(
Point
(
1
,
2
),
(
1
,
2
))
# And that we can't compare to another unrelated dataclass
@
dataclass
class
C
:
x
:
int
y
:
int
self
.
assertNotEqual
(
Point
(
1
,
3
),
C
(
1
,
3
))
def
test_base_has_init
(
self
):
class
B
:
def
__init__
(
self
):
pass
# Make sure that declaring this class doesn't raise an error.
# The issue is that we can't override __init__ in our class,
# but it should be okay to add __init__ to us if our base has
# an __init__.
@
dataclass
class
C
(
B
):
x
:
int
=
0
def
test_frozen
(
self
):
@
dataclass
(
frozen
=
True
)
class
C
:
i
:
int
c
=
C
(
10
)
self
.
assertEqual
(
c
.
i
,
10
)
with
self
.
assertRaises
(
FrozenInstanceError
):
c
.
i
=
5
self
.
assertEqual
(
c
.
i
,
10
)
# Check that a derived class is still frozen, even if not
# marked so.
@
dataclass
class
D
(
C
):
pass
d
=
D
(
20
)
self
.
assertEqual
(
d
.
i
,
20
)
with
self
.
assertRaises
(
FrozenInstanceError
):
d
.
i
=
5
self
.
assertEqual
(
d
.
i
,
20
)
def
test_not_tuple
(
self
):
# Test that some of the problems with namedtuple don't happen
# here.
@
dataclass
class
Point3D
:
x
:
int
y
:
int
z
:
int
@
dataclass
class
Date
:
year
:
int
month
:
int
day
:
int
self
.
assertNotEqual
(
Point3D
(
2017
,
6
,
3
),
Date
(
2017
,
6
,
3
))
self
.
assertNotEqual
(
Point3D
(
1
,
2
,
3
),
(
1
,
2
,
3
))
# Make sure we can't unpack
with
self
.
assertRaisesRegex
(
TypeError
,
'is not iterable'
):
x
,
y
,
z
=
Point3D
(
4
,
5
,
6
)
# Maka sure another class with the same field names isn't
# equal.
@
dataclass
class
Point3Dv1
:
x
:
int
=
0
y
:
int
=
0
z
:
int
=
0
self
.
assertNotEqual
(
Point3D
(
0
,
0
,
0
),
Point3Dv1
())
def
test_function_annotations
(
self
):
# Some dummy class and instance to use as a default.
class
F
:
pass
f
=
F
()
def
validate_class
(
cls
):
# First, check __annotations__, even though they're not
# function annotations.
self
.
assertEqual
(
cls
.
__annotations__
[
'i'
],
int
)
self
.
assertEqual
(
cls
.
__annotations__
[
'j'
],
str
)
self
.
assertEqual
(
cls
.
__annotations__
[
'k'
],
F
)
self
.
assertEqual
(
cls
.
__annotations__
[
'l'
],
float
)
self
.
assertEqual
(
cls
.
__annotations__
[
'z'
],
complex
)
# Verify __init__.
signature
=
inspect
.
signature
(
cls
.
__init__
)
# Check the return type, should be None
self
.
assertIs
(
signature
.
return_annotation
,
None
)
# Check each parameter.
params
=
iter
(
signature
.
parameters
.
values
())
param
=
next
(
params
)
# This is testing an internal name, and probably shouldn't be tested.
self
.
assertEqual
(
param
.
name
,
'self'
)
param
=
next
(
params
)
self
.
assertEqual
(
param
.
name
,
'i'
)
self
.
assertIs
(
param
.
annotation
,
int
)
self
.
assertEqual
(
param
.
default
,
inspect
.
Parameter
.
empty
)
self
.
assertEqual
(
param
.
kind
,
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
)
param
=
next
(
params
)
self
.
assertEqual
(
param
.
name
,
'j'
)
self
.
assertIs
(
param
.
annotation
,
str
)
self
.
assertEqual
(
param
.
default
,
inspect
.
Parameter
.
empty
)
self
.
assertEqual
(
param
.
kind
,
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
)
param
=
next
(
params
)
self
.
assertEqual
(
param
.
name
,
'k'
)
self
.
assertIs
(
param
.
annotation
,
F
)
# Don't test for the default, since it's set to _MISSING
self
.
assertEqual
(
param
.
kind
,
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
)
param
=
next
(
params
)
self
.
assertEqual
(
param
.
name
,
'l'
)
self
.
assertIs
(
param
.
annotation
,
float
)
# Don't test for the default, since it's set to _MISSING
self
.
assertEqual
(
param
.
kind
,
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
)
self
.
assertRaises
(
StopIteration
,
next
,
params
)
@
dataclass
class
C
:
i
:
int
j
:
str
k
:
F
=
f
l
:
float
=
field
(
default
=
None
)
z
:
complex
=
field
(
default
=
3
+
4j
,
init
=
False
)
validate_class
(
C
)
# Now repeat with __hash__.
@
dataclass
(
frozen
=
True
,
hash
=
True
)
class
C
:
i
:
int
j
:
str
k
:
F
=
f
l
:
float
=
field
(
default
=
None
)
z
:
complex
=
field
(
default
=
3
+
4j
,
init
=
False
)
validate_class
(
C
)
def
test_dont_include_other_annotations
(
self
):
@
dataclass
class
C
:
i
:
int
def
foo
(
self
)
->
int
:
return
4
@
property
def
bar
(
self
)
->
int
:
return
5
self
.
assertEqual
(
list
(
C
.
__annotations__
),
[
'i'
])
self
.
assertEqual
(
C
(
10
).
foo
(),
4
)
self
.
assertEqual
(
C
(
10
).
bar
,
5
)
def
test_post_init
(
self
):
# Just make sure it gets called
@
dataclass
class
C
:
def
__post_init__
(
self
):
raise
CustomError
()
with
self
.
assertRaises
(
CustomError
):
C
()
@
dataclass
class
C
:
i
:
int
=
10
def
__post_init__
(
self
):
if
self
.
i
==
10
:
raise
CustomError
()
with
self
.
assertRaises
(
CustomError
):
C
()
# post-init gets called, but doesn't raise. This is just
# checking that self is used correctly.
C
(
5
)
# If there's not an __init__, then post-init won't get called.
@
dataclass
(
init
=
False
)
class
C
:
def
__post_init__
(
self
):
raise
CustomError
()
# Creating the class won't raise
C
()
@
dataclass
class
C
:
x
:
int
=
0
def
__post_init__
(
self
):
self
.
x
*=
2
self
.
assertEqual
(
C
().
x
,
0
)
self
.
assertEqual
(
C
(
2
).
x
,
4
)
# Make sure that if we'r frozen, post-init can't set
# attributes.
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
=
0
def
__post_init__
(
self
):
self
.
x
*=
2
with
self
.
assertRaises
(
FrozenInstanceError
):
C
()
def
test_post_init_super
(
self
):
# Make sure super() post-init isn't called by default.
class
B
:
def
__post_init__
(
self
):
raise
CustomError
()
@
dataclass
class
C
(
B
):
def
__post_init__
(
self
):
self
.
x
=
5
self
.
assertEqual
(
C
().
x
,
5
)
# Now call super(), and it will raise
@
dataclass
class
C
(
B
):
def
__post_init__
(
self
):
super
().
__post_init__
()
with
self
.
assertRaises
(
CustomError
):
C
()
# Make sure post-init is called, even if not defined in our
# class.
@
dataclass
class
C
(
B
):
pass
with
self
.
assertRaises
(
CustomError
):
C
()
def
test_post_init_staticmethod
(
self
):
flag
=
False
@
dataclass
class
C
:
x
:
int
y
:
int
@
staticmethod
def
__post_init__
():
nonlocal
flag
flag
=
True
self
.
assertFalse
(
flag
)
c
=
C
(
3
,
4
)
self
.
assertEqual
((
c
.
x
,
c
.
y
),
(
3
,
4
))
self
.
assertTrue
(
flag
)
def
test_post_init_classmethod
(
self
):
@
dataclass
class
C
:
flag
=
False
x
:
int
y
:
int
@
classmethod
def
__post_init__
(
cls
):
cls
.
flag
=
True
self
.
assertFalse
(
C
.
flag
)
c
=
C
(
3
,
4
)
self
.
assertEqual
((
c
.
x
,
c
.
y
),
(
3
,
4
))
self
.
assertTrue
(
C
.
flag
)
def
test_class_var
(
self
):
# Make sure ClassVars are ignored in __init__, __repr__, etc.
@
dataclass
class
C
:
x
:
int
y
:
int
=
10
z
:
ClassVar
[
int
]
=
1000
w
:
ClassVar
[
int
]
=
2000
t
:
ClassVar
[
int
]
=
3000
c
=
C
(
5
)
self
.
assertEqual
(
repr
(
c
),
'TestCase.test_class_var.<locals>.C(x=5, y=10)'
)
self
.
assertEqual
(
len
(
fields
(
C
)),
2
)
# We have 2 fields
self
.
assertEqual
(
len
(
C
.
__annotations__
),
5
)
# And 3 ClassVars
self
.
assertEqual
(
c
.
z
,
1000
)
self
.
assertEqual
(
c
.
w
,
2000
)
self
.
assertEqual
(
c
.
t
,
3000
)
C
.
z
+=
1
self
.
assertEqual
(
c
.
z
,
1001
)
c
=
C
(
20
)
self
.
assertEqual
((
c
.
x
,
c
.
y
),
(
20
,
10
))
self
.
assertEqual
(
c
.
z
,
1001
)
self
.
assertEqual
(
c
.
w
,
2000
)
self
.
assertEqual
(
c
.
t
,
3000
)
def
test_class_var_no_default
(
self
):
# If a ClassVar has no default value, it should not be set on the class.
@
dataclass
class
C
:
x
:
ClassVar
[
int
]
self
.
assertNotIn
(
'x'
,
C
.
__dict__
)
def
test_class_var_default_factory
(
self
):
# It makes no sense for a ClassVar to have a default factory. When
# would it be called? Call it yourself, since it's class-wide.
with
self
.
assertRaisesRegex
(
TypeError
,
'cannot have a default factory'
):
@
dataclass
class
C
:
x
:
ClassVar
[
int
]
=
field
(
default_factory
=
int
)
self
.
assertNotIn
(
'x'
,
C
.
__dict__
)
def
test_class_var_with_default
(
self
):
# If a ClassVar has a default value, it should be set on the class.
@
dataclass
class
C
:
x
:
ClassVar
[
int
]
=
10
self
.
assertEqual
(
C
.
x
,
10
)
@
dataclass
class
C
:
x
:
ClassVar
[
int
]
=
field
(
default
=
10
)
self
.
assertEqual
(
C
.
x
,
10
)
def
test_class_var_frozen
(
self
):
# Make sure ClassVars work even if we're frozen.
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
y
:
int
=
10
z
:
ClassVar
[
int
]
=
1000
w
:
ClassVar
[
int
]
=
2000
t
:
ClassVar
[
int
]
=
3000
c
=
C
(
5
)
self
.
assertEqual
(
repr
(
C
(
5
)),
'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)'
)
self
.
assertEqual
(
len
(
fields
(
C
)),
2
)
# We have 2 fields
self
.
assertEqual
(
len
(
C
.
__annotations__
),
5
)
# And 3 ClassVars
self
.
assertEqual
(
c
.
z
,
1000
)
self
.
assertEqual
(
c
.
w
,
2000
)
self
.
assertEqual
(
c
.
t
,
3000
)
# We can still modify the ClassVar, it's only instances that are
# frozen.
C
.
z
+=
1
self
.
assertEqual
(
c
.
z
,
1001
)
c
=
C
(
20
)
self
.
assertEqual
((
c
.
x
,
c
.
y
),
(
20
,
10
))
self
.
assertEqual
(
c
.
z
,
1001
)
self
.
assertEqual
(
c
.
w
,
2000
)
self
.
assertEqual
(
c
.
t
,
3000
)
def
test_init_var_no_default
(
self
):
# If an InitVar has no default value, it should not be set on the class.
@
dataclass
class
C
:
x
:
InitVar
[
int
]
self
.
assertNotIn
(
'x'
,
C
.
__dict__
)
def
test_init_var_default_factory
(
self
):
# It makes no sense for an InitVar to have a default factory. When
# would it be called? Call it yourself, since it's class-wide.
with
self
.
assertRaisesRegex
(
TypeError
,
'cannot have a default factory'
):
@
dataclass
class
C
:
x
:
InitVar
[
int
]
=
field
(
default_factory
=
int
)
self
.
assertNotIn
(
'x'
,
C
.
__dict__
)
def
test_init_var_with_default
(
self
):
# If an InitVar has a default value, it should be set on the class.
@
dataclass
class
C
:
x
:
InitVar
[
int
]
=
10
self
.
assertEqual
(
C
.
x
,
10
)
@
dataclass
class
C
:
x
:
InitVar
[
int
]
=
field
(
default
=
10
)
self
.
assertEqual
(
C
.
x
,
10
)
def
test_init_var
(
self
):
@
dataclass
class
C
:
x
:
int
=
None
init_param
:
InitVar
[
int
]
=
None
def
__post_init__
(
self
,
init_param
):
if
self
.
x
is
None
:
self
.
x
=
init_param
*
2
c
=
C
(
init_param
=
10
)
self
.
assertEqual
(
c
.
x
,
20
)
def
test_init_var_inheritance
(
self
):
# Note that this deliberately tests that a dataclass need not
# have a __post_init__ function if it has an InitVar field.
# It could just be used in a derived class, as shown here.
@
dataclass
class
Base
:
x
:
int
init_base
:
InitVar
[
int
]
# We can instantiate by passing the InitVar, even though
# it's not used.
b
=
Base
(
0
,
10
)
self
.
assertEqual
(
vars
(
b
),
{
'x'
:
0
})
@
dataclass
class
C
(
Base
):
y
:
int
init_derived
:
InitVar
[
int
]
def
__post_init__
(
self
,
init_base
,
init_derived
):
self
.
x
=
self
.
x
+
init_base
self
.
y
=
self
.
y
+
init_derived
c
=
C
(
10
,
11
,
50
,
51
)
self
.
assertEqual
(
vars
(
c
),
{
'x'
:
21
,
'y'
:
101
})
def
test_default_factory
(
self
):
# Test a factory that returns a new list.
@
dataclass
class
C
:
x
:
int
y
:
list
=
field
(
default_factory
=
list
)
c0
=
C
(
3
)
c1
=
C
(
3
)
self
.
assertEqual
(
c0
.
x
,
3
)
self
.
assertEqual
(
c0
.
y
,
[])
self
.
assertEqual
(
c0
,
c1
)
self
.
assertIsNot
(
c0
.
y
,
c1
.
y
)
self
.
assertEqual
(
astuple
(
C
(
5
,
[
1
])),
(
5
,
[
1
]))
# Test a factory that returns a shared list.
l
=
[]
@
dataclass
class
C
:
x
:
int
y
:
list
=
field
(
default_factory
=
lambda
:
l
)
c0
=
C
(
3
)
c1
=
C
(
3
)
self
.
assertEqual
(
c0
.
x
,
3
)
self
.
assertEqual
(
c0
.
y
,
[])
self
.
assertEqual
(
c0
,
c1
)
self
.
assertIs
(
c0
.
y
,
c1
.
y
)
self
.
assertEqual
(
astuple
(
C
(
5
,
[
1
])),
(
5
,
[
1
]))
# Test various other field flags.
# repr
@
dataclass
class
C
:
x
:
list
=
field
(
default_factory
=
list
,
repr
=
False
)
self
.
assertEqual
(
repr
(
C
()),
'TestCase.test_default_factory.<locals>.C()'
)
self
.
assertEqual
(
C
().
x
,
[])
# hash
@
dataclass
(
hash
=
True
)
class
C
:
x
:
list
=
field
(
default_factory
=
list
,
hash
=
False
)
self
.
assertEqual
(
astuple
(
C
()),
([],))
self
.
assertEqual
(
hash
(
C
()),
hash
(()))
# init (see also test_default_factory_with_no_init)
@
dataclass
class
C
:
x
:
list
=
field
(
default_factory
=
list
,
init
=
False
)
self
.
assertEqual
(
astuple
(
C
()),
([],))
# compare
@
dataclass
class
C
:
x
:
list
=
field
(
default_factory
=
list
,
compare
=
False
)
self
.
assertEqual
(
C
(),
C
([
1
]))
def
test_default_factory_with_no_init
(
self
):
# We need a factory with a side effect.
factory
=
Mock
()
@
dataclass
class
C
:
x
:
list
=
field
(
default_factory
=
factory
,
init
=
False
)
# Make sure the default factory is called for each new instance.
C
().
x
self
.
assertEqual
(
factory
.
call_count
,
1
)
C
().
x
self
.
assertEqual
(
factory
.
call_count
,
2
)
def
test_default_factory_not_called_if_value_given
(
self
):
# We need a factory that we can test if it's been called.
factory
=
Mock
()
@
dataclass
class
C
:
x
:
int
=
field
(
default_factory
=
factory
)
# Make sure that if a field has a default factory function,
# it's not called if a value is specified.
C
().
x
self
.
assertEqual
(
factory
.
call_count
,
1
)
self
.
assertEqual
(
C
(
10
).
x
,
10
)
self
.
assertEqual
(
factory
.
call_count
,
1
)
C
().
x
self
.
assertEqual
(
factory
.
call_count
,
2
)
def
x_test_classvar_default_factory
(
self
):
# XXX: it's an error for a ClassVar to have a factory function
@
dataclass
class
C
:
x
:
ClassVar
[
int
]
=
field
(
default_factory
=
int
)
self
.
assertIs
(
C
().
x
,
int
)
def
test_isdataclass
(
self
):
# There is no isdataclass() helper any more, but the PEP
# describes how to write it, so make sure that works. Note
# that this version returns True for both classes and
# instances.
def
isdataclass
(
obj
):
try
:
fields
(
obj
)
return
True
except
TypeError
:
return
False
self
.
assertFalse
(
isdataclass
(
0
))
self
.
assertFalse
(
isdataclass
(
int
))
@
dataclass
class
C
:
x
:
int
self
.
assertTrue
(
isdataclass
(
C
))
self
.
assertTrue
(
isdataclass
(
C
(
0
)))
def
test_helper_fields_with_class_instance
(
self
):
# Check that we can call fields() on either a class or instance,
# and get back the same thing.
@
dataclass
class
C
:
x
:
int
y
:
float
self
.
assertEqual
(
fields
(
C
),
fields
(
C
(
0
,
0.0
)))
def
test_helper_fields_exception
(
self
):
# Check that TypeError is raised if not passed a dataclass or
# instance.
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass type or instance'
):
fields
(
0
)
class
C
:
pass
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass type or instance'
):
fields
(
C
)
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass type or instance'
):
fields
(
C
())
def
test_helper_asdict
(
self
):
# Basic tests for asdict(), it should return a new dictionary
@
dataclass
class
C
:
x
:
int
y
:
int
c
=
C
(
1
,
2
)
self
.
assertEqual
(
asdict
(
c
),
{
'x'
:
1
,
'y'
:
2
})
self
.
assertEqual
(
asdict
(
c
),
asdict
(
c
))
self
.
assertIsNot
(
asdict
(
c
),
asdict
(
c
))
c
.
x
=
42
self
.
assertEqual
(
asdict
(
c
),
{
'x'
:
42
,
'y'
:
2
})
self
.
assertIs
(
type
(
asdict
(
c
)),
dict
)
def
test_helper_asdict_raises_on_classes
(
self
):
# asdict() should raise on a class object
@
dataclass
class
C
:
x
:
int
y
:
int
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass instance'
):
asdict
(
C
)
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass instance'
):
asdict
(
int
)
def
test_helper_asdict_copy_values
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
List
[
int
]
=
field
(
default_factory
=
list
)
initial
=
[]
c
=
C
(
1
,
initial
)
d
=
asdict
(
c
)
self
.
assertEqual
(
d
[
'y'
],
initial
)
self
.
assertIsNot
(
d
[
'y'
],
initial
)
c
=
C
(
1
)
d
=
asdict
(
c
)
d
[
'y'
].
append
(
1
)
self
.
assertEqual
(
c
.
y
,
[])
def
test_helper_asdict_nested
(
self
):
@
dataclass
class
UserId
:
token
:
int
group
:
int
@
dataclass
class
User
:
name
:
str
id
:
UserId
u
=
User
(
'Joe'
,
UserId
(
123
,
1
))
d
=
asdict
(
u
)
self
.
assertEqual
(
d
,
{
'name'
:
'Joe'
,
'id'
:
{
'token'
:
123
,
'group'
:
1
}})
self
.
assertIsNot
(
asdict
(
u
),
asdict
(
u
))
u
.
id
.
group
=
2
self
.
assertEqual
(
asdict
(
u
),
{
'name'
:
'Joe'
,
'id'
:
{
'token'
:
123
,
'group'
:
2
}})
def
test_helper_asdict_builtin_containers
(
self
):
@
dataclass
class
User
:
name
:
str
id
:
int
@
dataclass
class
GroupList
:
id
:
int
users
:
List
[
User
]
@
dataclass
class
GroupTuple
:
id
:
int
users
:
Tuple
[
User
,
...]
@
dataclass
class
GroupDict
:
id
:
int
users
:
Dict
[
str
,
User
]
a
=
User
(
'Alice'
,
1
)
b
=
User
(
'Bob'
,
2
)
gl
=
GroupList
(
0
,
[
a
,
b
])
gt
=
GroupTuple
(
0
,
(
a
,
b
))
gd
=
GroupDict
(
0
,
{
'first'
:
a
,
'second'
:
b
})
self
.
assertEqual
(
asdict
(
gl
),
{
'id'
:
0
,
'users'
:
[{
'name'
:
'Alice'
,
'id'
:
1
},
{
'name'
:
'Bob'
,
'id'
:
2
}]})
self
.
assertEqual
(
asdict
(
gt
),
{
'id'
:
0
,
'users'
:
({
'name'
:
'Alice'
,
'id'
:
1
},
{
'name'
:
'Bob'
,
'id'
:
2
})})
self
.
assertEqual
(
asdict
(
gd
),
{
'id'
:
0
,
'users'
:
{
'first'
:
{
'name'
:
'Alice'
,
'id'
:
1
},
'second'
:
{
'name'
:
'Bob'
,
'id'
:
2
}}})
def
test_helper_asdict_builtin_containers
(
self
):
@
dataclass
class
Child
:
d
:
object
@
dataclass
class
Parent
:
child
:
Child
self
.
assertEqual
(
asdict
(
Parent
(
Child
([
1
]))),
{
'child'
:
{
'd'
:
[
1
]}})
self
.
assertEqual
(
asdict
(
Parent
(
Child
({
1
:
2
}))),
{
'child'
:
{
'd'
:
{
1
:
2
}}})
def
test_helper_asdict_factory
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
int
c
=
C
(
1
,
2
)
d
=
asdict
(
c
,
dict_factory
=
OrderedDict
)
self
.
assertEqual
(
d
,
OrderedDict
([(
'x'
,
1
),
(
'y'
,
2
)]))
self
.
assertIsNot
(
d
,
asdict
(
c
,
dict_factory
=
OrderedDict
))
c
.
x
=
42
d
=
asdict
(
c
,
dict_factory
=
OrderedDict
)
self
.
assertEqual
(
d
,
OrderedDict
([(
'x'
,
42
),
(
'y'
,
2
)]))
self
.
assertIs
(
type
(
d
),
OrderedDict
)
def
test_helper_astuple
(
self
):
# Basic tests for astuple(), it should return a new tuple
@
dataclass
class
C
:
x
:
int
y
:
int
=
0
c
=
C
(
1
)
self
.
assertEqual
(
astuple
(
c
),
(
1
,
0
))
self
.
assertEqual
(
astuple
(
c
),
astuple
(
c
))
self
.
assertIsNot
(
astuple
(
c
),
astuple
(
c
))
c
.
y
=
42
self
.
assertEqual
(
astuple
(
c
),
(
1
,
42
))
self
.
assertIs
(
type
(
astuple
(
c
)),
tuple
)
def
test_helper_astuple_raises_on_classes
(
self
):
# astuple() should raise on a class object
@
dataclass
class
C
:
x
:
int
y
:
int
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass instance'
):
astuple
(
C
)
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass instance'
):
astuple
(
int
)
def
test_helper_astuple_copy_values
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
List
[
int
]
=
field
(
default_factory
=
list
)
initial
=
[]
c
=
C
(
1
,
initial
)
t
=
astuple
(
c
)
self
.
assertEqual
(
t
[
1
],
initial
)
self
.
assertIsNot
(
t
[
1
],
initial
)
c
=
C
(
1
)
t
=
astuple
(
c
)
t
[
1
].
append
(
1
)
self
.
assertEqual
(
c
.
y
,
[])
def
test_helper_astuple_nested
(
self
):
@
dataclass
class
UserId
:
token
:
int
group
:
int
@
dataclass
class
User
:
name
:
str
id
:
UserId
u
=
User
(
'Joe'
,
UserId
(
123
,
1
))
t
=
astuple
(
u
)
self
.
assertEqual
(
t
,
(
'Joe'
,
(
123
,
1
)))
self
.
assertIsNot
(
astuple
(
u
),
astuple
(
u
))
u
.
id
.
group
=
2
self
.
assertEqual
(
astuple
(
u
),
(
'Joe'
,
(
123
,
2
)))
def
test_helper_astuple_builtin_containers
(
self
):
@
dataclass
class
User
:
name
:
str
id
:
int
@
dataclass
class
GroupList
:
id
:
int
users
:
List
[
User
]
@
dataclass
class
GroupTuple
:
id
:
int
users
:
Tuple
[
User
,
...]
@
dataclass
class
GroupDict
:
id
:
int
users
:
Dict
[
str
,
User
]
a
=
User
(
'Alice'
,
1
)
b
=
User
(
'Bob'
,
2
)
gl
=
GroupList
(
0
,
[
a
,
b
])
gt
=
GroupTuple
(
0
,
(
a
,
b
))
gd
=
GroupDict
(
0
,
{
'first'
:
a
,
'second'
:
b
})
self
.
assertEqual
(
astuple
(
gl
),
(
0
,
[(
'Alice'
,
1
),
(
'Bob'
,
2
)]))
self
.
assertEqual
(
astuple
(
gt
),
(
0
,
((
'Alice'
,
1
),
(
'Bob'
,
2
))))
self
.
assertEqual
(
astuple
(
gd
),
(
0
,
{
'first'
:
(
'Alice'
,
1
),
'second'
:
(
'Bob'
,
2
)}))
def
test_helper_astuple_builtin_containers
(
self
):
@
dataclass
class
Child
:
d
:
object
@
dataclass
class
Parent
:
child
:
Child
self
.
assertEqual
(
astuple
(
Parent
(
Child
([
1
]))),
(([
1
],),))
self
.
assertEqual
(
astuple
(
Parent
(
Child
({
1
:
2
}))),
(({
1
:
2
},),))
def
test_helper_astuple_factory
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
int
NT
=
namedtuple
(
'NT'
,
'x y'
)
def
nt
(
lst
):
return
NT
(
*
lst
)
c
=
C
(
1
,
2
)
t
=
astuple
(
c
,
tuple_factory
=
nt
)
self
.
assertEqual
(
t
,
NT
(
1
,
2
))
self
.
assertIsNot
(
t
,
astuple
(
c
,
tuple_factory
=
nt
))
c
.
x
=
42
t
=
astuple
(
c
,
tuple_factory
=
nt
)
self
.
assertEqual
(
t
,
NT
(
42
,
2
))
self
.
assertIs
(
type
(
t
),
NT
)
def
test_dynamic_class_creation
(
self
):
cls_dict
=
{
'__annotations__'
:
OrderedDict
(
x
=
int
,
y
=
int
),
}
# Create the class.
cls
=
type
(
'C'
,
(),
cls_dict
)
# Make it a dataclass.
cls1
=
dataclass
(
cls
)
self
.
assertEqual
(
cls1
,
cls
)
self
.
assertEqual
(
asdict
(
cls
(
1
,
2
)),
{
'x'
:
1
,
'y'
:
2
})
def
test_dynamic_class_creation_using_field
(
self
):
cls_dict
=
{
'__annotations__'
:
OrderedDict
(
x
=
int
,
y
=
int
),
'y'
:
field
(
default
=
5
),
}
# Create the class.
cls
=
type
(
'C'
,
(),
cls_dict
)
# Make it a dataclass.
cls1
=
dataclass
(
cls
)
self
.
assertEqual
(
cls1
,
cls
)
self
.
assertEqual
(
asdict
(
cls1
(
1
)),
{
'x'
:
1
,
'y'
:
5
})
def
test_init_in_order
(
self
):
@
dataclass
class
C
:
a
:
int
b
:
int
=
field
()
c
:
list
=
field
(
default_factory
=
list
,
init
=
False
)
d
:
list
=
field
(
default_factory
=
list
)
e
:
int
=
field
(
default
=
4
,
init
=
False
)
f
:
int
=
4
calls
=
[]
def
setattr
(
self
,
name
,
value
):
calls
.
append
((
name
,
value
))
C
.
__setattr__
=
setattr
c
=
C
(
0
,
1
)
self
.
assertEqual
((
'a'
,
0
),
calls
[
0
])
self
.
assertEqual
((
'b'
,
1
),
calls
[
1
])
self
.
assertEqual
((
'c'
,
[]),
calls
[
2
])
self
.
assertEqual
((
'd'
,
[]),
calls
[
3
])
self
.
assertNotIn
((
'e'
,
4
),
calls
)
self
.
assertEqual
((
'f'
,
4
),
calls
[
4
])
def
test_items_in_dicts
(
self
):
@
dataclass
class
C
:
a
:
int
b
:
list
=
field
(
default_factory
=
list
,
init
=
False
)
c
:
list
=
field
(
default_factory
=
list
)
d
:
int
=
field
(
default
=
4
,
init
=
False
)
e
:
int
=
0
c
=
C
(
0
)
# Class dict
self
.
assertNotIn
(
'a'
,
C
.
__dict__
)
self
.
assertNotIn
(
'b'
,
C
.
__dict__
)
self
.
assertNotIn
(
'c'
,
C
.
__dict__
)
self
.
assertIn
(
'd'
,
C
.
__dict__
)
self
.
assertEqual
(
C
.
d
,
4
)
self
.
assertIn
(
'e'
,
C
.
__dict__
)
self
.
assertEqual
(
C
.
e
,
0
)
# Instance dict
self
.
assertIn
(
'a'
,
c
.
__dict__
)
self
.
assertEqual
(
c
.
a
,
0
)
self
.
assertIn
(
'b'
,
c
.
__dict__
)
self
.
assertEqual
(
c
.
b
,
[])
self
.
assertIn
(
'c'
,
c
.
__dict__
)
self
.
assertEqual
(
c
.
c
,
[])
self
.
assertNotIn
(
'd'
,
c
.
__dict__
)
self
.
assertIn
(
'e'
,
c
.
__dict__
)
self
.
assertEqual
(
c
.
e
,
0
)
def
test_alternate_classmethod_constructor
(
self
):
# Since __post_init__ can't take params, use a classmethod
# alternate constructor. This is mostly an example to show how
# to use this technique.
@
dataclass
class
C
:
x
:
int
@
classmethod
def
from_file
(
cls
,
filename
):
# In a real example, create a new instance
# and populate 'x' from contents of a file.
value_in_file
=
20
return
cls
(
value_in_file
)
self
.
assertEqual
(
C
.
from_file
(
'filename'
).
x
,
20
)
def
test_field_metadata_default
(
self
):
# Make sure the default metadata is read-only and of
# zero length.
@
dataclass
class
C
:
i
:
int
self
.
assertFalse
(
fields
(
C
)[
0
].
metadata
)
self
.
assertEqual
(
len
(
fields
(
C
)[
0
].
metadata
),
0
)
with
self
.
assertRaisesRegex
(
TypeError
,
'does not support item assignment'
):
fields
(
C
)[
0
].
metadata
[
'test'
]
=
3
def
test_field_metadata_mapping
(
self
):
# Make sure only a mapping can be passed as metadata
# zero length.
with
self
.
assertRaises
(
TypeError
):
@
dataclass
class
C
:
i
:
int
=
field
(
metadata
=
0
)
# Make sure an empty dict works
@
dataclass
class
C
:
i
:
int
=
field
(
metadata
=
{})
self
.
assertFalse
(
fields
(
C
)[
0
].
metadata
)
self
.
assertEqual
(
len
(
fields
(
C
)[
0
].
metadata
),
0
)
with
self
.
assertRaisesRegex
(
TypeError
,
'does not support item assignment'
):
fields
(
C
)[
0
].
metadata
[
'test'
]
=
3
# Make sure a non-empty dict works.
@
dataclass
class
C
:
i
:
int
=
field
(
metadata
=
{
'test'
:
10
,
'bar'
:
'42'
,
3
:
'three'
})
self
.
assertEqual
(
len
(
fields
(
C
)[
0
].
metadata
),
3
)
self
.
assertEqual
(
fields
(
C
)[
0
].
metadata
[
'test'
],
10
)
self
.
assertEqual
(
fields
(
C
)[
0
].
metadata
[
'bar'
],
'42'
)
self
.
assertEqual
(
fields
(
C
)[
0
].
metadata
[
3
],
'three'
)
with
self
.
assertRaises
(
KeyError
):
# Non-existent key.
fields
(
C
)[
0
].
metadata
[
'baz'
]
with
self
.
assertRaisesRegex
(
TypeError
,
'does not support item assignment'
):
fields
(
C
)[
0
].
metadata
[
'test'
]
=
3
def
test_field_metadata_custom_mapping
(
self
):
# Try a custom mapping.
class
SimpleNameSpace
:
def
__init__
(
self
,
**
kw
):
self
.
__dict__
.
update
(
kw
)
def
__getitem__
(
self
,
item
):
if
item
==
'xyzzy'
:
return
'plugh'
return
getattr
(
self
,
item
)
def
__len__
(
self
):
return
self
.
__dict__
.
__len__
()
@
dataclass
class
C
:
i
:
int
=
field
(
metadata
=
SimpleNameSpace
(
a
=
10
))
self
.
assertEqual
(
len
(
fields
(
C
)[
0
].
metadata
),
1
)
self
.
assertEqual
(
fields
(
C
)[
0
].
metadata
[
'a'
],
10
)
with
self
.
assertRaises
(
AttributeError
):
fields
(
C
)[
0
].
metadata
[
'b'
]
# Make sure we're still talking to our custom mapping.
self
.
assertEqual
(
fields
(
C
)[
0
].
metadata
[
'xyzzy'
],
'plugh'
)
def
test_generic_dataclasses
(
self
):
T
=
TypeVar
(
'T'
)
@
dataclass
class
LabeledBox
(
Generic
[
T
]):
content
:
T
label
:
str
=
'<unknown>'
box
=
LabeledBox
(
42
)
self
.
assertEqual
(
box
.
content
,
42
)
self
.
assertEqual
(
box
.
label
,
'<unknown>'
)
# subscripting the resulting class should work, etc.
Alias
=
List
[
LabeledBox
[
int
]]
def
test_generic_extending
(
self
):
S
=
TypeVar
(
'S'
)
T
=
TypeVar
(
'T'
)
@
dataclass
class
Base
(
Generic
[
T
,
S
]):
x
:
T
y
:
S
@
dataclass
class
DataDerived
(
Base
[
int
,
T
]):
new_field
:
str
Alias
=
DataDerived
[
str
]
c
=
Alias
(
0
,
'test1'
,
'test2'
)
self
.
assertEqual
(
astuple
(
c
),
(
0
,
'test1'
,
'test2'
))
class
NonDataDerived
(
Base
[
int
,
T
]):
def
new_method
(
self
):
return
self
.
y
Alias
=
NonDataDerived
[
float
]
c
=
Alias
(
10
,
1.0
)
self
.
assertEqual
(
c
.
new_method
(),
1.0
)
def
test_helper_replace
(
self
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
y
:
int
c
=
C
(
1
,
2
)
c1
=
replace
(
c
,
x
=
3
)
self
.
assertEqual
(
c1
.
x
,
3
)
self
.
assertEqual
(
c1
.
y
,
2
)
def
test_helper_replace_frozen
(
self
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
y
:
int
z
:
int
=
field
(
init
=
False
,
default
=
10
)
t
:
int
=
field
(
init
=
False
,
default
=
100
)
c
=
C
(
1
,
2
)
c1
=
replace
(
c
,
x
=
3
)
self
.
assertEqual
((
c
.
x
,
c
.
y
,
c
.
z
,
c
.
t
),
(
1
,
2
,
10
,
100
))
self
.
assertEqual
((
c1
.
x
,
c1
.
y
,
c1
.
z
,
c1
.
t
),
(
3
,
2
,
10
,
100
))
with
self
.
assertRaisesRegex
(
ValueError
,
'init=False'
):
replace
(
c
,
x
=
3
,
z
=
20
,
t
=
50
)
with
self
.
assertRaisesRegex
(
ValueError
,
'init=False'
):
replace
(
c
,
z
=
20
)
replace
(
c
,
x
=
3
,
z
=
20
,
t
=
50
)
# Make sure the result is still frozen.
with
self
.
assertRaisesRegex
(
FrozenInstanceError
,
"cannot assign to field 'x'"
):
c1
.
x
=
3
# Make sure we can't replace an attribute that doesn't exist,
# if we're also replacing one that does exist. Test this
# here, because setting attributes on frozen instances is
# handled slightly differently from non-frozen ones.
with
self
.
assertRaisesRegex
(
TypeError
,
"__init__
\
(
\
) got an unexpected "
"keyword argument 'a'"
):
c1
=
replace
(
c
,
x
=
20
,
a
=
5
)
def
test_helper_replace_invalid_field_name
(
self
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
y
:
int
c
=
C
(
1
,
2
)
with
self
.
assertRaisesRegex
(
TypeError
,
"__init__
\
(
\
) got an unexpected "
"keyword argument 'z'"
):
c1
=
replace
(
c
,
z
=
3
)
def
test_helper_replace_invalid_object
(
self
):
@
dataclass
(
frozen
=
True
)
class
C
:
x
:
int
y
:
int
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass instance'
):
replace
(
C
,
x
=
3
)
with
self
.
assertRaisesRegex
(
TypeError
,
'dataclass instance'
):
replace
(
0
,
x
=
3
)
def
test_helper_replace_no_init
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
int
=
field
(
init
=
False
,
default
=
10
)
c
=
C
(
1
)
c
.
y
=
20
# Make sure y gets the default value.
c1
=
replace
(
c
,
x
=
5
)
self
.
assertEqual
((
c1
.
x
,
c1
.
y
),
(
5
,
10
))
# Trying to replace y is an error.
with
self
.
assertRaisesRegex
(
ValueError
,
'init=False'
):
replace
(
c
,
x
=
2
,
y
=
30
)
with
self
.
assertRaisesRegex
(
ValueError
,
'init=False'
):
replace
(
c
,
y
=
30
)
def
test_dataclassses_pickleable
(
self
):
global
P
,
Q
,
R
@
dataclass
class
P
:
x
:
int
y
:
int
=
0
@
dataclass
class
Q
:
x
:
int
y
:
int
=
field
(
default
=
0
,
init
=
False
)
@
dataclass
class
R
:
x
:
int
y
:
List
[
int
]
=
field
(
default_factory
=
list
)
q
=
Q
(
1
)
q
.
y
=
2
samples
=
[
P
(
1
),
P
(
1
,
2
),
Q
(
1
),
q
,
R
(
1
),
R
(
1
,
[
2
,
3
,
4
])]
for
sample
in
samples
:
for
proto
in
range
(
pickle
.
HIGHEST_PROTOCOL
+
1
):
with
self
.
subTest
(
sample
=
sample
,
proto
=
proto
):
new_sample
=
pickle
.
loads
(
pickle
.
dumps
(
sample
,
proto
))
self
.
assertEqual
(
sample
.
x
,
new_sample
.
x
)
self
.
assertEqual
(
sample
.
y
,
new_sample
.
y
)
self
.
assertIsNot
(
sample
,
new_sample
)
new_sample
.
x
=
42
another_new_sample
=
pickle
.
loads
(
pickle
.
dumps
(
new_sample
,
proto
))
self
.
assertEqual
(
new_sample
.
x
,
another_new_sample
.
x
)
self
.
assertEqual
(
sample
.
y
,
another_new_sample
.
y
)
def
test_helper_make_dataclass
(
self
):
C
=
make_dataclass
(
'C'
,
[(
'x'
,
int
),
(
'y'
,
int
,
field
(
default
=
5
))],
namespace
=
{
'add_one'
:
lambda
self
:
self
.
x
+
1
})
c
=
C
(
10
)
self
.
assertEqual
((
c
.
x
,
c
.
y
),
(
10
,
5
))
self
.
assertEqual
(
c
.
add_one
(),
11
)
def
test_helper_make_dataclass_no_mutate_namespace
(
self
):
# Make sure a provided namespace isn't mutated.
ns
=
{}
C
=
make_dataclass
(
'C'
,
[(
'x'
,
int
),
(
'y'
,
int
,
field
(
default
=
5
))],
namespace
=
ns
)
self
.
assertEqual
(
ns
,
{})
def
test_helper_make_dataclass_base
(
self
):
class
Base1
:
pass
class
Base2
:
pass
C
=
make_dataclass
(
'C'
,
[(
'x'
,
int
)],
bases
=
(
Base1
,
Base2
))
c
=
C
(
2
)
self
.
assertIsInstance
(
c
,
C
)
self
.
assertIsInstance
(
c
,
Base1
)
self
.
assertIsInstance
(
c
,
Base2
)
def
test_helper_make_dataclass_base_dataclass
(
self
):
@
dataclass
class
Base1
:
x
:
int
class
Base2
:
pass
C
=
make_dataclass
(
'C'
,
[(
'y'
,
int
)],
bases
=
(
Base1
,
Base2
))
with
self
.
assertRaisesRegex
(
TypeError
,
'required positional'
):
c
=
C
(
2
)
c
=
C
(
1
,
2
)
self
.
assertIsInstance
(
c
,
C
)
self
.
assertIsInstance
(
c
,
Base1
)
self
.
assertIsInstance
(
c
,
Base2
)
self
.
assertEqual
((
c
.
x
,
c
.
y
),
(
1
,
2
))
def
test_helper_make_dataclass_init_var
(
self
):
def
post_init
(
self
,
y
):
self
.
x
*=
y
C
=
make_dataclass
(
'C'
,
[(
'x'
,
int
),
(
'y'
,
InitVar
[
int
]),
],
namespace
=
{
'__post_init__'
:
post_init
},
)
c
=
C
(
2
,
3
)
self
.
assertEqual
(
vars
(
c
),
{
'x'
:
6
})
self
.
assertEqual
(
len
(
fields
(
c
)),
1
)
def
test_helper_make_dataclass_class_var
(
self
):
C
=
make_dataclass
(
'C'
,
[(
'x'
,
int
),
(
'y'
,
ClassVar
[
int
],
10
),
(
'z'
,
ClassVar
[
int
],
field
(
default
=
20
)),
])
c
=
C
(
1
)
self
.
assertEqual
(
vars
(
c
),
{
'x'
:
1
})
self
.
assertEqual
(
len
(
fields
(
c
)),
1
)
self
.
assertEqual
(
C
.
y
,
10
)
self
.
assertEqual
(
C
.
z
,
20
)
class
TestDocString
(
unittest
.
TestCase
):
def
assertDocStrEqual
(
self
,
a
,
b
):
# Because 3.6 and 3.7 differ in how inspect.signature work
# (see bpo #32108), for the time being just compare them with
# whitespace stripped.
self
.
assertEqual
(
a
.
replace
(
' '
,
''
),
b
.
replace
(
' '
,
''
))
def
test_existing_docstring_not_overridden
(
self
):
@
dataclass
class
C
:
"""Lorem ipsum"""
x
:
int
self
.
assertEqual
(
C
.
__doc__
,
"Lorem ipsum"
)
def
test_docstring_no_fields
(
self
):
@
dataclass
class
C
:
pass
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C()"
)
def
test_docstring_one_field
(
self
):
@
dataclass
class
C
:
x
:
int
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:int)"
)
def
test_docstring_two_fields
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
int
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:int, y:int)"
)
def
test_docstring_three_fields
(
self
):
@
dataclass
class
C
:
x
:
int
y
:
int
z
:
str
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:int, y:int, z:str)"
)
def
test_docstring_one_field_with_default
(
self
):
@
dataclass
class
C
:
x
:
int
=
3
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:int=3)"
)
def
test_docstring_one_field_with_default_none
(
self
):
@
dataclass
class
C
:
x
:
Union
[
int
,
type
(
None
)]
=
None
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:Union[int, NoneType]=None)"
)
def
test_docstring_list_field
(
self
):
@
dataclass
class
C
:
x
:
List
[
int
]
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:List[int])"
)
def
test_docstring_list_field_with_default_factory
(
self
):
@
dataclass
class
C
:
x
:
List
[
int
]
=
field
(
default_factory
=
list
)
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:List[int]=<factory>)"
)
def
test_docstring_deque_field
(
self
):
@
dataclass
class
C
:
x
:
deque
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:collections.deque)"
)
def
test_docstring_deque_field_with_default_factory
(
self
):
@
dataclass
class
C
:
x
:
deque
=
field
(
default_factory
=
deque
)
self
.
assertDocStrEqual
(
C
.
__doc__
,
"C(x:collections.deque=<factory>)"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Misc/NEWS.d/next/Library/2017-12-04-15-51-57.bpo-32214.uozdNj.rst
0 → 100644
View file @
f0db54a0
PEP 557, Data Classes. Provides a decorator which adds boilerplate methods
to classes which use type annotations so specify fields.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment