Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Labels
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Commits
Open sidebar
nexedi
cython
Commits
1472e87b
Commit
1472e87b
authored
May 09, 2011
by
Mark Florisson
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Support fused cdef methods
parent
a8305590
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
229 additions
and
171 deletions
+229
-171
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+26
-19
Cython/Compiler/ModuleNode.py
Cython/Compiler/ModuleNode.py
+46
-80
Cython/Compiler/Nodes.py
Cython/Compiler/Nodes.py
+38
-10
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/ParseTreeTransforms.py
+5
-8
Cython/Compiler/PyrexTypes.py
Cython/Compiler/PyrexTypes.py
+47
-52
Cython/Compiler/Symtab.py
Cython/Compiler/Symtab.py
+11
-0
tests/run/fused_types.pyx
tests/run/fused_types.pyx
+1
-2
tests/run/public_fused_types.srctree
tests/run/public_fused_types.srctree
+55
-0
No files found.
Cython/Compiler/ExprNodes.py
View file @
1472e87b
...
...
@@ -595,7 +595,10 @@ class ExprNode(Node):
for
signature
in
src_type
.
get_all_specific_function_types
():
if
signature
.
same_as
(
dst_type
):
return
CoerceFusedToSpecific
(
src
,
signature
)
src
.
type
=
signature
src
.
entry
=
src
.
type
.
entry
src
.
entry
.
used
=
True
return
self
error
(
self
.
pos
,
"Type is not specific"
)
self
.
type
=
error_type
...
...
@@ -2330,10 +2333,9 @@ class IndexNode(ExprNode):
NameNode with specific entry just after analysis of expressions by
AnalyseExpressionsTransform.
"""
base_type
=
self
.
base
.
type
self
.
type
=
PyrexTypes
.
error_type
base_type
=
self
.
base
.
type
specific_types
=
[]
positions
=
[]
...
...
@@ -2378,6 +2380,15 @@ class IndexNode(ExprNode):
for
signature
in
self
.
base
.
type
.
get_all_specific_function_types
():
if
type
.
same_as
(
signature
):
self
.
type
=
signature
if
self
.
base
.
is_attribute
:
# Pretend to be a normal attribute, for cdef extension
# methods
self
.
entry
=
signature
.
entry
self
.
is_attribute
=
self
.
base
.
is_attribute
self
.
obj
=
self
.
base
.
obj
self
.
entry
.
used
=
True
break
else
:
assert
False
...
...
@@ -3026,11 +3037,13 @@ class SimpleCallNode(CallNode):
function
=
self
.
function
function
.
is_called
=
1
self
.
function
.
analyse_types
(
env
)
if
function
.
is_attribute
and
function
.
entry
and
function
.
entry
.
is_cmethod
:
# Take ownership of the object from which the attribute
# was obtained, because we need to pass it as 'self'.
self
.
self
=
function
.
obj
function
.
obj
=
CloneNode
(
self
.
self
)
func_type
=
self
.
function_type
()
if
func_type
.
is_pyobject
:
self
.
arg_tuple
=
TupleNode
(
self
.
pos
,
args
=
self
.
args
)
...
...
@@ -3111,15 +3124,13 @@ class SimpleCallNode(CallNode):
elif
(
isinstance
(
self
.
function
,
IndexNode
)
and
self
.
function
.
base
.
type
.
is_fused
):
overloaded_entry
=
self
.
function
.
type
.
entry
self
.
function
.
entry
=
self
.
function
.
type
.
entry
else
:
overloaded_entry
=
None
if
overloaded_entry
:
if
self
.
function
.
type
.
is_fused
:
alternatives
=
[]
PyrexTypes
.
map_with_specific_entries
(
self
.
function
.
entry
,
alternatives
.
append
)
functypes
=
self
.
function
.
type
.
get_all_specific_function_types
()
alternatives
=
[
f
.
entry
for
f
in
functypes
]
else
:
alternatives
=
overloaded_entry
.
all_alternatives
()
...
...
@@ -3129,6 +3140,8 @@ class SimpleCallNode(CallNode):
self
.
type
=
PyrexTypes
.
error_type
self
.
result_code
=
"<error>"
return
entry
.
used
=
True
self
.
function
.
entry
=
entry
self
.
function
.
type
=
entry
.
type
func_type
=
self
.
function_type
()
...
...
@@ -3812,6 +3825,12 @@ class AttributeNode(ExprNode):
#print "...obj_code =", obj_code ###
if
self
.
entry
and
self
.
entry
.
is_cmethod
:
if
obj
.
type
.
is_extension_type
:
# If the attribute was specialized through indexing, make sure
# to get the right fused name, as our entry was replaced by our
# parent index node (AnalyseExpressionsTransform)
if
self
.
type
.
from_fused
:
self
.
member
=
self
.
entry
.
cname
return
"((struct %s *)%s%s%s)->%s"
%
(
obj
.
type
.
vtabstruct_cname
,
obj_code
,
self
.
op
,
obj
.
type
.
vtabslot_cname
,
self
.
member
)
...
...
@@ -7379,18 +7398,6 @@ class CoercionNode(ExprNode):
file
,
line
,
col
=
self
.
pos
code
.
annotate
((
file
,
line
,
col
-
1
),
AnnotationItem
(
style
=
'coerce'
,
tag
=
'coerce'
,
text
=
'[%s] to [%s]'
%
(
self
.
arg
.
type
,
self
.
type
)))
class
CoerceFusedToSpecific
(
CoercionNode
):
def
__init__
(
self
,
arg
,
dst_type
):
super
(
CoerceFusedToSpecific
,
self
).
__init__
(
arg
)
self
.
type
=
dst_type
self
.
specialized_cname
=
dst_type
.
entry
.
cname
def
calculate_result_code
(
self
):
return
self
.
specialized_cname
def
generate_result_code
(
self
,
code
):
pass
class
CastNode
(
CoercionNode
):
# Wrap a node in a C type cast.
...
...
Cython/Compiler/ModuleNode.py
View file @
1472e87b
...
...
@@ -156,19 +156,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f
.
close
()
def
generate_public_declaration
(
self
,
entry
,
h_code
,
i_code
):
PyrexTypes
.
map_with_specific_entries
(
entry
,
self
.
_generate_public_declaration
,
h_code
,
i_code
)
def
_generate_public_declaration
(
self
,
entry
,
h_code
,
i_code
):
h_code
.
putln
(
"%s %s;"
%
(
Naming
.
extern_c_macro
,
entry
.
type
.
declaration_code
(
entry
.
cname
,
dll_linkage
=
"DL_IMPORT"
)))
entry
.
cname
,
dll_linkage
=
"DL_IMPORT"
)))
if
i_code
:
i_code
.
putln
(
"cdef extern %s"
%
entry
.
type
.
declaration_code
(
cname
,
pyrex
=
1
))
entry
.
type
.
declaration_code
(
entry
.
cname
,
pyrex
=
1
))
def
api_name
(
self
,
env
):
return
env
.
qualified_name
.
replace
(
"."
,
"__"
)
...
...
@@ -992,39 +986,38 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
dll_linkage
=
"DL_EXPORT"
,
definition
=
definition
)
def
generate_cfunction_predeclarations
(
self
,
env
,
code
,
definition
):
func
=
self
.
_generate_cfunction_predeclaration
for
entry
in
env
.
cfunc_entries
:
PyrexTypes
.
map_with_specific_entries
(
entry
,
func
,
code
,
definition
)
def
_generate_cfunction_predeclaration
(
self
,
entry
,
code
,
definition
):
if
entry
.
inline_func_in_pxd
or
(
not
entry
.
in_cinclude
and
(
definition
or
entry
.
defined_in_pxd
or
entry
.
visibility
==
'extern'
)
):
if
entry
.
visibility
==
'public'
:
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
"DL_EXPORT"
elif
entry
.
visibility
==
'extern'
:
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
"DL_IMPORT"
elif
entry
.
visibility
==
'private'
:
storage_class
=
"static "
dll_linkage
=
None
else
:
storage_class
=
"static "
dll_linkage
=
None
type
=
entry
.
type
should_declare
=
(
not
entry
.
in_cinclude
and
(
definition
or
entry
.
defined_in_pxd
or
entry
.
visibility
==
'extern'
))
if
entry
.
used
and
(
entry
.
inline_func_in_pxd
or
should_declare
):
if
entry
.
visibility
==
'public'
:
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
"DL_EXPORT"
elif
entry
.
visibility
==
'extern'
:
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
"DL_IMPORT"
elif
entry
.
visibility
==
'private'
:
storage_class
=
"static "
dll_linkage
=
None
else
:
storage_class
=
"static "
dll_linkage
=
None
type
=
entry
.
type
if
not
definition
and
entry
.
defined_in_pxd
:
type
=
CPtrType
(
type
)
header
=
type
.
declaration_code
(
entry
.
cname
,
dll_linkage
=
dll_linkage
)
if
entry
.
func_modifiers
:
modifiers
=
"%s "
%
' '
.
join
(
entry
.
func_modifiers
).
upper
()
else
:
modifiers
=
''
code
.
putln
(
"%s%s%s; /*proto*/"
%
(
storage_class
,
modifiers
,
header
))
if
not
definition
and
entry
.
defined_in_pxd
:
type
=
CPtrType
(
type
)
header
=
type
.
declaration_code
(
entry
.
cname
,
dll_linkage
=
dll_linkage
)
if
entry
.
func_modifiers
:
modifiers
=
"%s "
%
' '
.
join
(
entry
.
func_modifiers
).
upper
()
else
:
modifiers
=
''
code
.
putln
(
"%s%s%s; /*proto*/"
%
(
storage_class
,
modifiers
,
header
))
def
generate_typeobj_definitions
(
self
,
env
,
code
):
full_module_name
=
env
.
qualified_name
...
...
@@ -2048,28 +2041,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def
generate_c_function_export_code
(
self
,
env
,
code
):
# Generate code to create PyCFunction wrappers for exported C functions.
func
=
self
.
_generate_c_function_export_code
for
entry
in
env
.
cfunc_entries
:
from_fused
=
entry
.
type
.
is_fused
if
entry
.
api
or
entry
.
defined_in_pxd
:
PyrexTypes
.
map_with_specific_entries
(
entry
,
func
,
env
,
code
,
from_fused
)
def
_generate_c_function_export_code
(
self
,
entry
,
env
,
code
,
from_fused
):
env
.
use_utility_code
(
function_export_utility_code
)
signature
=
entry
.
type
.
signature_string
()
s
=
'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s'
if
from_fused
:
# Specific version of a fused function. Fused functions can never
# be declared public or api, but they may need to be exported when
# declared in a .pxd. We need to give them a unique name in that
# case
name
=
entry
.
cname
else
:
name
=
entry
.
name
code
.
putln
(
s
%
(
name
,
entry
.
cname
,
signature
,
code
.
error_goto
(
self
.
pos
)))
env
.
use_utility_code
(
function_export_utility_code
)
signature
=
entry
.
type
.
signature_string
()
code
.
putln
(
'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s'
%
(
entry
.
name
,
entry
.
cname
,
signature
,
code
.
error_goto
(
self
.
pos
)))
def
generate_type_import_code_for_module
(
self
,
module
,
env
,
code
):
# Generate type import code for all exported extension types in
...
...
@@ -2095,30 +2075,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
module
.
qualified_name
,
temp
,
code
.
error_goto
(
self
.
pos
)))
for
entry
in
entries
:
PyrexTypes
.
map_with_specific_entries
(
entry
,
self
.
_import_cdef_func
,
code
,
temp
,
entry
.
type
.
is_fused
)
code
.
putln
(
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s'
%
(
temp
,
entry
.
name
,
entry
.
cname
,
entry
.
type
.
signature_string
(),
code
.
error_goto
(
self
.
pos
)))
code
.
putln
(
"Py_DECREF(%s); %s = 0;"
%
(
temp
,
temp
))
def
_import_cdef_func
(
self
,
entry
,
code
,
temp
,
from_fused
):
if
from_fused
:
name
=
entry
.
cname
else
:
name
=
entry
.
name
code
.
putln
(
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s'
%
(
temp
,
name
,
entry
.
cname
,
entry
.
type
.
signature_string
(),
code
.
error_goto
(
self
.
pos
)))
def
generate_type_init_code
(
self
,
env
,
code
):
# Generate type import code for extern extension types
# and type ready code for non-extern ones.
...
...
Cython/Compiler/Nodes.py
View file @
1472e87b
...
...
@@ -1749,7 +1749,6 @@ class CFuncDefNode(FuncDefNode):
inline_in_pxd
=
False
decorators
=
None
directive_locals
=
None
cname_postfix
=
None
def
unqualified_name
(
self
):
return
self
.
entry
.
name
...
...
@@ -2050,6 +2049,8 @@ class FusedCFuncDefNode(StatListNode):
if
n
.
cfunc_declarator
.
optional_arg_count
:
assert
n
.
type
.
op_arg_struct
assert
n
.
type
.
entry
assert
node
.
type
.
is_fused
node
.
entry
.
fused_cfunction
=
self
...
...
@@ -2065,19 +2066,29 @@ class FusedCFuncDefNode(StatListNode):
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations])
env
.
cfunc_entries
.
remove
(
self
.
node
.
entry
)
for
cname
,
fused_to_specific
in
permutations
:
copied_node
=
copy
.
deepcopy
(
self
.
node
)
# Make the types in our CFuncType specific
newtype
=
copied_node
.
type
.
specialize
(
fused_to_specific
)
copied_node
.
type
=
newtype
copied_node
.
entry
.
type
=
newtype
newtype
.
entry
=
copied_node
.
entry
type
=
copied_node
.
type
.
specialize
(
fused_to_specific
)
entry
=
copied_node
.
entry
copied_node
.
type
=
type
entry
.
type
,
type
.
entry
=
type
,
entry
self
.
node
.
cfunc_declarator
.
declare_optional_arg_struct
(
newtype
,
env
,
fused_cname
=
cname
)
entry
.
used
=
(
entry
.
used
or
self
.
node
.
entry
.
defined_in_pxd
or
env
.
is_c_class_scope
or
entry
.
is_cmethod
)
copied_node
.
return_type
=
newtype
.
return_type
if
self
.
node
.
cfunc_declarator
.
optional_arg_count
:
self
.
node
.
cfunc_declarator
.
declare_optional_arg_struct
(
type
,
env
,
fused_cname
=
cname
)
copied_node
.
return_type
=
type
.
return_type
copied_node
.
create_local_scope
(
env
)
copied_node
.
local_scope
.
fused_to_specific
=
fused_to_specific
...
...
@@ -2090,8 +2101,8 @@ class FusedCFuncDefNode(StatListNode):
for
arg
in
copied_node
.
cfunc_declarator
.
args
:
arg
.
type
=
arg
.
type
.
specialize
(
fused_to_specific
)
cname
=
self
.
node
.
type
.
get_specific_cname
(
cname
)
copied_node
.
entry
.
func_cname
=
copied_node
.
entry
.
cname
=
cname
type
.
specialize_entry
(
entry
,
cname
)
env
.
cfunc_entries
.
append
(
entry
)
num_errors
=
Errors
.
num_errors
transform
=
ParseTreeTransforms
.
ReplaceFusedTypeChecks
(
...
...
@@ -2101,6 +2112,23 @@ class FusedCFuncDefNode(StatListNode):
if
Errors
.
num_errors
>
num_errors
:
break
def
generate_function_definitions
(
self
,
env
,
code
):
for
stat
in
self
.
stats
:
# print stat.entry, stat.entry.used
if
stat
.
entry
.
used
:
stat
.
generate_function_definitions
(
env
,
code
)
def
generate_execution_code
(
self
,
code
):
for
stat
in
self
.
stats
:
if
stat
.
entry
.
used
:
code
.
mark_pos
(
stat
.
pos
)
stat
.
generate_execution_code
(
code
)
def
annotate
(
self
,
code
):
for
stat
in
self
.
stats
:
if
stat
.
entry
.
used
:
stat
.
annotate
(
code
)
class
PyArgDeclNode
(
Node
):
# Argument which must be a Python object (used
...
...
Cython/Compiler/ParseTreeTransforms.py
View file @
1472e87b
...
...
@@ -1344,20 +1344,17 @@ class AnalyseExpressionsTransform(CythonTransform):
def
visit_IndexNode
(
self
,
node
):
"""
Replace index nodes used to specialize cdef functions with fused
argument types with a NameNode referring to the function with
specialized entry and type.
argument types with the Attribute- or NameNode referring to the
function. We then need to copy over the specialization properties to
the attribute or name node.
"""
self
.
visit_Node
(
node
)
type
=
node
.
type
if
type
.
is_cfunction
and
node
.
base
.
type
.
is_fused
:
node
.
base
.
type
=
node
.
type
node
.
base
.
entry
=
node
.
type
.
entry
node
=
node
.
base
if
not
node
.
is_name
:
error
(
node
.
pos
,
"Can only index a fused function once"
)
node
.
type
=
PyrexTypes
.
error_type
else
:
node
.
type
=
type
node
.
entry
=
type
.
entry
return
node
...
...
Cython/Compiler/PyrexTypes.py
View file @
1472e87b
...
...
@@ -39,20 +39,14 @@ class BaseType(object):
"""
return
self
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
):
if
self
.
subtypes
:
def
add_fused_types
(
types
):
for
type
in
types
or
():
if
type
not
in
seen
:
seen
.
add
(
type
)
result
.
append
(
type
)
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
,
subtypes
=
None
):
subtypes
=
subtypes
or
self
.
subtypes
if
subtypes
:
if
result
is
None
:
result
=
[]
seen
=
cython
.
set
()
for
attr
in
s
elf
.
s
ubtypes
:
for
attr
in
subtypes
:
list_or_subtype
=
getattr
(
self
,
attr
)
if
isinstance
(
list_or_subtype
,
BaseType
):
...
...
@@ -1763,10 +1757,13 @@ class CFuncType(CType):
# with_gil boolean Acquire gil around function body
# templates [string] or None
# cached_specialized_types [CFuncType] cached specialized versions of the CFuncType if defined in a pxd
# from_fused boolean Indicates whether this is a specialized
# C function
is_cfunction
=
1
original_sig
=
None
cached_specialized_types
=
None
from_fused
=
False
subtypes
=
[
'return_type'
,
'args'
]
...
...
@@ -1994,17 +1991,20 @@ class CFuncType(CType):
else
:
new_templates
=
[
v
.
specialize
(
values
)
for
v
in
self
.
templates
]
return
CFuncType
(
self
.
return_type
.
specialize
(
values
),
[
arg
.
specialize
(
values
)
for
arg
in
self
.
args
],
has_varargs
=
0
,
exception_value
=
self
.
exception_value
,
exception_check
=
self
.
exception_check
,
calling_convention
=
self
.
calling_convention
,
nogil
=
self
.
nogil
,
with_gil
=
self
.
with_gil
,
is_overridable
=
self
.
is_overridable
,
optional_arg_count
=
self
.
optional_arg_count
,
templates
=
new_templates
)
result
=
CFuncType
(
self
.
return_type
.
specialize
(
values
),
[
arg
.
specialize
(
values
)
for
arg
in
self
.
args
],
has_varargs
=
0
,
exception_value
=
self
.
exception_value
,
exception_check
=
self
.
exception_check
,
calling_convention
=
self
.
calling_convention
,
nogil
=
self
.
nogil
,
with_gil
=
self
.
with_gil
,
is_overridable
=
self
.
is_overridable
,
optional_arg_count
=
self
.
optional_arg_count
,
templates
=
new_templates
)
result
.
from_fused
=
self
.
is_fused
return
result
def
opt_arg_cname
(
self
,
arg_name
):
return
self
.
op_arg_struct
.
base_type
.
scope
.
lookup
(
arg_name
).
cname
...
...
@@ -2040,6 +2040,10 @@ class CFuncType(CType):
elif
self
.
cached_specialized_types
is
not
None
:
return
self
.
cached_specialized_types
cfunc_entries
=
self
.
entry
.
scope
.
cfunc_entries
cfunc_entries
.
remove
(
self
.
entry
)
result
=
[]
permutations
=
self
.
get_all_specific_permutations
()
for
cname
,
fused_to_specific
in
permutations
:
...
...
@@ -2050,55 +2054,46 @@ class CFuncType(CType):
self
.
declare_opt_arg_struct
(
new_func_type
,
cname
)
new_entry
=
copy
.
deepcopy
(
self
.
entry
)
new_
entry
.
cname
=
self
.
get_specific_cname
(
cname
)
new_
func_type
.
specialize_entry
(
new_entry
,
cname
)
new_entry
.
type
=
new_func_type
new_func_type
.
entry
=
new_entry
result
.
append
(
new_func_type
)
cfunc_entries
.
append
(
new_entry
)
self
.
cached_specialized_types
=
result
return
result
def
get_specific_cname
(
self
,
fused_cname
):
"""
Given the cname for a permutation of fused types, return the cname
for the corresponding function with specific types.
"""
assert
self
.
is_fused
return
get_fused_cname
(
fused_cname
,
self
.
entry
.
func_cname
)
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
,
subtypes
=
None
):
"Return fused types in the order they appear as parameter types"
return
super
(
CFuncType
,
self
).
get_fused_types
(
result
,
seen
,
subtypes
=
[
'args'
])
def
specialize_entry
(
self
,
entry
,
cname
):
assert
not
self
.
is_fused
entry
.
name
=
get_fused_cname
(
cname
,
entry
.
name
)
if
entry
.
is_cmethod
:
entry
.
cname
=
entry
.
name
if
entry
.
is_inherited
:
entry
.
cname
=
"%s.%s"
%
(
Naming
.
obj_base_cname
,
entry
.
cname
)
else
:
entry
.
cname
=
get_fused_cname
(
cname
,
entry
.
cname
)
if
entry
.
func_cname
:
entry
.
func_cname
=
get_fused_cname
(
cname
,
entry
.
func_cname
)
def
get_fused_cname
(
fused_cname
,
orig_cname
):
"""
Given the fused cname id and an original cname, return a specialized cname
"""
assert
fused_cname
and
orig_cname
return
'%s%s%s'
%
(
Naming
.
fused_func_prefix
,
fused_cname
,
orig_cname
)
def
map_with_specific_entries
(
entry
,
func
,
*
args
,
**
kwargs
):
"""
Call func for every specific function instance. If this is not a
signature with fused types, call it with the entry for this cdef
function.
"""
type
=
entry
.
type
if
type
.
is_cfunction
and
(
entry
.
fused_cfunction
or
type
.
is_fused
):
if
entry
.
fused_cfunction
:
# cdef with fused types defined in this file
for
cfunction
in
entry
.
fused_cfunction
.
nodes
:
func
(
cfunction
.
entry
,
*
args
,
**
kwargs
)
else
:
# cdef with fused types defined in another file, create their
# signatures
for
func_type
in
type
.
get_all_specific_function_types
():
func
(
func_type
.
entry
,
*
args
,
**
kwargs
)
else
:
# a normal cdef or not a c function
func
(
entry
,
*
args
,
**
kwargs
)
def
get_all_specific_permutations
(
fused_types
,
id
=
""
,
f2s
=
()):
fused_type
=
fused_types
[
0
]
result
=
[]
...
...
Cython/Compiler/Symtab.py
View file @
1472e87b
...
...
@@ -734,6 +734,7 @@ class Scope(object):
else:
return outer.is_cpp()
class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname
...
...
@@ -1696,6 +1697,7 @@ class CClassScope(ClassScope):
if
defining
:
entry
.
func_cname
=
self
.
mangle
(
Naming
.
func_prefix
,
name
)
entry
.
utility_code
=
utility_code
type
.
entry
=
entry
return
entry
def
add_cfunction
(
self
,
name
,
type
,
pos
,
cname
,
visibility
,
modifiers
):
...
...
@@ -1744,6 +1746,14 @@ class CClassScope(ClassScope):
base_entry
.
type
,
None
,
'private'
)
entry
.
is_variable
=
1
self
.
inherited_var_entries
.
append
(
entry
)
# If the class defined in a pxd, specific entries have not been added.
# Ensure now that the parent (base) scope has specific entries
# Iterate over a copy as get_all_specific_function_types() will mutate
for
base_entry
in
base_scope
.
cfunc_entries
[:]:
if
base_entry
.
type
.
is_fused
:
base_entry
.
type
.
get_all_specific_function_types
()
for
base_entry
in
base_scope
.
cfunc_entries
:
entry
=
self
.
add_cfunction
(
base_entry
.
name
,
base_entry
.
type
,
base_entry
.
pos
,
adapt
(
base_entry
.
cname
),
...
...
@@ -1819,6 +1829,7 @@ class CppClassScope(Scope):
if
prev_entry
:
entry
.
overloaded_alternatives
=
prev_entry
.
all_alternatives
()
entry
.
utility_code
=
utility_code
type
.
entry
=
entry
return
entry
def
declare_inherited_cpp_attributes
(
self
,
base_scope
):
...
...
tests/run/fused_types.pyx
View file @
1472e87b
...
...
@@ -169,8 +169,7 @@ cdef opt_args(integral x, floating y = 4.0):
def
test_opt_args
():
"""
ToDO: enable and fix
test_opt_args()
>>> test_opt_args()
3 4.0
3 4.0
3 4.0
...
...
tests/run/public_fused_types.srctree
View file @
1472e87b
...
...
@@ -39,6 +39,12 @@ cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
cdef public_optional_args(object_t obj, simple_t simple = *)
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y)
######## header.h ########
typedef int extern_int;
...
...
@@ -58,6 +64,16 @@ cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
cdef public_optional_args(object_t obj, simple_t simple = 6):
return obj.a, simple
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y):
if integral is int:
x += 1
if floating is double:
y += 2.0
return x + y
######## b.pyx ########
from a cimport *
...
...
@@ -92,3 +108,42 @@ assert public_optional_args[mystruct_t, int](mystruct) == (5, 6)
assert public_optional_args[mystruct_t, float](mystruct) == (5, 6.0)
assert public_optional_args[mystruct_t, float](mystruct, 7.0) == (5, 7.0)
cdef TestFusedExtMethods obj = TestFusedExtMethods()
cdef int x = 4
cdef float y = 5.0
cdef long a = 6
cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method
assert func(obj, a, b) == 15.0
func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0
func = obj.method[long, double]
assert func(obj, a, y) == 13.0
assert obj.method(x, <double> a) == 13.0
assert obj.method[int, double](x, b) == 14.0
# Test inheritance
cdef class Subclass(TestFusedExtMethods):
cdef floating method(self, integral x, floating y):
return -x -y
cdef Subclass myobj = Subclass()
assert myobj.method[int, float](5, 5.0) == -10
cdef float (*meth)(Subclass, int, float)
meth = myobj.method
assert meth(myobj, 5, 5.0) == -10
meth = myobj.method[int, float]
assert meth(myobj, 5, 5.0) == -10
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment