Commit fc0547e4 authored by Mark Florisson's avatar Mark Florisson

Allow type expressions in comparisons

parent ee914074
...@@ -141,7 +141,9 @@ class Context(object): ...@@ -141,7 +141,9 @@ class Context(object):
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(self), WithTransform(self),
DecoratorTransform(self), DecoratorTransform(self),
# PrintTree(),
AnalyseDeclarationsTransform(self), AnalyseDeclarationsTransform(self),
# PrintTree(),
AutoTestDictTransform(self), AutoTestDictTransform(self),
EmbedSignature(self), EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), ## Necessary? EarlyReplaceBuiltinCalls(self), ## Necessary?
...@@ -159,6 +161,7 @@ class Context(object): ...@@ -159,6 +161,7 @@ class Context(object):
DropRefcountingTransform(), DropRefcountingTransform(),
FinalOptimizePhase(self), FinalOptimizePhase(self),
GilCheck(), GilCheck(),
# PrintTree(),
] ]
def create_pyx_pipeline(self, options, result, py=False): def create_pyx_pipeline(self, options, result, py=False):
......
...@@ -2055,6 +2055,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2055,6 +2055,7 @@ class FusedCFuncDefNode(StatListNode):
cname = self.node.type.get_specific_cname(cname) cname = self.node.type.get_specific_cname(cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname copied_node.entry.func_cname = copied_node.entry.cname = cname
# TransformBuiltinMethods(copied_node)
ParseTreeTransforms.ReplaceFusedTypeChecks(copied_node.local_scope)(copied_node) ParseTreeTransforms.ReplaceFusedTypeChecks(copied_node.local_scope)(copied_node)
......
...@@ -1912,11 +1912,13 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1912,11 +1912,13 @@ class ReplaceFusedTypeChecks(VisitorTransform):
def visit_IfClauseNode(self, node): def visit_IfClauseNode(self, node):
cond = node.condition cond = node.condition
if isinstance(cond, ExprNodes.PrimaryCmpNode): if isinstance(cond, ExprNodes.PrimaryCmpNode):
type1, type2 = self.get_types(cond) type1 = cond.operand1.analyse_as_type(self.local_scope)
op = cond.operator type2 = cond.operand2.analyse_as_type(self.local_scope)
type1 = self.specialize_type(type1, cond.operand1.pos)
if type1 and type2: if type1 and type2:
type1 = self.specialize_type(type1, cond.operand1.pos)
op = cond.operator
if op == 'is': if op == 'is':
type2 = self.specialize_type(type2, cond.operand1.pos) type2 = self.specialize_type(type2, cond.operand1.pos)
if type1.same_as(type2): if type1.same_as(type2):
...@@ -1942,22 +1944,6 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1942,22 +1944,6 @@ class ReplaceFusedTypeChecks(VisitorTransform):
return node return node
def get_types(self, node):
if node.operand1.is_name and node.operand2.is_name:
return self.get_type(node.operand1), self.get_type(node.operand2)
return None, None
def get_type(self, node):
type = PyrexTypes.parse_basic_type(node.name)
if not type:
# Don't use self.lookup_type() as it will specialize
entry = self.local_scope.lookup(node.name)
if entry and entry.is_type:
type = entry.type
return type
def specialize_type(self, type, pos): def specialize_type(self, type, pos):
try: try:
return type.specialize(self.local_scope.fused_to_specific) return type.specialize(self.local_scope.fused_to_specific)
......
cimport cython cimport cython
cimport check_fused_types_pxd
ctypedef char *string_t ctypedef char *string_t
ctypedef cython.fused_type(int, long, float, string_t) fused_t ctypedef cython.fused_type(int, long, float, string_t) fused_t
ctypedef cython.fused_type(int, long) other_t ctypedef cython.fused_type(int, long) other_t
ctypedef cython.fused_type(int, float) unresolved_t
cdef func(fused_t a, other_t b): cdef func(fused_t a, other_t b):
cdef int int_a cdef int int_a
...@@ -22,13 +22,13 @@ cdef func(fused_t a, other_t b): ...@@ -22,13 +22,13 @@ cdef func(fused_t a, other_t b):
print 'fused_t is string_t' print 'fused_t is string_t'
string_a = a string_a = a
if fused_t in unresolved_t: if fused_t in check_fused_types_pxd.unresolved_t:
print 'fused_t in unresolved_t' print 'fused_t in unresolved_t'
if int in unresolved_t: if int in check_fused_types_pxd.unresolved_t:
print 'int in unresolved_t' print 'int in unresolved_t'
if string_t in unresolved_t: if string_t in check_fused_types_pxd.unresolved_t:
print 'string_t in unresolved_t' print 'string_t in unresolved_t'
......
cimport cython
ctypedef cython.fused_type(int, float) unresolved_t
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment