Commit 27b5adbb authored by will-ca's avatar will-ca Committed by GitHub

Make fused function dispatch O(n) for `cpdef` functions. (GH-3366)

* Rewrote signature matching for fused cpdef function dispatch to use a pre-built tree index in a mutable default argument and be O(n).

* Added test to ensure proper differentiation between ambiguously compatible and definitely compatible arguments.

* Added test to ensure fused cpdef's can be called by the module itself during import.

* Added test to ensure consistent handling of ambiguous fused cpdef signatures.

* Test for explicitly defined fused cpdef method.

* Add .komodoproject to .gitignore.

* Add /cython_debug/ to .gitignore.

Closes #1385.
parent 4fd90184
...@@ -23,6 +23,7 @@ Demos/*/*.html ...@@ -23,6 +23,7 @@ Demos/*/*.html
/TEST_TMP/ /TEST_TMP/
/build/ /build/
/cython_build/
/wheelhouse*/ /wheelhouse*/
!tests/build/ !tests/build/
/dist/ /dist/
...@@ -52,3 +53,5 @@ MANIFEST ...@@ -52,3 +53,5 @@ MANIFEST
/.idea /.idea
/*.iml /*.iml
# Komodo EDIT/IDE project files
/*.komodoproject
...@@ -584,6 +584,26 @@ class FusedCFuncDefNode(StatListNode): ...@@ -584,6 +584,26 @@ class FusedCFuncDefNode(StatListNode):
{{endif}} {{endif}}
""") """)
def _fused_signature_index(self, pyx_code):
"""
Generate Cython code for constructing a persistent nested dictionary index of
fused type specialization signatures.
"""
pyx_code.put_chunk(
u"""
if not _fused_sigindex:
for sig in <dict>signatures:
sigindex_node = _fused_sigindex
sig_series = sig.strip('()').split('|')
for sig_type in sig_series[:-1]:
if sig_type not in sigindex_node:
sigindex_node[sig_type] = sigindex_node = {}
else:
sigindex_node = sigindex_node[sig_type]
sigindex_node[sig_series[-1]] = sig
"""
)
def make_fused_cpdef(self, orig_py_func, env, is_def): def make_fused_cpdef(self, orig_py_func, env, is_def):
""" """
This creates the function that is indexable from Python and does This creates the function that is indexable from Python and does
...@@ -620,10 +640,14 @@ class FusedCFuncDefNode(StatListNode): ...@@ -620,10 +640,14 @@ class FusedCFuncDefNode(StatListNode):
pyx_code.put_chunk( pyx_code.put_chunk(
u""" u"""
def __pyx_fused_cpdef(signatures, args, kwargs, defaults): def __pyx_fused_cpdef(signatures, args, kwargs, defaults, *, _fused_sigindex={}):
# FIXME: use a typed signature - currently fails badly because # FIXME: use a typed signature - currently fails badly because
# default arguments inherit the types we specify here! # default arguments inherit the types we specify here!
cdef list search_list
cdef dict sn, sigindex_node
dest_sig = [None] * {{n_fused}} dest_sig = [None] * {{n_fused}}
if kwargs is not None and not kwargs: if kwargs is not None and not kwargs:
...@@ -691,23 +715,36 @@ class FusedCFuncDefNode(StatListNode): ...@@ -691,23 +715,36 @@ class FusedCFuncDefNode(StatListNode):
env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c")) env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c")) env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
self._fused_signature_index(pyx_code)
pyx_code.put_chunk( pyx_code.put_chunk(
u""" u"""
candidates = [] sigindex_matches = []
for sig in <dict>signatures: sigindex_candidates = [_fused_sigindex]
match_found = False
src_sig = sig.strip('()').split('|') for dst_type in dest_sig:
for i in range(len(dest_sig)): found_matches = []
dst_type = dest_sig[i] found_candidates = []
if dst_type is not None: # Make two seperate lists: One for signature sub-trees
if src_sig[i] == dst_type: # with at least one definite match, and another for
match_found = True # signature sub-trees with only ambiguous matches
# (where `dest_sig[i] is None`).
if dst_type is None:
for sn in sigindex_matches:
found_matches.extend(sn.values())
for sn in sigindex_candidates:
found_candidates.extend(sn.values())
else: else:
match_found = False for search_list in (sigindex_matches, sigindex_candidates):
for sn in search_list:
if dst_type in sn:
found_matches.append(sn[dst_type])
sigindex_matches = found_matches
sigindex_candidates = found_candidates
if not (found_matches or found_candidates):
break break
if match_found: candidates = sigindex_matches
candidates.append(sig)
if not candidates: if not candidates:
raise TypeError("No matching signature found") raise TypeError("No matching signature found")
......
# cython: language_level=3
# mode: run
cimport cython cimport cython
import sys, io
cy = __import__("cython") cy = __import__("cython")
cpdef func1(self, cython.integral x): cpdef func1(self, cython.integral x):
print "%s," % (self,), print(f"{self},", end=' ')
if cython.integral is int: if cython.integral is int:
print 'x is int', x, cython.typeof(x) print('x is int', x, cython.typeof(x))
else: else:
print 'x is long', x, cython.typeof(x) print('x is long', x, cython.typeof(x))
class A(object): class A(object):
...@@ -16,6 +20,18 @@ class A(object): ...@@ -16,6 +20,18 @@ class A(object):
def __str__(self): def __str__(self):
return "A" return "A"
cdef class B:
cpdef int meth(self, cython.integral x):
print(f"{self},", end=' ')
if cython.integral is int:
print('x is int', x, cython.typeof(x))
else:
print('x is long', x, cython.typeof(x))
return 0
def __str__(self):
return "B"
pyfunc = func1 pyfunc = func1
def test_fused_cpdef(): def test_fused_cpdef():
...@@ -32,23 +48,71 @@ def test_fused_cpdef(): ...@@ -32,23 +48,71 @@ def test_fused_cpdef():
A, x is long 2 long A, x is long 2 long
A, x is long 2 long A, x is long 2 long
A, x is long 2 long A, x is long 2 long
<BLANKLINE>
B, x is long 2 long
""" """
func1[int](None, 2) func1[int](None, 2)
func1[long](None, 2) func1[long](None, 2)
func1(None, 2) func1(None, 2)
print print()
pyfunc[cy.int](None, 2) pyfunc[cy.int](None, 2)
pyfunc(None, 2) pyfunc(None, 2)
print print()
A.meth[cy.int](A(), 2) A.meth[cy.int](A(), 2)
A.meth(A(), 2) A.meth(A(), 2)
A().meth[cy.long](2) A().meth[cy.long](2)
A().meth(2) A().meth(2)
print()
B().meth(2)
midimport_run = io.StringIO()
if sys.version_info.major < 3:
# Monkey-patch midimport_run.write to accept non-unicode strings under Python 2.
midimport_run.write = lambda c: io.StringIO.write(midimport_run, unicode(c))
realstdout = sys.stdout
sys.stdout = midimport_run
try:
# Run `test_fused_cpdef()` during import and save the result for
# `test_midimport_run()`.
test_fused_cpdef()
except Exception as e:
midimport_run.write(f"{e!r}\n")
finally:
sys.stdout = realstdout
def test_midimport_run():
# At one point, dynamically calling fused cpdef functions during import
# would fail because the type signature-matching indices weren't
# yet initialized.
# (See Compiler.FusedNode.FusedCFuncDefNode._fused_signature_index,
# GH-3366.)
"""
>>> test_midimport_run()
None, x is int 2 int
None, x is long 2 long
None, x is long 2 long
<BLANKLINE>
None, x is int 2 int
None, x is long 2 long
<BLANKLINE>
A, x is int 2 int
A, x is long 2 long
A, x is long 2 long
A, x is long 2 long
<BLANKLINE>
B, x is long 2 long
"""
print(midimport_run.getvalue(), end='')
def assert_raise(func, *args): def assert_raise(func, *args):
try: try:
...@@ -70,23 +134,31 @@ def test_badcall(): ...@@ -70,23 +134,31 @@ def test_badcall():
assert_raise(A.meth) assert_raise(A.meth)
assert_raise(A().meth[cy.int]) assert_raise(A().meth[cy.int])
assert_raise(A.meth[cy.int]) assert_raise(A.meth[cy.int])
assert_raise(B().meth, 1, 2, 3)
def test_nomatch():
"""
>>> func1(None, ())
Traceback (most recent call last):
TypeError: No matching signature found
"""
ctypedef long double long_double ctypedef long double long_double
cpdef multiarg(cython.integral x, cython.floating y): cpdef multiarg(cython.integral x, cython.floating y):
if cython.integral is int: if cython.integral is int:
print "x is an int,", print("x is an int,", end=' ')
else: else:
print "x is a long,", print("x is a long,", end=' ')
if cython.floating is long_double: if cython.floating is long_double:
print "y is a long double:", print("y is a long double:", end=' ')
elif float is cython.floating: elif float is cython.floating:
print "y is a float:", print("y is a float:", end=' ')
else: else:
print "y is a double:", print("y is a double:", end=' ')
print x, y print(x, y)
def test_multiarg(): def test_multiarg():
""" """
...@@ -104,3 +176,13 @@ def test_multiarg(): ...@@ -104,3 +176,13 @@ def test_multiarg():
multiarg[int, float](1, 2.0) multiarg[int, float](1, 2.0)
multiarg[cy.int, cy.float](1, 2.0) multiarg[cy.int, cy.float](1, 2.0)
multiarg(4, 5.0) multiarg(4, 5.0)
def test_ambiguousmatch():
"""
>>> multiarg(5, ())
Traceback (most recent call last):
TypeError: Function call with ambiguous argument types
>>> multiarg((), 2.0)
Traceback (most recent call last):
TypeError: Function call with ambiguous argument types
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment