Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Labels
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Commits
Open sidebar
nexedi
cython
Commits
1472e87b
Commit
1472e87b
authored
May 09, 2011
by
Mark Florisson
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Support fused cdef methods
parent
a8305590
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
229 additions
and
171 deletions
+229
-171
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+26
-19
Cython/Compiler/ModuleNode.py
Cython/Compiler/ModuleNode.py
+46
-80
Cython/Compiler/Nodes.py
Cython/Compiler/Nodes.py
+38
-10
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/ParseTreeTransforms.py
+5
-8
Cython/Compiler/PyrexTypes.py
Cython/Compiler/PyrexTypes.py
+47
-52
Cython/Compiler/Symtab.py
Cython/Compiler/Symtab.py
+11
-0
tests/run/fused_types.pyx
tests/run/fused_types.pyx
+1
-2
tests/run/public_fused_types.srctree
tests/run/public_fused_types.srctree
+55
-0
No files found.
Cython/Compiler/ExprNodes.py
View file @
1472e87b
...
@@ -595,7 +595,10 @@ class ExprNode(Node):
...
@@ -595,7 +595,10 @@ class ExprNode(Node):
for
signature
in
src_type
.
get_all_specific_function_types
():
for
signature
in
src_type
.
get_all_specific_function_types
():
if
signature
.
same_as
(
dst_type
):
if
signature
.
same_as
(
dst_type
):
return
CoerceFusedToSpecific
(
src
,
signature
)
src
.
type
=
signature
src
.
entry
=
src
.
type
.
entry
src
.
entry
.
used
=
True
return
self
error
(
self
.
pos
,
"Type is not specific"
)
error
(
self
.
pos
,
"Type is not specific"
)
self
.
type
=
error_type
self
.
type
=
error_type
...
@@ -2330,10 +2333,9 @@ class IndexNode(ExprNode):
...
@@ -2330,10 +2333,9 @@ class IndexNode(ExprNode):
NameNode with specific entry just after analysis of expressions by
NameNode with specific entry just after analysis of expressions by
AnalyseExpressionsTransform.
AnalyseExpressionsTransform.
"""
"""
base_type
=
self
.
base
.
type
self
.
type
=
PyrexTypes
.
error_type
self
.
type
=
PyrexTypes
.
error_type
base_type
=
self
.
base
.
type
specific_types
=
[]
specific_types
=
[]
positions
=
[]
positions
=
[]
...
@@ -2378,6 +2380,15 @@ class IndexNode(ExprNode):
...
@@ -2378,6 +2380,15 @@ class IndexNode(ExprNode):
for
signature
in
self
.
base
.
type
.
get_all_specific_function_types
():
for
signature
in
self
.
base
.
type
.
get_all_specific_function_types
():
if
type
.
same_as
(
signature
):
if
type
.
same_as
(
signature
):
self
.
type
=
signature
self
.
type
=
signature
if
self
.
base
.
is_attribute
:
# Pretend to be a normal attribute, for cdef extension
# methods
self
.
entry
=
signature
.
entry
self
.
is_attribute
=
self
.
base
.
is_attribute
self
.
obj
=
self
.
base
.
obj
self
.
entry
.
used
=
True
break
break
else
:
else
:
assert
False
assert
False
...
@@ -3026,11 +3037,13 @@ class SimpleCallNode(CallNode):
...
@@ -3026,11 +3037,13 @@ class SimpleCallNode(CallNode):
function
=
self
.
function
function
=
self
.
function
function
.
is_called
=
1
function
.
is_called
=
1
self
.
function
.
analyse_types
(
env
)
self
.
function
.
analyse_types
(
env
)
if
function
.
is_attribute
and
function
.
entry
and
function
.
entry
.
is_cmethod
:
if
function
.
is_attribute
and
function
.
entry
and
function
.
entry
.
is_cmethod
:
# Take ownership of the object from which the attribute
# Take ownership of the object from which the attribute
# was obtained, because we need to pass it as 'self'.
# was obtained, because we need to pass it as 'self'.
self
.
self
=
function
.
obj
self
.
self
=
function
.
obj
function
.
obj
=
CloneNode
(
self
.
self
)
function
.
obj
=
CloneNode
(
self
.
self
)
func_type
=
self
.
function_type
()
func_type
=
self
.
function_type
()
if
func_type
.
is_pyobject
:
if
func_type
.
is_pyobject
:
self
.
arg_tuple
=
TupleNode
(
self
.
pos
,
args
=
self
.
args
)
self
.
arg_tuple
=
TupleNode
(
self
.
pos
,
args
=
self
.
args
)
...
@@ -3111,15 +3124,13 @@ class SimpleCallNode(CallNode):
...
@@ -3111,15 +3124,13 @@ class SimpleCallNode(CallNode):
elif
(
isinstance
(
self
.
function
,
IndexNode
)
and
elif
(
isinstance
(
self
.
function
,
IndexNode
)
and
self
.
function
.
base
.
type
.
is_fused
):
self
.
function
.
base
.
type
.
is_fused
):
overloaded_entry
=
self
.
function
.
type
.
entry
overloaded_entry
=
self
.
function
.
type
.
entry
self
.
function
.
entry
=
self
.
function
.
type
.
entry
else
:
else
:
overloaded_entry
=
None
overloaded_entry
=
None
if
overloaded_entry
:
if
overloaded_entry
:
if
self
.
function
.
type
.
is_fused
:
if
self
.
function
.
type
.
is_fused
:
alternatives
=
[]
functypes
=
self
.
function
.
type
.
get_all_specific_function_types
()
PyrexTypes
.
map_with_specific_entries
(
self
.
function
.
entry
,
alternatives
=
[
f
.
entry
for
f
in
functypes
]
alternatives
.
append
)
else
:
else
:
alternatives
=
overloaded_entry
.
all_alternatives
()
alternatives
=
overloaded_entry
.
all_alternatives
()
...
@@ -3129,6 +3140,8 @@ class SimpleCallNode(CallNode):
...
@@ -3129,6 +3140,8 @@ class SimpleCallNode(CallNode):
self
.
type
=
PyrexTypes
.
error_type
self
.
type
=
PyrexTypes
.
error_type
self
.
result_code
=
"<error>"
self
.
result_code
=
"<error>"
return
return
entry
.
used
=
True
self
.
function
.
entry
=
entry
self
.
function
.
entry
=
entry
self
.
function
.
type
=
entry
.
type
self
.
function
.
type
=
entry
.
type
func_type
=
self
.
function_type
()
func_type
=
self
.
function_type
()
...
@@ -3812,6 +3825,12 @@ class AttributeNode(ExprNode):
...
@@ -3812,6 +3825,12 @@ class AttributeNode(ExprNode):
#print "...obj_code =", obj_code ###
#print "...obj_code =", obj_code ###
if
self
.
entry
and
self
.
entry
.
is_cmethod
:
if
self
.
entry
and
self
.
entry
.
is_cmethod
:
if
obj
.
type
.
is_extension_type
:
if
obj
.
type
.
is_extension_type
:
# If the attribute was specialized through indexing, make sure
# to get the right fused name, as our entry was replaced by our
# parent index node (AnalyseExpressionsTransform)
if
self
.
type
.
from_fused
:
self
.
member
=
self
.
entry
.
cname
return
"((struct %s *)%s%s%s)->%s"
%
(
return
"((struct %s *)%s%s%s)->%s"
%
(
obj
.
type
.
vtabstruct_cname
,
obj_code
,
self
.
op
,
obj
.
type
.
vtabstruct_cname
,
obj_code
,
self
.
op
,
obj
.
type
.
vtabslot_cname
,
self
.
member
)
obj
.
type
.
vtabslot_cname
,
self
.
member
)
...
@@ -7379,18 +7398,6 @@ class CoercionNode(ExprNode):
...
@@ -7379,18 +7398,6 @@ class CoercionNode(ExprNode):
file
,
line
,
col
=
self
.
pos
file
,
line
,
col
=
self
.
pos
code
.
annotate
((
file
,
line
,
col
-
1
),
AnnotationItem
(
style
=
'coerce'
,
tag
=
'coerce'
,
text
=
'[%s] to [%s]'
%
(
self
.
arg
.
type
,
self
.
type
)))
code
.
annotate
((
file
,
line
,
col
-
1
),
AnnotationItem
(
style
=
'coerce'
,
tag
=
'coerce'
,
text
=
'[%s] to [%s]'
%
(
self
.
arg
.
type
,
self
.
type
)))
class
CoerceFusedToSpecific
(
CoercionNode
):
def
__init__
(
self
,
arg
,
dst_type
):
super
(
CoerceFusedToSpecific
,
self
).
__init__
(
arg
)
self
.
type
=
dst_type
self
.
specialized_cname
=
dst_type
.
entry
.
cname
def
calculate_result_code
(
self
):
return
self
.
specialized_cname
def
generate_result_code
(
self
,
code
):
pass
class
CastNode
(
CoercionNode
):
class
CastNode
(
CoercionNode
):
# Wrap a node in a C type cast.
# Wrap a node in a C type cast.
...
...
Cython/Compiler/ModuleNode.py
View file @
1472e87b
...
@@ -156,19 +156,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
...
@@ -156,19 +156,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f
.
close
()
f
.
close
()
def
generate_public_declaration
(
self
,
entry
,
h_code
,
i_code
):
def
generate_public_declaration
(
self
,
entry
,
h_code
,
i_code
):
PyrexTypes
.
map_with_specific_entries
(
entry
,
self
.
_generate_public_declaration
,
h_code
,
i_code
)
def
_generate_public_declaration
(
self
,
entry
,
h_code
,
i_code
):
h_code
.
putln
(
"%s %s;"
%
(
h_code
.
putln
(
"%s %s;"
%
(
Naming
.
extern_c_macro
,
Naming
.
extern_c_macro
,
entry
.
type
.
declaration_code
(
entry
.
type
.
declaration_code
(
entry
.
cname
,
dll_linkage
=
"DL_IMPORT"
)))
entry
.
cname
,
dll_linkage
=
"DL_IMPORT"
)))
if
i_code
:
if
i_code
:
i_code
.
putln
(
"cdef extern %s"
%
i_code
.
putln
(
"cdef extern %s"
%
entry
.
type
.
declaration_code
(
cname
,
pyrex
=
1
))
entry
.
type
.
declaration_code
(
entry
.
cname
,
pyrex
=
1
))
def
api_name
(
self
,
env
):
def
api_name
(
self
,
env
):
return
env
.
qualified_name
.
replace
(
"."
,
"__"
)
return
env
.
qualified_name
.
replace
(
"."
,
"__"
)
...
@@ -992,39 +986,38 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
...
@@ -992,39 +986,38 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
dll_linkage
=
"DL_EXPORT"
,
definition
=
definition
)
dll_linkage
=
"DL_EXPORT"
,
definition
=
definition
)
def
generate_cfunction_predeclarations
(
self
,
env
,
code
,
definition
):
def
generate_cfunction_predeclarations
(
self
,
env
,
code
,
definition
):
func
=
self
.
_generate_cfunction_predeclaration
for
entry
in
env
.
cfunc_entries
:
for
entry
in
env
.
cfunc_entries
:
PyrexTypes
.
map_with_specific_entries
(
entry
,
func
,
code
,
definition
)
should_declare
=
(
not
entry
.
in_cinclude
and
(
definition
or
entry
.
defined_in_pxd
or
def
_generate_cfunction_predeclaration
(
self
,
entry
,
code
,
definition
):
entry
.
visibility
==
'extern'
))
if
entry
.
inline_func_in_pxd
or
(
not
entry
.
in_cinclude
and
(
definition
or
entry
.
defined_in_pxd
or
entry
.
visibility
==
'extern'
)
):
if
entry
.
used
and
(
entry
.
inline_func_in_pxd
or
should_declare
):
if
entry
.
visibility
==
'public'
:
if
entry
.
visibility
==
'public'
:
storage_class
=
"%s "
%
Naming
.
extern_c_macro
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
"DL_EXPORT"
dll_linkage
=
"DL_EXPORT"
elif
entry
.
visibility
==
'extern'
:
elif
entry
.
visibility
==
'extern'
:
storage_class
=
"%s "
%
Naming
.
extern_c_macro
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
"DL_IMPORT"
dll_linkage
=
"DL_IMPORT"
elif
entry
.
visibility
==
'private'
:
elif
entry
.
visibility
==
'private'
:
storage_class
=
"static "
storage_class
=
"static "
dll_linkage
=
None
dll_linkage
=
None
else
:
else
:
storage_class
=
"static "
storage_class
=
"static "
dll_linkage
=
None
dll_linkage
=
None
type
=
entry
.
type
type
=
entry
.
type
if
not
definition
and
entry
.
defined_in_pxd
:
if
not
definition
and
entry
.
defined_in_pxd
:
type
=
CPtrType
(
type
)
type
=
CPtrType
(
type
)
header
=
type
.
declaration_code
(
entry
.
cname
,
header
=
type
.
declaration_code
(
entry
.
cname
,
dll_linkage
=
dll_linkage
)
dll_linkage
=
dll_linkage
)
if
entry
.
func_modifiers
:
if
entry
.
func_modifiers
:
modifiers
=
"%s "
%
' '
.
join
(
entry
.
func_modifiers
).
upper
()
modifiers
=
"%s "
%
' '
.
join
(
entry
.
func_modifiers
).
upper
()
else
:
else
:
modifiers
=
''
modifiers
=
''
code
.
putln
(
"%s%s%s; /*proto*/"
%
(
code
.
putln
(
"%s%s%s; /*proto*/"
%
(
storage_class
,
storage_class
,
modifiers
,
modifiers
,
header
))
header
))
def
generate_typeobj_definitions
(
self
,
env
,
code
):
def
generate_typeobj_definitions
(
self
,
env
,
code
):
full_module_name
=
env
.
qualified_name
full_module_name
=
env
.
qualified_name
...
@@ -2048,28 +2041,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
...
@@ -2048,28 +2041,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def
generate_c_function_export_code
(
self
,
env
,
code
):
def
generate_c_function_export_code
(
self
,
env
,
code
):
# Generate code to create PyCFunction wrappers for exported C functions.
# Generate code to create PyCFunction wrappers for exported C functions.
func
=
self
.
_generate_c_function_export_code
for
entry
in
env
.
cfunc_entries
:
for
entry
in
env
.
cfunc_entries
:
from_fused
=
entry
.
type
.
is_fused
if
entry
.
api
or
entry
.
defined_in_pxd
:
if
entry
.
api
or
entry
.
defined_in_pxd
:
PyrexTypes
.
map_with_specific_entries
(
entry
,
func
,
env
,
env
.
use_utility_code
(
function_export_utility_code
)
code
,
from_fused
)
signature
=
entry
.
type
.
signature_string
()
code
.
putln
(
'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s'
%
(
def
_generate_c_function_export_code
(
self
,
entry
,
env
,
code
,
from_fused
):
entry
.
name
,
env
.
use_utility_code
(
function_export_utility_code
)
entry
.
cname
,
signature
=
entry
.
type
.
signature_string
()
signature
,
s
=
'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s'
code
.
error_goto
(
self
.
pos
)))
if
from_fused
:
# Specific version of a fused function. Fused functions can never
# be declared public or api, but they may need to be exported when
# declared in a .pxd. We need to give them a unique name in that
# case
name
=
entry
.
cname
else
:
name
=
entry
.
name
code
.
putln
(
s
%
(
name
,
entry
.
cname
,
signature
,
code
.
error_goto
(
self
.
pos
)))
def
generate_type_import_code_for_module
(
self
,
module
,
env
,
code
):
def
generate_type_import_code_for_module
(
self
,
module
,
env
,
code
):
# Generate type import code for all exported extension types in
# Generate type import code for all exported extension types in
...
@@ -2095,30 +2075,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
...
@@ -2095,30 +2075,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
module
.
qualified_name
,
module
.
qualified_name
,
temp
,
temp
,
code
.
error_goto
(
self
.
pos
)))
code
.
error_goto
(
self
.
pos
)))
for
entry
in
entries
:
for
entry
in
entries
:
PyrexTypes
.
map_with_specific_entries
(
entry
,
code
.
putln
(
self
.
_import_cdef_func
,
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s'
%
(
code
,
temp
,
temp
,
entry
.
name
,
entry
.
type
.
is_fused
)
entry
.
cname
,
entry
.
type
.
signature_string
(),
code
.
error_goto
(
self
.
pos
)))
code
.
putln
(
"Py_DECREF(%s); %s = 0;"
%
(
temp
,
temp
))
code
.
putln
(
"Py_DECREF(%s); %s = 0;"
%
(
temp
,
temp
))
def
_import_cdef_func
(
self
,
entry
,
code
,
temp
,
from_fused
):
if
from_fused
:
name
=
entry
.
cname
else
:
name
=
entry
.
name
code
.
putln
(
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s'
%
(
temp
,
name
,
entry
.
cname
,
entry
.
type
.
signature_string
(),
code
.
error_goto
(
self
.
pos
)))
def
generate_type_init_code
(
self
,
env
,
code
):
def
generate_type_init_code
(
self
,
env
,
code
):
# Generate type import code for extern extension types
# Generate type import code for extern extension types
# and type ready code for non-extern ones.
# and type ready code for non-extern ones.
...
...
Cython/Compiler/Nodes.py
View file @
1472e87b
...
@@ -1749,7 +1749,6 @@ class CFuncDefNode(FuncDefNode):
...
@@ -1749,7 +1749,6 @@ class CFuncDefNode(FuncDefNode):
inline_in_pxd
=
False
inline_in_pxd
=
False
decorators
=
None
decorators
=
None
directive_locals
=
None
directive_locals
=
None
cname_postfix
=
None
def
unqualified_name
(
self
):
def
unqualified_name
(
self
):
return
self
.
entry
.
name
return
self
.
entry
.
name
...
@@ -2050,6 +2049,8 @@ class FusedCFuncDefNode(StatListNode):
...
@@ -2050,6 +2049,8 @@ class FusedCFuncDefNode(StatListNode):
if
n
.
cfunc_declarator
.
optional_arg_count
:
if
n
.
cfunc_declarator
.
optional_arg_count
:
assert
n
.
type
.
op_arg_struct
assert
n
.
type
.
op_arg_struct
assert
n
.
type
.
entry
assert
node
.
type
.
is_fused
assert
node
.
type
.
is_fused
node
.
entry
.
fused_cfunction
=
self
node
.
entry
.
fused_cfunction
=
self
...
@@ -2065,19 +2066,29 @@ class FusedCFuncDefNode(StatListNode):
...
@@ -2065,19 +2066,29 @@ class FusedCFuncDefNode(StatListNode):
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations))
# len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations])
# import pprint; pprint.pprint([d for cname, d in permutations])
env
.
cfunc_entries
.
remove
(
self
.
node
.
entry
)
for
cname
,
fused_to_specific
in
permutations
:
for
cname
,
fused_to_specific
in
permutations
:
copied_node
=
copy
.
deepcopy
(
self
.
node
)
copied_node
=
copy
.
deepcopy
(
self
.
node
)
# Make the types in our CFuncType specific
# Make the types in our CFuncType specific
newtype
=
copied_node
.
type
.
specialize
(
fused_to_specific
)
type
=
copied_node
.
type
.
specialize
(
fused_to_specific
)
copied_node
.
type
=
newtype
entry
=
copied_node
.
entry
copied_node
.
entry
.
type
=
newtype
newtype
.
entry
=
copied_node
.
entry
copied_node
.
type
=
type
entry
.
type
,
type
.
entry
=
type
,
entry
self
.
node
.
cfunc_declarator
.
declare_optional_arg_struct
(
entry
.
used
=
(
entry
.
used
or
newtype
,
env
,
fused_cname
=
cname
)
self
.
node
.
entry
.
defined_in_pxd
or
env
.
is_c_class_scope
or
entry
.
is_cmethod
)
copied_node
.
return_type
=
newtype
.
return_type
if
self
.
node
.
cfunc_declarator
.
optional_arg_count
:
self
.
node
.
cfunc_declarator
.
declare_optional_arg_struct
(
type
,
env
,
fused_cname
=
cname
)
copied_node
.
return_type
=
type
.
return_type
copied_node
.
create_local_scope
(
env
)
copied_node
.
create_local_scope
(
env
)
copied_node
.
local_scope
.
fused_to_specific
=
fused_to_specific
copied_node
.
local_scope
.
fused_to_specific
=
fused_to_specific
...
@@ -2090,8 +2101,8 @@ class FusedCFuncDefNode(StatListNode):
...
@@ -2090,8 +2101,8 @@ class FusedCFuncDefNode(StatListNode):
for
arg
in
copied_node
.
cfunc_declarator
.
args
:
for
arg
in
copied_node
.
cfunc_declarator
.
args
:
arg
.
type
=
arg
.
type
.
specialize
(
fused_to_specific
)
arg
.
type
=
arg
.
type
.
specialize
(
fused_to_specific
)
cname
=
self
.
node
.
type
.
get_specific_cname
(
cname
)
type
.
specialize_entry
(
entry
,
cname
)
copied_node
.
entry
.
func_cname
=
copied_node
.
entry
.
cname
=
cname
env
.
cfunc_entries
.
append
(
entry
)
num_errors
=
Errors
.
num_errors
num_errors
=
Errors
.
num_errors
transform
=
ParseTreeTransforms
.
ReplaceFusedTypeChecks
(
transform
=
ParseTreeTransforms
.
ReplaceFusedTypeChecks
(
...
@@ -2101,6 +2112,23 @@ class FusedCFuncDefNode(StatListNode):
...
@@ -2101,6 +2112,23 @@ class FusedCFuncDefNode(StatListNode):
if
Errors
.
num_errors
>
num_errors
:
if
Errors
.
num_errors
>
num_errors
:
break
break
def
generate_function_definitions
(
self
,
env
,
code
):
for
stat
in
self
.
stats
:
# print stat.entry, stat.entry.used
if
stat
.
entry
.
used
:
stat
.
generate_function_definitions
(
env
,
code
)
def
generate_execution_code
(
self
,
code
):
for
stat
in
self
.
stats
:
if
stat
.
entry
.
used
:
code
.
mark_pos
(
stat
.
pos
)
stat
.
generate_execution_code
(
code
)
def
annotate
(
self
,
code
):
for
stat
in
self
.
stats
:
if
stat
.
entry
.
used
:
stat
.
annotate
(
code
)
class
PyArgDeclNode
(
Node
):
class
PyArgDeclNode
(
Node
):
# Argument which must be a Python object (used
# Argument which must be a Python object (used
...
...
Cython/Compiler/ParseTreeTransforms.py
View file @
1472e87b
...
@@ -1344,20 +1344,17 @@ class AnalyseExpressionsTransform(CythonTransform):
...
@@ -1344,20 +1344,17 @@ class AnalyseExpressionsTransform(CythonTransform):
def
visit_IndexNode
(
self
,
node
):
def
visit_IndexNode
(
self
,
node
):
"""
"""
Replace index nodes used to specialize cdef functions with fused
Replace index nodes used to specialize cdef functions with fused
argument types with a NameNode referring to the function with
argument types with the Attribute- or NameNode referring to the
specialized entry and type.
function. We then need to copy over the specialization properties to
the attribute or name node.
"""
"""
self
.
visit_Node
(
node
)
self
.
visit_Node
(
node
)
type
=
node
.
type
type
=
node
.
type
if
type
.
is_cfunction
and
node
.
base
.
type
.
is_fused
:
if
type
.
is_cfunction
and
node
.
base
.
type
.
is_fused
:
node
.
base
.
type
=
node
.
type
node
.
base
.
entry
=
node
.
type
.
entry
node
=
node
.
base
node
=
node
.
base
if
not
node
.
is_name
:
error
(
node
.
pos
,
"Can only index a fused function once"
)
node
.
type
=
PyrexTypes
.
error_type
else
:
node
.
type
=
type
node
.
entry
=
type
.
entry
return
node
return
node
...
...
Cython/Compiler/PyrexTypes.py
View file @
1472e87b
...
@@ -39,20 +39,14 @@ class BaseType(object):
...
@@ -39,20 +39,14 @@ class BaseType(object):
"""
"""
return
self
return
self
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
):
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
,
subtypes
=
None
):
if
self
.
subtypes
:
subtypes
=
subtypes
or
self
.
subtypes
if
subtypes
:
def
add_fused_types
(
types
):
for
type
in
types
or
():
if
type
not
in
seen
:
seen
.
add
(
type
)
result
.
append
(
type
)
if
result
is
None
:
if
result
is
None
:
result
=
[]
result
=
[]
seen
=
cython
.
set
()
seen
=
cython
.
set
()
for
attr
in
s
elf
.
s
ubtypes
:
for
attr
in
subtypes
:
list_or_subtype
=
getattr
(
self
,
attr
)
list_or_subtype
=
getattr
(
self
,
attr
)
if
isinstance
(
list_or_subtype
,
BaseType
):
if
isinstance
(
list_or_subtype
,
BaseType
):
...
@@ -1763,10 +1757,13 @@ class CFuncType(CType):
...
@@ -1763,10 +1757,13 @@ class CFuncType(CType):
# with_gil boolean Acquire gil around function body
# with_gil boolean Acquire gil around function body
# templates [string] or None
# templates [string] or None
# cached_specialized_types [CFuncType] cached specialized versions of the CFuncType if defined in a pxd
# cached_specialized_types [CFuncType] cached specialized versions of the CFuncType if defined in a pxd
# from_fused boolean Indicates whether this is a specialized
# C function
is_cfunction
=
1
is_cfunction
=
1
original_sig
=
None
original_sig
=
None
cached_specialized_types
=
None
cached_specialized_types
=
None
from_fused
=
False
subtypes
=
[
'return_type'
,
'args'
]
subtypes
=
[
'return_type'
,
'args'
]
...
@@ -1994,17 +1991,20 @@ class CFuncType(CType):
...
@@ -1994,17 +1991,20 @@ class CFuncType(CType):
else
:
else
:
new_templates
=
[
v
.
specialize
(
values
)
for
v
in
self
.
templates
]
new_templates
=
[
v
.
specialize
(
values
)
for
v
in
self
.
templates
]
return
CFuncType
(
self
.
return_type
.
specialize
(
values
),
result
=
CFuncType
(
self
.
return_type
.
specialize
(
values
),
[
arg
.
specialize
(
values
)
for
arg
in
self
.
args
],
[
arg
.
specialize
(
values
)
for
arg
in
self
.
args
],
has_varargs
=
0
,
has_varargs
=
0
,
exception_value
=
self
.
exception_value
,
exception_value
=
self
.
exception_value
,
exception_check
=
self
.
exception_check
,
exception_check
=
self
.
exception_check
,
calling_convention
=
self
.
calling_convention
,
calling_convention
=
self
.
calling_convention
,
nogil
=
self
.
nogil
,
nogil
=
self
.
nogil
,
with_gil
=
self
.
with_gil
,
with_gil
=
self
.
with_gil
,
is_overridable
=
self
.
is_overridable
,
is_overridable
=
self
.
is_overridable
,
optional_arg_count
=
self
.
optional_arg_count
,
optional_arg_count
=
self
.
optional_arg_count
,
templates
=
new_templates
)
templates
=
new_templates
)
result
.
from_fused
=
self
.
is_fused
return
result
def
opt_arg_cname
(
self
,
arg_name
):
def
opt_arg_cname
(
self
,
arg_name
):
return
self
.
op_arg_struct
.
base_type
.
scope
.
lookup
(
arg_name
).
cname
return
self
.
op_arg_struct
.
base_type
.
scope
.
lookup
(
arg_name
).
cname
...
@@ -2040,6 +2040,10 @@ class CFuncType(CType):
...
@@ -2040,6 +2040,10 @@ class CFuncType(CType):
elif
self
.
cached_specialized_types
is
not
None
:
elif
self
.
cached_specialized_types
is
not
None
:
return
self
.
cached_specialized_types
return
self
.
cached_specialized_types
cfunc_entries
=
self
.
entry
.
scope
.
cfunc_entries
cfunc_entries
.
remove
(
self
.
entry
)
result
=
[]
result
=
[]
permutations
=
self
.
get_all_specific_permutations
()
permutations
=
self
.
get_all_specific_permutations
()
for
cname
,
fused_to_specific
in
permutations
:
for
cname
,
fused_to_specific
in
permutations
:
...
@@ -2050,55 +2054,46 @@ class CFuncType(CType):
...
@@ -2050,55 +2054,46 @@ class CFuncType(CType):
self
.
declare_opt_arg_struct
(
new_func_type
,
cname
)
self
.
declare_opt_arg_struct
(
new_func_type
,
cname
)
new_entry
=
copy
.
deepcopy
(
self
.
entry
)
new_entry
=
copy
.
deepcopy
(
self
.
entry
)
new_
entry
.
cname
=
self
.
get_specific_cname
(
cname
)
new_
func_type
.
specialize_entry
(
new_entry
,
cname
)
new_entry
.
type
=
new_func_type
new_entry
.
type
=
new_func_type
new_func_type
.
entry
=
new_entry
new_func_type
.
entry
=
new_entry
result
.
append
(
new_func_type
)
result
.
append
(
new_func_type
)
cfunc_entries
.
append
(
new_entry
)
self
.
cached_specialized_types
=
result
self
.
cached_specialized_types
=
result
return
result
return
result
def
get_specific_cname
(
self
,
fused_cname
):
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
,
subtypes
=
None
):
"""
"Return fused types in the order they appear as parameter types"
Given the cname for a permutation of fused types, return the cname
return
super
(
CFuncType
,
self
).
get_fused_types
(
result
,
seen
,
for the corresponding function with specific types.
subtypes
=
[
'args'
])
"""
assert
self
.
is_fused
return
get_fused_cname
(
fused_cname
,
self
.
entry
.
func_cname
)
def
specialize_entry
(
self
,
entry
,
cname
):
assert
not
self
.
is_fused
entry
.
name
=
get_fused_cname
(
cname
,
entry
.
name
)
if
entry
.
is_cmethod
:
entry
.
cname
=
entry
.
name
if
entry
.
is_inherited
:
entry
.
cname
=
"%s.%s"
%
(
Naming
.
obj_base_cname
,
entry
.
cname
)
else
:
entry
.
cname
=
get_fused_cname
(
cname
,
entry
.
cname
)
if
entry
.
func_cname
:
entry
.
func_cname
=
get_fused_cname
(
cname
,
entry
.
func_cname
)
def
get_fused_cname
(
fused_cname
,
orig_cname
):
def
get_fused_cname
(
fused_cname
,
orig_cname
):
"""
"""
Given the fused cname id and an original cname, return a specialized cname
Given the fused cname id and an original cname, return a specialized cname
"""
"""
assert
fused_cname
and
orig_cname
return
'%s%s%s'
%
(
Naming
.
fused_func_prefix
,
fused_cname
,
orig_cname
)
return
'%s%s%s'
%
(
Naming
.
fused_func_prefix
,
fused_cname
,
orig_cname
)
def
map_with_specific_entries
(
entry
,
func
,
*
args
,
**
kwargs
):
"""
Call func for every specific function instance. If this is not a
signature with fused types, call it with the entry for this cdef
function.
"""
type
=
entry
.
type
if
type
.
is_cfunction
and
(
entry
.
fused_cfunction
or
type
.
is_fused
):
if
entry
.
fused_cfunction
:
# cdef with fused types defined in this file
for
cfunction
in
entry
.
fused_cfunction
.
nodes
:
func
(
cfunction
.
entry
,
*
args
,
**
kwargs
)
else
:
# cdef with fused types defined in another file, create their
# signatures
for
func_type
in
type
.
get_all_specific_function_types
():
func
(
func_type
.
entry
,
*
args
,
**
kwargs
)
else
:
# a normal cdef or not a c function
func
(
entry
,
*
args
,
**
kwargs
)
def
get_all_specific_permutations
(
fused_types
,
id
=
""
,
f2s
=
()):
def
get_all_specific_permutations
(
fused_types
,
id
=
""
,
f2s
=
()):
fused_type
=
fused_types
[
0
]
fused_type
=
fused_types
[
0
]
result
=
[]
result
=
[]
...
...
Cython/Compiler/Symtab.py
View file @
1472e87b
...
@@ -734,6 +734,7 @@ class Scope(object):
...
@@ -734,6 +734,7 @@ class Scope(object):
else:
else:
return outer.is_cpp()
return outer.is_cpp()
class PreImportScope(Scope):
class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname
namespace_cname = Naming.preimport_cname
...
@@ -1696,6 +1697,7 @@ class CClassScope(ClassScope):
...
@@ -1696,6 +1697,7 @@ class CClassScope(ClassScope):
if
defining
:
if
defining
:
entry
.
func_cname
=
self
.
mangle
(
Naming
.
func_prefix
,
name
)
entry
.
func_cname
=
self
.
mangle
(
Naming
.
func_prefix
,
name
)
entry
.
utility_code
=
utility_code
entry
.
utility_code
=
utility_code
type
.
entry
=
entry
return
entry
return
entry
def
add_cfunction
(
self
,
name
,
type
,
pos
,
cname
,
visibility
,
modifiers
):
def
add_cfunction
(
self
,
name
,
type
,
pos
,
cname
,
visibility
,
modifiers
):
...
@@ -1744,6 +1746,14 @@ class CClassScope(ClassScope):
...
@@ -1744,6 +1746,14 @@ class CClassScope(ClassScope):
base_entry
.
type
,
None
,
'private'
)
base_entry
.
type
,
None
,
'private'
)
entry
.
is_variable
=
1
entry
.
is_variable
=
1
self
.
inherited_var_entries
.
append
(
entry
)
self
.
inherited_var_entries
.
append
(
entry
)
# If the class defined in a pxd, specific entries have not been added.
# Ensure now that the parent (base) scope has specific entries
# Iterate over a copy as get_all_specific_function_types() will mutate
for
base_entry
in
base_scope
.
cfunc_entries
[:]:
if
base_entry
.
type
.
is_fused
:
base_entry
.
type
.
get_all_specific_function_types
()
for
base_entry
in
base_scope
.
cfunc_entries
:
for
base_entry
in
base_scope
.
cfunc_entries
:
entry
=
self
.
add_cfunction
(
base_entry
.
name
,
base_entry
.
type
,
entry
=
self
.
add_cfunction
(
base_entry
.
name
,
base_entry
.
type
,
base_entry
.
pos
,
adapt
(
base_entry
.
cname
),
base_entry
.
pos
,
adapt
(
base_entry
.
cname
),
...
@@ -1819,6 +1829,7 @@ class CppClassScope(Scope):
...
@@ -1819,6 +1829,7 @@ class CppClassScope(Scope):
if
prev_entry
:
if
prev_entry
:
entry
.
overloaded_alternatives
=
prev_entry
.
all_alternatives
()
entry
.
overloaded_alternatives
=
prev_entry
.
all_alternatives
()
entry
.
utility_code
=
utility_code
entry
.
utility_code
=
utility_code
type
.
entry
=
entry
return
entry
return
entry
def
declare_inherited_cpp_attributes
(
self
,
base_scope
):
def
declare_inherited_cpp_attributes
(
self
,
base_scope
):
...
...
tests/run/fused_types.pyx
View file @
1472e87b
...
@@ -169,8 +169,7 @@ cdef opt_args(integral x, floating y = 4.0):
...
@@ -169,8 +169,7 @@ cdef opt_args(integral x, floating y = 4.0):
def
test_opt_args
():
def
test_opt_args
():
"""
"""
ToDO: enable and fix
>>> test_opt_args()
test_opt_args()
3 4.0
3 4.0
3 4.0
3 4.0
3 4.0
3 4.0
...
...
tests/run/public_fused_types.srctree
View file @
1472e87b
...
@@ -39,6 +39,12 @@ cdef object_t add_simple(object_t obj, simple_t simple)
...
@@ -39,6 +39,12 @@ cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
cdef public_optional_args(object_t obj, simple_t simple = *)
cdef public_optional_args(object_t obj, simple_t simple = *)
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y)
######## header.h ########
######## header.h ########
typedef int extern_int;
typedef int extern_int;
...
@@ -58,6 +64,16 @@ cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
...
@@ -58,6 +64,16 @@ cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
cdef public_optional_args(object_t obj, simple_t simple = 6):
cdef public_optional_args(object_t obj, simple_t simple = 6):
return obj.a, simple
return obj.a, simple
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y):
if integral is int:
x += 1
if floating is double:
y += 2.0
return x + y
######## b.pyx ########
######## b.pyx ########
from a cimport *
from a cimport *
...
@@ -92,3 +108,42 @@ assert public_optional_args[mystruct_t, int](mystruct) == (5, 6)
...
@@ -92,3 +108,42 @@ assert public_optional_args[mystruct_t, int](mystruct) == (5, 6)
assert public_optional_args[mystruct_t, float](mystruct) == (5, 6.0)
assert public_optional_args[mystruct_t, float](mystruct) == (5, 6.0)
assert public_optional_args[mystruct_t, float](mystruct, 7.0) == (5, 7.0)
assert public_optional_args[mystruct_t, float](mystruct, 7.0) == (5, 7.0)
cdef TestFusedExtMethods obj = TestFusedExtMethods()
cdef int x = 4
cdef float y = 5.0
cdef long a = 6
cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method
assert func(obj, a, b) == 15.0
func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0
func = obj.method[long, double]
assert func(obj, a, y) == 13.0
assert obj.method(x, <double> a) == 13.0
assert obj.method[int, double](x, b) == 14.0
# Test inheritance
cdef class Subclass(TestFusedExtMethods):
cdef floating method(self, integral x, floating y):
return -x -y
cdef Subclass myobj = Subclass()
assert myobj.method[int, float](5, 5.0) == -10
cdef float (*meth)(Subclass, int, float)
meth = myobj.method
assert meth(myobj, 5, 5.0) == -10
meth = myobj.method[int, float]
assert meth(myobj, 5, 5.0) == -10
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment