Commit 86fe8811 authored by da-woods's avatar da-woods Committed by GitHub

Specialize fused function local variables specified with pure-python (GH-3463)

These were previously getting missed. Added code to specialize them and tests to prove it.

Fixes https://github.com/cython/cython/issues/3142

Also fixes https://github.com/cython/cython/issues/3460 - (seems related enough to go in the same PR)
parent 6a30fecf
...@@ -1953,6 +1953,8 @@ class NameNode(AtomicExprNode): ...@@ -1953,6 +1953,8 @@ class NameNode(AtomicExprNode):
_, atype = annotation.analyse_type_annotation(env) _, atype = annotation.analyse_type_annotation(env)
if atype is None: if atype is None:
atype = unspecified_type if as_target and env.directives['infer_types'] != False else py_object_type atype = unspecified_type if as_target and env.directives['infer_types'] != False else py_object_type
if atype.is_fused and env.fused_to_specific:
atype = atype.specialize(env.fused_to_specific)
self.entry = env.declare_var(name, atype, self.pos, is_cdef=not as_target) self.entry = env.declare_var(name, atype, self.pos, is_cdef=not as_target)
self.entry.annotation = annotation.expr self.entry.annotation = annotation.expr
...@@ -13593,10 +13595,16 @@ class AnnotationNode(ExprNode): ...@@ -13593,10 +13595,16 @@ class AnnotationNode(ExprNode):
# 2. The Cython use where the annotation can indicate an # 2. The Cython use where the annotation can indicate an
# object type # object type
# #
# doesn't handle the pre PEP-563 version where the # Doesn't handle the pre PEP-563 version where the
# annotation is evaluated into a Python Object # annotation is evaluated into a Python Object.
subexprs = [] subexprs = []
# 'untyped' is set for fused specializations:
# Once a fused function has been created we don't want
# annotations to override an already set type.
untyped = False
def __init__(self, pos, expr, string=None): def __init__(self, pos, expr, string=None):
"""string is expected to already be a StringNode or None""" """string is expected to already be a StringNode or None"""
ExprNode.__init__(self, pos) ExprNode.__init__(self, pos)
...@@ -13617,6 +13625,9 @@ class AnnotationNode(ExprNode): ...@@ -13617,6 +13625,9 @@ class AnnotationNode(ExprNode):
return self.analyse_type_annotation(env)[1] return self.analyse_type_annotation(env)[1]
def analyse_type_annotation(self, env, assigned_value=None): def analyse_type_annotation(self, env, assigned_value=None):
if self.untyped:
# Already applied as a fused type, not re-evaluating it here.
return None, None
annotation = self.expr annotation = self.expr
base_type = None base_type = None
is_ambiguous = False is_ambiguous = False
......
...@@ -220,6 +220,10 @@ class FusedCFuncDefNode(StatListNode): ...@@ -220,6 +220,10 @@ class FusedCFuncDefNode(StatListNode):
arg.type = arg.type.specialize(fused_to_specific) arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice: if arg.type.is_memoryviewslice:
arg.type.validate_memslice_dtype(arg.pos) arg.type.validate_memslice_dtype(arg.pos)
if arg.annotation:
# TODO might be nice if annotations were specialized instead?
# (Or might be hard to do reliably)
arg.annotation.untyped = True
def create_new_local_scope(self, node, env, f2s): def create_new_local_scope(self, node, env, f2s):
""" """
......
...@@ -1680,8 +1680,6 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1680,8 +1680,6 @@ class FuncDefNode(StatNode, BlockNode):
return arg return arg
if other_type is None: if other_type is None:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
elif other_type.is_fused and any(orig_type.same_as(t) for t in other_type.types):
pass # use specialized rather than fused type
elif orig_type is not py_object_type and not orig_type.same_as(other_type): elif orig_type is not py_object_type and not orig_type.same_as(other_type):
error(arg.base_type.pos, "Signature does not agree with previous declaration") error(arg.base_type.pos, "Signature does not agree with previous declaration")
error(type_node.pos, "Previous declaration here") error(type_node.pos, "Previous declaration here")
......
...@@ -1890,6 +1890,8 @@ if VALUE is not None: ...@@ -1890,6 +1890,8 @@ if VALUE is not None:
for var, type_node in node.directive_locals.items(): for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv) type = type_node.analyse_as_type(lenv)
if type and type.is_fused and lenv.fused_to_specific:
type = type.specialize(lenv.fused_to_specific)
if type: if type:
lenv.declare_var(var, type, type_node.pos) lenv.declare_var(var, type, type_node.pos)
else: else:
......
cimport cython
ctypedef fused NotInPy: ctypedef fused NotInPy:
int int
float float
cdef class TestCls: cdef class TestCls:
@cython.locals(loc = NotInPy)
cpdef cpfunc(self, NotInPy arg) cpdef cpfunc(self, NotInPy arg)
# mode: run # mode: run
# tag: fused, pure3.0 # tag: fused, pure3.6
#cython: annotation_typing=True #cython: annotation_typing=True
...@@ -17,10 +17,11 @@ class TestCls: ...@@ -17,10 +17,11 @@ class TestCls:
>>> TestCls().func1(2) >>> TestCls().func1(2)
'int' 'int'
""" """
loc: 'NotInPy' = arg
return cython.typeof(arg) return cython.typeof(arg)
if cython.compiled: if cython.compiled:
@cython.locals(arg = NotInPy) # NameError in pure Python @cython.locals(arg=NotInPy, loc=NotInPy) # NameError for 'NotInPy' in pure Python
def func2(self, arg): def func2(self, arg):
""" """
>>> TestCls().func2(1.0) >>> TestCls().func2(1.0)
...@@ -28,6 +29,7 @@ class TestCls: ...@@ -28,6 +29,7 @@ class TestCls:
>>> TestCls().func2(2) >>> TestCls().func2(2)
'int' 'int'
""" """
loc = arg
return cython.typeof(arg) return cython.typeof(arg)
def cpfunc(self, arg): def cpfunc(self, arg):
...@@ -37,6 +39,7 @@ class TestCls: ...@@ -37,6 +39,7 @@ class TestCls:
>>> TestCls().cpfunc(2) >>> TestCls().cpfunc(2)
'int' 'int'
""" """
loc = arg
return cython.typeof(arg) return cython.typeof(arg)
def func1_inpy(self, arg: InPy): def func1_inpy(self, arg: InPy):
...@@ -46,9 +49,10 @@ class TestCls: ...@@ -46,9 +49,10 @@ class TestCls:
>>> TestCls().func1_inpy(2) >>> TestCls().func1_inpy(2)
'int' 'int'
""" """
loc: InPy = arg
return cython.typeof(arg) return cython.typeof(arg)
@cython.locals(arg = InPy) @cython.locals(arg = InPy, loc = InPy)
def func2_inpy(self, arg): def func2_inpy(self, arg):
""" """
>>> TestCls().func2_inpy(1.0) >>> TestCls().func2_inpy(1.0)
...@@ -56,5 +60,5 @@ class TestCls: ...@@ -56,5 +60,5 @@ class TestCls:
>>> TestCls().func2_inpy(2) >>> TestCls().func2_inpy(2)
'int' 'int'
""" """
loc = arg
return cython.typeof(arg) return cython.typeof(arg)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment