Commit c651275a authored by Ethan Smith's avatar Ethan Smith Committed by Łukasz Langa

bpo-32380: Create functools.singledispatchmethod (#6306)

parent 09c4a7de
......@@ -383,6 +383,52 @@ The :mod:`functools` module defines the following functions:
The :func:`register` attribute supports using type annotations.
.. class:: singledispatchmethod(func)
Transform a method into a :term:`single-dispatch <single
dispatch>` :term:`generic function`.
To define a generic method, decorate it with the ``@singledispatchmethod``
decorator. Note that the dispatch happens on the type of the first non-self
or non-cls argument, create your function accordingly::
class Negator:
@singledispatchmethod
def neg(self, arg):
raise NotImplementedError("Cannot negate a")
@neg.register
def _(self, arg: int):
return -arg
@neg.register
def _(self, arg: bool):
return not arg
``@singledispatchmethod`` supports nesting with other decorators such as
``@classmethod``. Note that to allow for ``dispatcher.register``,
``singledispatchmethod`` must be the *outer most* decorator. Here is the
``Negator`` class with the ``neg`` methods being class bound::
class Negator:
@singledispatchmethod
@classmethod
def neg(cls, arg):
raise NotImplementedError("Cannot negate a")
@neg.register
@classmethod
def _(cls, arg: int):
return -arg
@neg.register
@classmethod
def _(cls, arg: bool):
return not arg
The same pattern can be used for other similar decorators: ``staticmethod``,
``abstractmethod``, and others.
.. function:: update_wrapper(wrapper, wrapped, assigned=WRAPPER_ASSIGNMENTS, updated=WRAPPER_UPDATES)
Update a *wrapper* function to look like the *wrapped* function. The optional
......
......@@ -11,7 +11,7 @@
__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial',
'partialmethod', 'singledispatch']
'partialmethod', 'singledispatch', 'singledispatchmethod']
try:
from _functools import reduce
......@@ -826,3 +826,40 @@ def singledispatch(func):
wrapper._clear_cache = dispatch_cache.clear
update_wrapper(wrapper, func)
return wrapper
# Descriptor version
class singledispatchmethod:
"""Single-dispatch generic method descriptor.
Supports wrapping existing descriptors and handles non-descriptor
callables as instance methods.
"""
def __init__(self, func):
if not callable(func) and not hasattr(func, "__get__"):
raise TypeError(f"{func!r} is not callable or a descriptor")
self.dispatcher = singledispatch(func)
self.func = func
def register(self, cls, method=None):
"""generic_method.register(cls, func) -> func
Registers a new implementation for the given *cls* on a *generic_method*.
"""
return self.dispatcher.register(cls, func=method)
def __get__(self, obj, cls):
def _method(*args, **kwargs):
method = self.dispatcher.dispatch(args[0].__class__)
return method.__get__(obj, cls)(*args, **kwargs)
_method.__isabstractmethod__ = self.__isabstractmethod__
_method.register = self.register
update_wrapper(_method, self.func)
return _method
@property
def __isabstractmethod__(self):
return getattr(self.func, '__isabstractmethod__', False)
......@@ -2147,6 +2147,124 @@ class TestSingleDispatch(unittest.TestCase):
return self.arg == other
self.assertEqual(i("str"), "str")
def test_method_register(self):
class A:
@functools.singledispatchmethod
def t(self, arg):
self.arg = "base"
@t.register(int)
def _(self, arg):
self.arg = "int"
@t.register(str)
def _(self, arg):
self.arg = "str"
a = A()
a.t(0)
self.assertEqual(a.arg, "int")
aa = A()
self.assertFalse(hasattr(aa, 'arg'))
a.t('')
self.assertEqual(a.arg, "str")
aa = A()
self.assertFalse(hasattr(aa, 'arg'))
a.t(0.0)
self.assertEqual(a.arg, "base")
aa = A()
self.assertFalse(hasattr(aa, 'arg'))
def test_staticmethod_register(self):
class A:
@functools.singledispatchmethod
@staticmethod
def t(arg):
return arg
@t.register(int)
@staticmethod
def _(arg):
return isinstance(arg, int)
@t.register(str)
@staticmethod
def _(arg):
return isinstance(arg, str)
a = A()
self.assertTrue(A.t(0))
self.assertTrue(A.t(''))
self.assertEqual(A.t(0.0), 0.0)
def test_classmethod_register(self):
class A:
def __init__(self, arg):
self.arg = arg
@functools.singledispatchmethod
@classmethod
def t(cls, arg):
return cls("base")
@t.register(int)
@classmethod
def _(cls, arg):
return cls("int")
@t.register(str)
@classmethod
def _(cls, arg):
return cls("str")
self.assertEqual(A.t(0).arg, "int")
self.assertEqual(A.t('').arg, "str")
self.assertEqual(A.t(0.0).arg, "base")
def test_callable_register(self):
class A:
def __init__(self, arg):
self.arg = arg
@functools.singledispatchmethod
@classmethod
def t(cls, arg):
return cls("base")
@A.t.register(int)
@classmethod
def _(cls, arg):
return cls("int")
@A.t.register(str)
@classmethod
def _(cls, arg):
return cls("str")
self.assertEqual(A.t(0).arg, "int")
self.assertEqual(A.t('').arg, "str")
self.assertEqual(A.t(0.0).arg, "base")
def test_abstractmethod_register(self):
class Abstract(abc.ABCMeta):
@functools.singledispatchmethod
@abc.abstractmethod
def add(self, x, y):
pass
self.assertTrue(Abstract.add.__isabstractmethod__)
def test_type_ann_register(self):
class A:
@functools.singledispatchmethod
def t(self, arg):
return "base"
@t.register
def _(self, arg: int):
return "int"
@t.register
def _(self, arg: str):
return "str"
a = A()
self.assertEqual(a.t(0), "int")
self.assertEqual(a.t(''), "str")
self.assertEqual(a.t(0.0), "base")
def test_invalid_registrations(self):
msg_prefix = "Invalid first argument to `register()`: "
msg_suffix = (
......
......@@ -1510,6 +1510,7 @@ Václav Šmilauer
Allen W. Smith
Christopher Smith
Eric V. Smith
Ethan H. Smith
Gregory P. Smith
Mark Smith
Nathaniel J. Smith
......
Create functools.singledispatchmethod to support generic single dispatch on
descriptors and methods.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment