Commit 457fc9a6 authored by Nick Coghlan's avatar Nick Coghlan

Issue #27137: align Python & C implementations of functools.partial

The pure Python fallback implementation of functools.partial
now matches the behaviour of its accelerated C counterpart for
subclassing, pickling and text representation purposes.

Patch by Emanuel Barry and Serhiy Storchaka.
parent eddc4b72
......@@ -21,6 +21,7 @@ from abc import get_cache_token
from collections import namedtuple
from types import MappingProxyType
from weakref import WeakKeyDictionary
from reprlib import recursive_repr
try:
from _thread import RLock
except ImportError:
......@@ -237,11 +238,24 @@ except ImportError:
################################################################################
# Purely functional, no descriptor behaviour
def partial(func, *args, **keywords):
class partial:
"""New function with partial application of the given arguments
and keywords.
"""
if hasattr(func, 'func'):
__slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
def __new__(*args, **keywords):
if not args:
raise TypeError("descriptor '__new__' of partial needs an argument")
if len(args) < 2:
raise TypeError("type 'partial' takes at least one argument")
cls, func, *args = args
if not callable(func):
raise TypeError("the first argument must be callable")
args = tuple(args)
if hasattr(func, "func"):
args = func.args + args
tmpkw = func.keywords.copy()
tmpkw.update(keywords)
......@@ -249,14 +263,58 @@ def partial(func, *args, **keywords):
del tmpkw
func = func.func
def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy()
newkeywords.update(fkeywords)
return func(*(args + fargs), **newkeywords)
newfunc.func = func
newfunc.args = args
newfunc.keywords = keywords
return newfunc
self = super(partial, cls).__new__(cls)
self.func = func
self.args = args
self.keywords = keywords
return self
def __call__(*args, **keywords):
if not args:
raise TypeError("descriptor '__call__' of partial needs an argument")
self, *args = args
newkeywords = self.keywords.copy()
newkeywords.update(keywords)
return self.func(*self.args, *args, **newkeywords)
@recursive_repr()
def __repr__(self):
qualname = type(self).__qualname__
args = [repr(self.func)]
args.extend(repr(x) for x in self.args)
args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
if type(self).__module__ == "functools":
return f"functools.{qualname}({', '.join(args)})"
return f"{qualname}({', '.join(args)})"
def __reduce__(self):
return type(self), (self.func,), (self.func, self.args,
self.keywords or None, self.__dict__ or None)
def __setstate__(self, state):
if not isinstance(state, tuple):
raise TypeError("argument to __setstate__ must be a tuple")
if len(state) != 4:
raise TypeError(f"expected 4 items in state, got {len(state)}")
func, args, kwds, namespace = state
if (not callable(func) or not isinstance(args, tuple) or
(kwds is not None and not isinstance(kwds, dict)) or
(namespace is not None and not isinstance(namespace, dict))):
raise TypeError("invalid partial state")
args = tuple(args) # just in case it's a subclass
if kwds is None:
kwds = {}
elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
kwds = dict(kwds)
if namespace is None:
namespace = {}
self.__dict__ = namespace
self.func = func
self.args = args
self.keywords = kwds
try:
from _functools import partial
......
......@@ -8,6 +8,7 @@ import sys
from test import support
import unittest
from weakref import proxy
import contextlib
try:
import threading
except ImportError:
......@@ -20,6 +21,14 @@ c_functools = support.import_fresh_module('functools', fresh=['_functools'])
decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
@contextlib.contextmanager
def replaced_module(name, replacement):
original_module = sys.modules[name]
sys.modules[name] = replacement
try:
yield
finally:
sys.modules[name] = original_module
def capture(*args, **kw):
"""capture all positional and keyword arguments"""
......@@ -167,58 +176,35 @@ class TestPartial:
p2.new_attr = 'spam'
self.assertEqual(p2.new_attr, 'spam')
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialC(TestPartial, unittest.TestCase):
if c_functools:
partial = c_functools.partial
def test_attributes_unwritable(self):
# attributes should not be writable
p = self.partial(capture, 1, 2, a=10, b=20)
self.assertRaises(AttributeError, setattr, p, 'func', map)
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
p = self.partial(hex)
try:
del p.__dict__
except TypeError:
pass
else:
self.fail('partial object allowed __dict__ to be deleted')
def test_repr(self):
args = (object(), object())
args_repr = ', '.join(repr(a) for a in args)
kwargs = {'a': object(), 'b': object()}
kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
'b={b!r}, a={a!r}'.format_map(kwargs)]
if self.partial is c_functools.partial:
if self.partial in (c_functools.partial, py_functools.partial):
name = 'functools.partial'
else:
name = self.partial.__name__
f = self.partial(capture)
self.assertEqual('{}({!r})'.format(name, capture),
repr(f))
self.assertEqual(f'{name}({capture!r})', repr(f))
f = self.partial(capture, *args)
self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
repr(f))
self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
f = self.partial(capture, **kwargs)
self.assertIn(repr(f),
['{}({!r}, {})'.format(name, capture, kwargs_repr)
[f'{name}({capture!r}, {kwargs_repr})'
for kwargs_repr in kwargs_reprs])
f = self.partial(capture, *args, **kwargs)
self.assertIn(repr(f),
['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
[f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
for kwargs_repr in kwargs_reprs])
def test_recursive_repr(self):
if self.partial is c_functools.partial:
if self.partial in (c_functools.partial, py_functools.partial):
name = 'functools.partial'
else:
name = self.partial.__name__
......@@ -226,25 +212,26 @@ class TestPartialC(TestPartial, unittest.TestCase):
f = self.partial(capture)
f.__setstate__((f, (), {}, {}))
try:
self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
self.assertEqual(repr(f), '%s(...)' % (name,))
finally:
f.__setstate__((capture, (), {}, {}))
f = self.partial(capture)
f.__setstate__((capture, (f,), {}, {}))
try:
self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
finally:
f.__setstate__((capture, (), {}, {}))
f = self.partial(capture)
f.__setstate__((capture, (), {'a': f}, {}))
try:
self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
finally:
f.__setstate__((capture, (), {}, {}))
def test_pickle(self):
with self.AllowPickle():
f = self.partial(signature, ['asdf'], bar=[True])
f.attr = []
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
......@@ -274,11 +261,13 @@ class TestPartialC(TestPartial, unittest.TestCase):
def test_setstate(self):
f = self.partial(signature)
f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
self.assertEqual(signature(f),
(capture, (1,), dict(a=10), dict(attr=[])))
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
f.__setstate__((capture, (1,), dict(a=10), None))
self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
......@@ -325,6 +314,7 @@ class TestPartialC(TestPartial, unittest.TestCase):
self.assertIs(type(r[0]), tuple)
def test_recursive_pickle(self):
with self.AllowPickle():
f = self.partial(capture)
f.__setstate__((f, (), {}, {}))
try:
......@@ -375,24 +365,60 @@ class TestPartialC(TestPartial, unittest.TestCase):
f = self.partial(object)
self.assertRaises(TypeError, f.__setstate__, BadSequence())
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialC(TestPartial, unittest.TestCase):
if c_functools:
partial = c_functools.partial
class AllowPickle:
def __enter__(self):
return self
def __exit__(self, type, value, tb):
return False
def test_attributes_unwritable(self):
# attributes should not be writable
p = self.partial(capture, 1, 2, a=10, b=20)
self.assertRaises(AttributeError, setattr, p, 'func', map)
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
p = self.partial(hex)
try:
del p.__dict__
except TypeError:
pass
else:
self.fail('partial object allowed __dict__ to be deleted')
class TestPartialPy(TestPartial, unittest.TestCase):
partial = staticmethod(py_functools.partial)
partial = py_functools.partial
class AllowPickle:
def __init__(self):
self._cm = replaced_module("functools", py_functools)
def __enter__(self):
return self._cm.__enter__()
def __exit__(self, type, value, tb):
return self._cm.__exit__(type, value, tb)
if c_functools:
class PartialSubclass(c_functools.partial):
class CPartialSubclass(c_functools.partial):
pass
class PyPartialSubclass(py_functools.partial):
pass
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialCSubclass(TestPartialC):
if c_functools:
partial = PartialSubclass
partial = CPartialSubclass
# partial subclasses are not optimized for nested calls
test_nested_optimization = None
class TestPartialPySubclass(TestPartialPy):
partial = PyPartialSubclass
class TestPartialMethod(unittest.TestCase):
......@@ -683,9 +709,10 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestReduce(unittest.TestCase):
func = functools.reduce
if c_functools:
func = c_functools.reduce
def test_reduce(self):
class Squares:
......
......@@ -135,6 +135,11 @@ Core and Builtins
Library
-------
- Issue #27137: the pure Python fallback implementation of ``functools.partial``
now matches the behaviour of its accelerated C counterpart for subclassing,
pickling and text representation purposes. Patch by Emanuel Barry and
Serhiy Storchaka.
- Issue #28019: itertools.count() no longer rounds non-integer step in range
between 1.0 and 2.0 to 1.
......
......@@ -229,7 +229,7 @@ partial_repr(partialobject *pto)
if (status != 0) {
if (status < 0)
return NULL;
return PyUnicode_FromFormat("%s(...)", Py_TYPE(pto)->tp_name);
return PyUnicode_FromString("...");
}
arglist = PyUnicode_FromString("");
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment