From c430314b095f32e0a3e6d31d6b6dc1ec5985ead2 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw <robertwb@gmail.com> Date: Sat, 30 May 2015 23:09:17 -0700 Subject: [PATCH] Allow composite fused types. --- Cython/Compiler/Nodes.py | 2 -- Cython/Compiler/PyrexTypes.py | 8 +++++++- tests/errors/fused_types.pyx | 1 - tests/run/fused_types.pyx | 15 +++++++++++++++ 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 2694ca8ba..67e1eb2a1 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1215,8 +1215,6 @@ class FusedTypeNode(CBaseTypeNode): if type in types: error(type_node.pos, "Type specified multiple times") - elif type.is_fused: - error(type_node.pos, "Cannot fuse a fused type") else: types.append(type) diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index a81b0395d..b07e55320 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -1339,7 +1339,13 @@ class FusedType(CType): exception_check = 0 def __init__(self, types, name=None): - self.types = types + # Use list rather than set to preserve order. + flattened_types = [t for t in types if not t.is_fused] + subtypes = sum([t.types for t in types if t.is_fused], []) + for type in subtypes: + if type not in flattened_types: + flattened_types.append(type) + self.types = flattened_types self.name = name def declaration_code(self, entity_code, for_display = 0, diff --git a/tests/errors/fused_types.pyx b/tests/errors/fused_types.pyx index 6b3465abd..e4b6bdba9 100644 --- a/tests/errors/fused_types.pyx +++ b/tests/errors/fused_types.pyx @@ -67,7 +67,6 @@ func(x, y) _ERRORS = u""" 10:15: fused_type does not take keyword arguments 15:38: Type specified multiple times -17:33: Cannot fuse a fused type 26:4: Invalid use of fused types, type cannot be specialized 26:4: Not enough types specified to specialize the function, int2_t is still fused 27:4: Invalid use of fused types, type cannot be specialized diff --git a/tests/run/fused_types.pyx b/tests/run/fused_types.pyx index 2fc44e2c6..819cda13d 100644 --- a/tests/run/fused_types.pyx +++ b/tests/run/fused_types.pyx @@ -19,6 +19,7 @@ other_t = cython.fused_type(int, double) ctypedef double *p_double ctypedef int *p_int fused_type3 = cython.fused_type(int, double) +fused_composite = cython.fused_type(fused_type2, fused_type3) def test_pure(): """ @@ -349,3 +350,17 @@ def test_index_fused_args(cython.floating f, ints_t i): double int """ _test_index_fused_args[cython.floating, ints_t](f, i) + +def test_composite(fused_composite x): + """ + >>> test_composite('a') + 'a' + >>> test_composite(3) + 6 + >>> test_composite(3.0) + 6.0 + """ + if fused_composite is string_t: + return x + else: + return 2 * x -- 2.30.9