Unverified Commit 27b5adbb authored by will-ca's avatar will-ca Committed by GitHub

Make fused function dispatch O(n) for `cpdef` functions. (GH-3366)

* Rewrote signature matching for fused cpdef function dispatch to use a pre-built tree index in a mutable default argument and be O(n).

* Added test to ensure proper differentiation between ambiguously compatible and definitely compatible arguments.

* Added test to ensure fused cpdef's can be called by the module itself during import.

* Added test to ensure consistent handling of ambiguous fused cpdef signatures.

* Test for explicitly defined fused cpdef method.

* Add .komodoproject to .gitignore.

* Add /cython_debug/ to .gitignore.

Closes #1385.
parent 4fd90184
......@@ -23,6 +23,7 @@ Demos/*/*.html
/TEST_TMP/
/build/
/cython_build/
/wheelhouse*/
!tests/build/
/dist/
......@@ -52,3 +53,5 @@ MANIFEST
/.idea
/*.iml
# Komodo EDIT/IDE project files
/*.komodoproject
......@@ -584,6 +584,26 @@ class FusedCFuncDefNode(StatListNode):
{{endif}}
""")
def _fused_signature_index(self, pyx_code):
"""
Generate Cython code for constructing a persistent nested dictionary index of
fused type specialization signatures.
"""
pyx_code.put_chunk(
u"""
if not _fused_sigindex:
for sig in <dict>signatures:
sigindex_node = _fused_sigindex
sig_series = sig.strip('()').split('|')
for sig_type in sig_series[:-1]:
if sig_type not in sigindex_node:
sigindex_node[sig_type] = sigindex_node = {}
else:
sigindex_node = sigindex_node[sig_type]
sigindex_node[sig_series[-1]] = sig
"""
)
def make_fused_cpdef(self, orig_py_func, env, is_def):
"""
This creates the function that is indexable from Python and does
......@@ -620,10 +640,14 @@ class FusedCFuncDefNode(StatListNode):
pyx_code.put_chunk(
u"""
def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
def __pyx_fused_cpdef(signatures, args, kwargs, defaults, *, _fused_sigindex={}):
# FIXME: use a typed signature - currently fails badly because
# default arguments inherit the types we specify here!
cdef list search_list
cdef dict sn, sigindex_node
dest_sig = [None] * {{n_fused}}
if kwargs is not None and not kwargs:
......@@ -691,23 +715,36 @@ class FusedCFuncDefNode(StatListNode):
env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
self._fused_signature_index(pyx_code)
pyx_code.put_chunk(
u"""
candidates = []
for sig in <dict>signatures:
match_found = False
src_sig = sig.strip('()').split('|')
for i in range(len(dest_sig)):
dst_type = dest_sig[i]
if dst_type is not None:
if src_sig[i] == dst_type:
match_found = True
else:
match_found = False
break
sigindex_matches = []
sigindex_candidates = [_fused_sigindex]
for dst_type in dest_sig:
found_matches = []
found_candidates = []
# Make two seperate lists: One for signature sub-trees
# with at least one definite match, and another for
# signature sub-trees with only ambiguous matches
# (where `dest_sig[i] is None`).
if dst_type is None:
for sn in sigindex_matches:
found_matches.extend(sn.values())
for sn in sigindex_candidates:
found_candidates.extend(sn.values())
else:
for search_list in (sigindex_matches, sigindex_candidates):
for sn in search_list:
if dst_type in sn:
found_matches.append(sn[dst_type])
sigindex_matches = found_matches
sigindex_candidates = found_candidates
if not (found_matches or found_candidates):
break
if match_found:
candidates.append(sig)
candidates = sigindex_matches
if not candidates:
raise TypeError("No matching signature found")
......
# cython: language_level=3
# mode: run
cimport cython
import sys, io
cy = __import__("cython")
cpdef func1(self, cython.integral x):
print "%s," % (self,),
print(f"{self},", end=' ')
if cython.integral is int:
print 'x is int', x, cython.typeof(x)
print('x is int', x, cython.typeof(x))
else:
print 'x is long', x, cython.typeof(x)
print('x is long', x, cython.typeof(x))
class A(object):
......@@ -16,6 +20,18 @@ class A(object):
def __str__(self):
return "A"
cdef class B:
cpdef int meth(self, cython.integral x):
print(f"{self},", end=' ')
if cython.integral is int:
print('x is int', x, cython.typeof(x))
else:
print('x is long', x, cython.typeof(x))
return 0
def __str__(self):
return "B"
pyfunc = func1
def test_fused_cpdef():
......@@ -32,23 +48,71 @@ def test_fused_cpdef():
A, x is long 2 long
A, x is long 2 long
A, x is long 2 long
<BLANKLINE>
B, x is long 2 long
"""
func1[int](None, 2)
func1[long](None, 2)
func1(None, 2)
print
print()
pyfunc[cy.int](None, 2)
pyfunc(None, 2)
print
print()
A.meth[cy.int](A(), 2)
A.meth(A(), 2)
A().meth[cy.long](2)
A().meth(2)
print()
B().meth(2)
midimport_run = io.StringIO()
if sys.version_info.major < 3:
# Monkey-patch midimport_run.write to accept non-unicode strings under Python 2.
midimport_run.write = lambda c: io.StringIO.write(midimport_run, unicode(c))
realstdout = sys.stdout
sys.stdout = midimport_run
try:
# Run `test_fused_cpdef()` during import and save the result for
# `test_midimport_run()`.
test_fused_cpdef()
except Exception as e:
midimport_run.write(f"{e!r}\n")
finally:
sys.stdout = realstdout
def test_midimport_run():
# At one point, dynamically calling fused cpdef functions during import
# would fail because the type signature-matching indices weren't
# yet initialized.
# (See Compiler.FusedNode.FusedCFuncDefNode._fused_signature_index,
# GH-3366.)
"""
>>> test_midimport_run()
None, x is int 2 int
None, x is long 2 long
None, x is long 2 long
<BLANKLINE>
None, x is int 2 int
None, x is long 2 long
<BLANKLINE>
A, x is int 2 int
A, x is long 2 long
A, x is long 2 long
A, x is long 2 long
<BLANKLINE>
B, x is long 2 long
"""
print(midimport_run.getvalue(), end='')
def assert_raise(func, *args):
try:
......@@ -70,23 +134,31 @@ def test_badcall():
assert_raise(A.meth)
assert_raise(A().meth[cy.int])
assert_raise(A.meth[cy.int])
assert_raise(B().meth, 1, 2, 3)
def test_nomatch():
"""
>>> func1(None, ())
Traceback (most recent call last):
TypeError: No matching signature found
"""
ctypedef long double long_double
cpdef multiarg(cython.integral x, cython.floating y):
if cython.integral is int:
print "x is an int,",
print("x is an int,", end=' ')
else:
print "x is a long,",
print("x is a long,", end=' ')
if cython.floating is long_double:
print "y is a long double:",
print("y is a long double:", end=' ')
elif float is cython.floating:
print "y is a float:",
print("y is a float:", end=' ')
else:
print "y is a double:",
print("y is a double:", end=' ')
print x, y
print(x, y)
def test_multiarg():
"""
......@@ -104,3 +176,13 @@ def test_multiarg():
multiarg[int, float](1, 2.0)
multiarg[cy.int, cy.float](1, 2.0)
multiarg(4, 5.0)
def test_ambiguousmatch():
"""
>>> multiarg(5, ())
Traceback (most recent call last):
TypeError: Function call with ambiguous argument types
>>> multiarg((), 2.0)
Traceback (most recent call last):
TypeError: Function call with ambiguous argument types
"""
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment