Commit 430ec7cf authored by Mark Florisson's avatar Mark Florisson

Change semantics of fused types

parent 82c13a65
...@@ -2332,8 +2332,6 @@ class IndexNode(ExprNode): ...@@ -2332,8 +2332,6 @@ class IndexNode(ExprNode):
""" """
base_type = self.base.type base_type = self.base.type
def err(msg, pos=None):
error(pos or self.pos, msg)
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
specific_types = [] specific_types = []
...@@ -2347,11 +2345,15 @@ class IndexNode(ExprNode): ...@@ -2347,11 +2345,15 @@ class IndexNode(ExprNode):
positions.append(arg.pos) positions.append(arg.pos)
specific_types.append(arg.analyse_as_type(env)) specific_types.append(arg.analyse_as_type(env))
else: else:
return err("Can only index fused functions with types") return error(self.pos, "Can only index fused functions with types")
fused_types = base_type.get_fused_types() fused_types = base_type.get_fused_types()
if len(specific_types) > len(fused_types): if len(specific_types) > len(fused_types):
return err("Too many types specified") return error(self.pos, "Too many types specified")
elif len(specific_types) < len(fused_types):
t = fused_types[len(specific_types)]
return error(self.pos, "Not enough types specified to specialize "
"the function, %s is still fused" % t)
# See if our index types form valid specializations # See if our index types form valid specializations
for pos, specific_type, fused_type in zip(positions, for pos, specific_type, fused_type in zip(positions,
...@@ -2359,27 +2361,19 @@ class IndexNode(ExprNode): ...@@ -2359,27 +2361,19 @@ class IndexNode(ExprNode):
fused_types): fused_types):
if not Utils.any([specific_type.same_as(t) if not Utils.any([specific_type.same_as(t)
for t in fused_type.types]): for t in fused_type.types]):
return err("Type not in fused type", pos=pos) return error(pos, "Type not in fused type")
if specific_type is None or specific_type.is_error: if specific_type is None or specific_type.is_error:
return return
fused_to_specific = dict(zip(fused_types, specific_types)) fused_to_specific = dict(zip(fused_types, specific_types))
# If we are only partially fused, specialize accordingly
for fused_type in fused_types:
if fused_type not in fused_to_specific:
fused_to_specific[fused_type] = fused_type
type = base_type.specialize(fused_to_specific) type = base_type.specialize(fused_to_specific)
if type is not base_type: if type.is_fused:
import copy # Only partially specific, this is invalid
e = copy.copy(base_type.entry) error(self.pos,
e.type = type "Index operation makes function only partially specific")
type.entry = e else:
if not type.is_fused:
# Fully specific, find the signature with the specialized entry # Fully specific, find the signature with the specialized entry
for signature in self.base.type.get_all_specific_function_types(): for signature in self.base.type.get_all_specific_function_types():
if type.same_as(signature): if type.same_as(signature):
...@@ -2387,9 +2381,6 @@ class IndexNode(ExprNode): ...@@ -2387,9 +2381,6 @@ class IndexNode(ExprNode):
break break
else: else:
assert False assert False
else:
# Only partially specific
self.type = type
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
...@@ -3117,8 +3108,10 @@ class SimpleCallNode(CallNode): ...@@ -3117,8 +3108,10 @@ class SimpleCallNode(CallNode):
return return
elif hasattr(self.function, 'entry'): elif hasattr(self.function, 'entry'):
overloaded_entry = self.function.entry overloaded_entry = self.function.entry
elif isinstance(self.function, IndexNode) and self.function.type.is_fused: elif (isinstance(self.function, IndexNode) and
self.function.base.type.is_fused):
overloaded_entry = self.function.type.entry overloaded_entry = self.function.type.entry
self.function.entry = self.function.type.entry
else: else:
overloaded_entry = None overloaded_entry = None
......
...@@ -937,28 +937,23 @@ class FusedTypeNode(CBaseTypeNode): ...@@ -937,28 +937,23 @@ class FusedTypeNode(CBaseTypeNode):
child_attrs = [] child_attrs = []
def analyse(self, env): def analyse(self, env):
self.types = [type.analyse_as_type(env) for type in self.types] # Note: this list may still contain multiple of the same entries
types = [type.analyse_as_type(env) for type in self.types]
if len(self.types) == 1: if len(self.types) == 1:
return self.types[0] return types[0]
types = []
seen = cython.set() seen = cython.set()
for type_node, type in zip(self.types, types):
for type in self.types: if type in seen:
self.add_type(type, types, seen) error(type_node.pos, "Type specified multiple times")
else:
return PyrexTypes.FusedType(types)
def add_type(self, type, types, seen):
if type not in seen:
seen.add(type) seen.add(type)
if type.is_fused: if type.is_fused:
for specific_type in PyrexTypes.get_specific_types(type): error(type_node.pos, "Cannot fuse a fused type")
self.add_type(specific_type, types, seen)
else: self.types = types
types.append(type) return PyrexTypes.FusedType(types)
class CVarDefNode(StatNode): class CVarDefNode(StatNode):
...@@ -1202,14 +1197,11 @@ class CTypeDefNode(StatNode): ...@@ -1202,14 +1197,11 @@ class CTypeDefNode(StatNode):
child_attrs = ["base_type", "declarator"] child_attrs = ["base_type", "declarator"]
def analyse_declarations(self, env): def analyse_declarations(self, env):
"""
If we are a fused type, do a normal type declaration, as we want
declared variables to have a FusedType type, not a CTypeDefType.
"""
base = self.base_type.analyse(env) base = self.base_type.analyse(env)
name_declarator, type = self.declarator.analyse(base, env) name_declarator, type = self.declarator.analyse(base, env)
name = name_declarator.name name = name_declarator.name
cname = name_declarator.cname cname = name_declarator.cname
entry = env.declare_typedef(name, type, self.pos, entry = env.declare_typedef(name, type, self.pos,
cname = cname, visibility = self.visibility, api = self.api) cname = cname, visibility = self.visibility, api = self.api)
...@@ -2040,6 +2032,9 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2040,6 +2032,9 @@ class FusedCFuncDefNode(StatListNode):
from Cython.Compiler import ParseTreeTransforms from Cython.Compiler import ParseTreeTransforms
permutations = self.node.type.get_all_specific_permutations() permutations = self.node.type.get_all_specific_permutations()
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations])
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
......
...@@ -1347,24 +1347,17 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1347,24 +1347,17 @@ class AnalyseExpressionsTransform(CythonTransform):
argument types with a NameNode referring to the function with argument types with a NameNode referring to the function with
specialized entry and type. specialized entry and type.
""" """
was_nested = self.nested_index_node
self.nested_index_node = True
self.visit_Node(node) self.visit_Node(node)
self.nested_index_node = was_nested
type = node.type type = node.type
if type.is_cfunction and type.is_fused and not self.nested_index_node: if type.is_cfunction and node.base.type.is_fused:
error(node.pos, "Not enough types were specified to indicate a "
"specialized function")
elif type.is_cfunction and node.base.type.is_fused:
while not node.is_name:
node = node.base node = node.base
if not node.is_name:
error(node.pos, "Can only index a fused function once")
node.type = PyrexTypes.error_type
else:
node.type = type node.type = type
node.entry = type.entry node.entry = type.entry
print node.entry.cname
return node
return node return node
...@@ -1905,6 +1898,12 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1905,6 +1898,12 @@ class ReplaceFusedTypeChecks(VisitorTransform):
... ...
""" """
# Defer the import until now to avoid circularity...
from Cython.Compiler import Optimize
transform = Optimize.ConstantFolding()
transform.check_constant_value_not_set = False
def __init__(self, local_scope): def __init__(self, local_scope):
super(ReplaceFusedTypeChecks, self).__init__() super(ReplaceFusedTypeChecks, self).__init__()
self.local_scope = local_scope self.local_scope = local_scope
...@@ -1914,12 +1913,8 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1914,12 +1913,8 @@ class ReplaceFusedTypeChecks(VisitorTransform):
Filters out any if clauses with false compile time type check Filters out any if clauses with false compile time type check
expression. expression.
""" """
from Cython.Compiler import Optimize
self.visitchildren(node) self.visitchildren(node)
transform = Optimize.ConstantFolding() return self.transform(node)
transform.check_constant_value_not_set = False
return transform(node)
def visit_PrimaryCmpNode(self, node): def visit_PrimaryCmpNode(self, node):
type1 = node.operand1.analyse_as_type(self.local_scope) type1 = node.operand1.analyse_as_type(self.local_scope)
...@@ -1932,7 +1927,7 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1932,7 +1927,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
type1 = self.specialize_type(type1, node.operand1.pos) type1 = self.specialize_type(type1, node.operand1.pos)
op = node.operator op = node.operator
if op in ('is', 'is not', '==', '!='): if op in ('is', 'is_not', '==', '!='):
type2 = self.specialize_type(type2, node.operand2.pos) type2 = self.specialize_type(type2, node.operand2.pos)
is_same = type1.same_as(type2) is_same = type1.same_as(type2)
......
...@@ -677,6 +677,7 @@ class FusedType(PyrexType): ...@@ -677,6 +677,7 @@ class FusedType(PyrexType):
""" """
is_fused = 1 is_fused = 1
name = None
def __init__(self, types): def __init__(self, types):
self.types = types self.types = types
......
...@@ -12,10 +12,28 @@ dtype4 = cython.typedef(cython.fused_type(int, long, kw=None)) ...@@ -12,10 +12,28 @@ dtype4 = cython.typedef(cython.fused_type(int, long, kw=None))
ctypedef public cython.fused_type(int, long) dtype7 ctypedef public cython.fused_type(int, long) dtype7
ctypedef api cython.fused_type(int, long) dtype8 ctypedef api cython.fused_type(int, long) dtype8
ctypedef cython.fused_type(short, short int, int) int_t
ctypedef cython.fused_type(int, long) int2_t
ctypedef cython.fused_type(int2_t, int) dtype9
ctypedef cython.fused_type(float, double) floating
cdef func(floating x, int2_t y):
print x, y
cdef float x = 10.0
cdef int y = 10
func[float](x, y)
func[float][int](x, y)
func[float, int](x)
func[float, int](x, y, y)
func(x, y=y)
# This is all valid # This is all valid
ctypedef fused_type(int, long, float) dtype5 ctypedef fused_type(int, long, float) dtype5
ctypedef cython.fused_type(int, long) dtype6 ctypedef cython.fused_type(int, long) dtype6
func[float, int](x, y)
func(x, y)
_ERRORS = u""" _ERRORS = u"""
fused_types.pyx:7:13: Can only fuse types with cython.fused_type() fused_types.pyx:7:13: Can only fuse types with cython.fused_type()
...@@ -24,4 +42,11 @@ fused_types.pyx:9:20: 'foo' is not a type identifier ...@@ -24,4 +42,11 @@ fused_types.pyx:9:20: 'foo' is not a type identifier
fused_types.pyx:10:23: fused_type does not take keyword arguments fused_types.pyx:10:23: fused_type does not take keyword arguments
fused_types.pyx:12:0: Fused types cannot be public or api fused_types.pyx:12:0: Fused types cannot be public or api
fused_types.pyx:13:0: Fused types cannot be public or api fused_types.pyx:13:0: Fused types cannot be public or api
fused_types.pyx:15:34: Type specified multiple times
fused_types.pyx:17:27: Cannot fuse a fused type
fused_types.pyx:26:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:27:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:28:16: Call with wrong number of arguments (expected 2, got 1)
fused_types.pyx:29:16: Call with wrong number of arguments (expected 2, got 3)
fused_types.pyx:30:4: Keyword and starred arguments not allowed in cdef functions.
""" """
...@@ -5,12 +5,15 @@ ctypedef char *string_t ...@@ -5,12 +5,15 @@ ctypedef char *string_t
ctypedef cython.fused_type(int, long, float, string_t) fused_t ctypedef cython.fused_type(int, long, float, string_t) fused_t
ctypedef cython.fused_type(int, long) other_t ctypedef cython.fused_type(int, long) other_t
ctypedef cython.fused_type(short, short int, short, int) base_t ctypedef cython.fused_type(short int, int) base_t
ctypedef cython.fused_type(float complex, double complex, ctypedef cython.fused_type(float complex, double complex,
int complex, long complex) complex_t int complex, long complex) complex_t
ctypedef base_t **base_t_p_p ctypedef base_t **base_t_p_p
ctypedef cython.fused_type(char, base_t_p_p, fused_t, complex_t) composed_t # ctypedef cython.fused_type(char, base_t_p_p, fused_t, complex_t) composed_t
ctypedef cython.fused_type(char, int, float, string_t, float complex,
double complex, int complex, long complex,
cython.p_p_int) composed_t
cdef func(fused_t a, other_t b): cdef func(fused_t a, other_t b):
...@@ -160,8 +163,8 @@ def test_composed_types(): ...@@ -160,8 +163,8 @@ def test_composed_types():
(0.9+0.4j) (0.9+0.4j)
<BLANKLINE> <BLANKLINE>
not a complex number not a complex number
9 10 7 8
19 15
<BLANKLINE> <BLANKLINE>
7 8 7 8
<BLANKLINE> <BLANKLINE>
...@@ -177,7 +180,7 @@ def test_composed_types(): ...@@ -177,7 +180,7 @@ def test_composed_types():
print result print result
print print
print composed(c + 2, d + 2) print composed(c, d)
print print
composed(&cp, &dp) composed(&cp, &dp)
......
# mode: run # mode: run
cimport cython cimport cython
#from cython cimport p_double, p_int
from cpython cimport Py_INCREF from cpython cimport Py_INCREF
from Cython import Shadow as pure_cython from Cython import Shadow as pure_cython
ctypedef char * string_t ctypedef char * string_t
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
ctypedef cython.fused_type(int, long, float, double, string_t) fused_type1 ctypedef cython.fused_type(int, long, float, double, string_t) fused_type1
ctypedef cython.fused_type(string_t) fused_type2 ctypedef cython.fused_type(string_t) fused_type2
ctypedef fused_type1 *composed_t
ctypedef cython.fused_type(int, long, float, double) other_t
ctypedef double *p_double
ctypedef int *p_int
def test_pure(): def test_pure():
...@@ -101,3 +108,76 @@ def test_fused_with_pointer(): ...@@ -101,3 +108,76 @@ def test_fused_with_pointer():
print fused_with_pointer(float_array) print fused_with_pointer(float_array)
print print
print fused_with_pointer(string_array) print fused_with_pointer(string_array)
cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a):
cdef fused_type1 result
if composed_t is p_double:
print "double pointer"
if fused_type1 in floating:
result = x + y[0] + z[0] + a[0]
return result
def test_specializations():
"""
>>> test_specializations()
double pointer
double pointer
double pointer
double pointer
double pointer
"""
cdef object (*f)(double, double *, double *, int *)
cdef double somedouble = 2.2
cdef double otherdouble = 3.3
cdef int someint = 4
cdef p_double somedouble_p = &somedouble
cdef p_double otherdouble_p = &otherdouble
cdef p_int someint_p = &someint
f = test_specialize
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
f = <object (*)(double, double *, double *, int *)> test_specialize
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
assert (<object (*)(double, double *, double *, int *)>
test_specialize)(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
f = test_specialize[double, int]
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
assert test_specialize[double, int](1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
# The following cases are not supported
# f = test_specialize[double][p_int]
# print f(1.1, somedouble_p, otherdouble_p)
# print
# print test_specialize[double][p_int](1.1, somedouble_p, otherdouble_p)
# print
# print test_specialize[double](1.1, somedouble_p, otherdouble_p)
# print
#cdef opt_args(integral x, floating y = 4.0):
# print x, y
def test_opt_args():
"""
ToDO: enable and fix
test_opt_args()
3 4.0
3 4.0
3 4.0
3 4.0
"""
#opt_args[int, float](3)
#opt_args[int, double](3)
#opt_args[int, float](3, 4.0)
#opt_args[int, double](3, 4.0)
...@@ -31,7 +31,7 @@ cdef public class MyExt [ type MyExtType, object MyExtObject ]: ...@@ -31,7 +31,7 @@ cdef public class MyExt [ type MyExtType, object MyExtObject ]:
ctypedef char *string_t ctypedef char *string_t
ctypedef cython.fused_type(int, float) simple_t ctypedef cython.fused_type(int, float) simple_t
ctypedef cython.fused_type(simple_t, string_t) less_simple_t ctypedef cython.fused_type(int, float, string_t) less_simple_t
ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) object_t ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) object_t
ctypedef cython.fused_type(str, unicode, bytes) builtin_t ctypedef cython.fused_type(str, unicode, bytes) builtin_t
...@@ -82,6 +82,3 @@ assert f(mystruct, 5).a == 10 ...@@ -82,6 +82,3 @@ assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t, int] f = add_simple[mystruct_t, int]
assert f(mystruct, 5).a == 10 assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t][int]
assert f(mystruct, 5).a == 10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment