Commit ee914074 authored by Mark Florisson's avatar Mark Florisson

Allow type checks on fused types

parent e7426088
...@@ -2027,6 +2027,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2027,6 +2027,8 @@ class FusedCFuncDefNode(StatListNode):
Gives a list of fused types and the parent environment, make copies Gives a list of fused types and the parent environment, make copies
of the original cdef function. of the original cdef function.
""" """
from Cython.Compiler import ParseTreeTransforms
permutations = self.node.type.get_all_specific_permutations() permutations = self.node.type.get_all_specific_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
...@@ -2053,6 +2055,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2053,6 +2055,8 @@ class FusedCFuncDefNode(StatListNode):
cname = self.node.type.get_specific_cname(cname) cname = self.node.type.get_specific_cname(cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname copied_node.entry.func_cname = copied_node.entry.cname = cname
ParseTreeTransforms.ReplaceFusedTypeChecks(copied_node.local_scope)(copied_node)
class PyArgDeclNode(Node): class PyArgDeclNode(Node):
# Argument which must be a Python object (used # Argument which must be a Python object (used
......
...@@ -1882,6 +1882,94 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1882,6 +1882,94 @@ class TransformBuiltinMethods(EnvTransform):
return node return node
class ReplaceFusedTypeChecks(VisitorTransform):
"""
This is not a transform in the pipeline. It is invoked on the specific
versions of a cdef function with fused argument types. It filters out any
type branches that don't match. e.g.
if fused_t is mytype:
...
elif fused_t in other_fused_type:
...
"""
def __init__(self, local_scope):
super(ReplaceFusedTypeChecks, self).__init__()
self.local_scope = local_scope
def visit_IfStatNode(self, node):
if_clauses = node.if_clauses[:]
self.visitchildren(node)
if if_clauses != node.if_clauses:
if node.if_clauses:
return node.if_clauses[0]
return node.else_clause
return node
def visit_IfClauseNode(self, node):
cond = node.condition
if isinstance(cond, ExprNodes.PrimaryCmpNode):
type1, type2 = self.get_types(cond)
op = cond.operator
type1 = self.specialize_type(type1, cond.operand1.pos)
if type1 and type2:
if op == 'is':
type2 = self.specialize_type(type2, cond.operand1.pos)
if type1.same_as(type2):
return node.body
elif op in ('in', 'not_in'):
# We have to do an instance check directly, as operand2
# needs to be a fused type and not a type with a subtype
# that is fused. First unpack the typedef
if isinstance(type2, PyrexTypes.CTypedefType):
type2 = type2.typedef_base_type
if type1.is_fused or not isinstance(type2, PyrexTypes.FusedType):
error(cond.pos, "Can use 'in' or 'not in' only on a "
"specific and a fused type")
elif op == 'in':
if type1 in type2.types:
return node.body
else:
if type1 not in type2.types:
return node.body
return None
return node
def get_types(self, node):
if node.operand1.is_name and node.operand2.is_name:
return self.get_type(node.operand1), self.get_type(node.operand2)
return None, None
def get_type(self, node):
type = PyrexTypes.parse_basic_type(node.name)
if not type:
# Don't use self.lookup_type() as it will specialize
entry = self.local_scope.lookup(node.name)
if entry and entry.is_type:
type = entry.type
return type
def specialize_type(self, type, pos):
try:
return type.specialize(self.local_scope.fused_to_specific)
except KeyError:
error(pos, "Type is not specific")
return type
def visit_Node(self, node):
self.visitchildren(node)
return node
class DebugTransform(CythonTransform): class DebugTransform(CythonTransform):
""" """
Create debug information and all functions' visibility to extern in order Create debug information and all functions' visibility to extern in order
......
cimport cython
ctypedef char *string_t
ctypedef cython.fused_type(int, long, float, string_t) fused_t
ctypedef cython.fused_type(int, long) other_t
ctypedef cython.fused_type(int, float) unresolved_t
cdef func(fused_t a, other_t b):
cdef int int_a
cdef string_t string_a
cdef other_t other_a
if fused_t is other_t:
print 'fused_t is other_t'
other_a = a
if fused_t is int:
print 'fused_t is int'
int_a = a
if fused_t is string_t:
print 'fused_t is string_t'
string_a = a
if fused_t in unresolved_t:
print 'fused_t in unresolved_t'
if int in unresolved_t:
print 'int in unresolved_t'
if string_t in unresolved_t:
print 'string_t in unresolved_t'
def test_int_int():
"""
>>> test_int_int()
fused_t is other_t
fused_t is int
fused_t in unresolved_t
int in unresolved_t
"""
cdef int x = 1
cdef int y = 2
func(x, y)
def test_int_long():
"""
>>> test_int_long()
fused_t is int
fused_t in unresolved_t
int in unresolved_t
"""
cdef int x = 1
cdef long y = 2
func(x, y)
def test_float_int():
"""
>>> test_float_int()
fused_t in unresolved_t
int in unresolved_t
"""
cdef float x = 1
cdef int y = 2
func(x, y)
def test_string_int():
"""
>>> test_string_int()
fused_t is string_t
int in unresolved_t
"""
cdef string_t x = b"spam"
cdef int y = 2
func(x, y)
cdef if_then_else(fused_t a, other_t b):
cdef other_t other_a
cdef string_t string_a
cdef fused_t specific_a
if fused_t is other_t:
print 'fused_t is other_t'
other_a = a
elif fused_t is string_t:
print 'fused_t is string_t'
string_a = a
else:
print 'none of the above'
specific_a = a
def test_if_then_else_long_long():
"""
>>> test_if_then_else_long_long()
fused_t is other_t
"""
cdef long x = 0, y = 0
if_then_else(x, y)
def test_if_then_else_string_int():
"""
>>> test_if_then_else_string_int()
fused_t is string_t
"""
cdef string_t x = b"spam"
cdef int y = 0
if_then_else(x, y)
def test_if_then_else_float_int():
"""
>>> test_if_then_else_float_int()
none of the above
"""
cdef float x = 0.0
cdef int y = 1
if_then_else(x, y)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment