Commit 7846cbbc authored by da-woods's avatar da-woods Committed by GitHub

Include return type in fused types of function pointers (GH-4678)

For fused functions it makes sense that the return type is ignored
(a function can't be specialized based on return type alone) but
for function pointers the return type should be included (since
such a pointer might be an argument to a fused function)

Fixes https://github.com/cython/cython/issues/4644
parent 0b3ccd7f
...@@ -78,7 +78,7 @@ class BaseType(object): ...@@ -78,7 +78,7 @@ class BaseType(object):
""" """
return self return self
def get_fused_types(self, result=None, seen=None, subtypes=None): def get_fused_types(self, result=None, seen=None, subtypes=None, include_function_return_type=False):
subtypes = subtypes or self.subtypes subtypes = subtypes or self.subtypes
if not subtypes: if not subtypes:
return None return None
...@@ -91,10 +91,10 @@ class BaseType(object): ...@@ -91,10 +91,10 @@ class BaseType(object):
list_or_subtype = getattr(self, attr) list_or_subtype = getattr(self, attr)
if list_or_subtype: if list_or_subtype:
if isinstance(list_or_subtype, BaseType): if isinstance(list_or_subtype, BaseType):
list_or_subtype.get_fused_types(result, seen) list_or_subtype.get_fused_types(result, seen, include_function_return_type=include_function_return_type)
else: else:
for subtype in list_or_subtype: for subtype in list_or_subtype:
subtype.get_fused_types(result, seen) subtype.get_fused_types(result, seen, include_function_return_type=include_function_return_type)
return result return result
...@@ -1845,7 +1845,7 @@ class FusedType(CType): ...@@ -1845,7 +1845,7 @@ class FusedType(CType):
else: else:
raise CannotSpecialize() raise CannotSpecialize()
def get_fused_types(self, result=None, seen=None): def get_fused_types(self, result=None, seen=None, include_function_return_type=False):
if result is None: if result is None:
return [self] return [self]
...@@ -2757,6 +2757,11 @@ class CPtrType(CPointerBaseType): ...@@ -2757,6 +2757,11 @@ class CPtrType(CPointerBaseType):
return self.base_type.find_cpp_operation_type(operator, operand_type) return self.base_type.find_cpp_operation_type(operator, operand_type)
return None return None
def get_fused_types(self, result=None, seen=None, include_function_return_type=False):
# For function pointers, include the return type - unlike for fused functions themselves,
# where the return type cannot be an independent fused type (i.e. is derived or non-fused).
return super(CPointerBaseType, self).get_fused_types(result, seen, include_function_return_type=True)
class CNullPtrType(CPtrType): class CNullPtrType(CPtrType):
...@@ -3232,10 +3237,13 @@ class CFuncType(CType): ...@@ -3232,10 +3237,13 @@ class CFuncType(CType):
return result return result
def get_fused_types(self, result=None, seen=None, subtypes=None): def get_fused_types(self, result=None, seen=None, subtypes=None, include_function_return_type=False):
"""Return fused types in the order they appear as parameter types""" """Return fused types in the order they appear as parameter types"""
return super(CFuncType, self).get_fused_types(result, seen, return super(CFuncType, self).get_fused_types(
subtypes=['args']) result, seen,
# for function pointer types, we consider the result type; for plain function
# types we don't (because it must be derivable from the arguments)
subtypes=self.subtypes if include_function_return_type else ['args'])
def specialize_entry(self, entry, cname): def specialize_entry(self, entry, cname):
assert not self.is_fused assert not self.is_fused
...@@ -3865,7 +3873,7 @@ class CppClassType(CType): ...@@ -3865,7 +3873,7 @@ class CppClassType(CType):
def is_template_type(self): def is_template_type(self):
return self.templates is not None and self.template_type is None return self.templates is not None and self.template_type is None
def get_fused_types(self, result=None, seen=None): def get_fused_types(self, result=None, seen=None, include_function_return_type=False):
if result is None: if result is None:
result = [] result = []
seen = set() seen = set()
......
...@@ -510,3 +510,48 @@ def convert_to_ptr(cython.floating x): ...@@ -510,3 +510,48 @@ def convert_to_ptr(cython.floating x):
return handle_float(&x) return handle_float(&x)
elif cython.floating is double: elif cython.floating is double:
return handle_double(&x) return handle_double(&x)
cdef double get_double():
return 1.0
cdef float get_float():
return 0.0
cdef call_func_pointer(cython.floating (*f)()):
return f()
def test_fused_func_pointer():
"""
>>> test_fused_func_pointer()
1.0
0.0
"""
print(call_func_pointer(get_double))
print(call_func_pointer(get_float))
cdef double get_double_from_int(int i):
return i
cdef call_func_pointer_with_1(cython.floating (*f)(cython.integral)):
return f(1)
def test_fused_func_pointer2():
"""
>>> test_fused_func_pointer2()
1.0
"""
print(call_func_pointer_with_1(get_double_from_int))
cdef call_function_that_calls_fused_pointer(object (*f)(cython.floating (*)(cython.integral))):
if cython.floating is double and cython.integral is int:
return 5*f(get_double_from_int)
else:
return None # practically it's hard to make this kind of function useful...
def test_fused_func_pointer_multilevel():
"""
>>> test_fused_func_pointer_multilevel()
5.0
None
"""
print(call_function_that_calls_fused_pointer(call_func_pointer_with_1[double, int]))
print(call_function_that_calls_fused_pointer(call_func_pointer_with_1[float, int]))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment