Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Labels
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Commits
Open sidebar
nexedi
cython
Commits
47ce63c9
Commit
47ce63c9
authored
Apr 27, 2011
by
Mark Florisson
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Support for fused types in cdef functions
parent
b386862c
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
616 additions
and
60 deletions
+616
-60
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+17
-4
Cython/Compiler/ModuleNode.py
Cython/Compiler/ModuleNode.py
+48
-30
Cython/Compiler/Naming.py
Cython/Compiler/Naming.py
+1
-0
Cython/Compiler/Nodes.py
Cython/Compiler/Nodes.py
+164
-4
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/ParseTreeTransforms.py
+51
-6
Cython/Compiler/Parsing.py
Cython/Compiler/Parsing.py
+45
-7
Cython/Compiler/PyrexTypes.py
Cython/Compiler/PyrexTypes.py
+132
-9
Cython/Compiler/Symtab.py
Cython/Compiler/Symtab.py
+7
-0
Cython/Shadow.py
Cython/Shadow.py
+26
-0
tests/errors/fused_types.pyx
tests/errors/fused_types.pyx
+22
-0
tests/run/fused_types.pyx
tests/run/fused_types.pyx
+103
-0
No files found.
Cython/Compiler/ExprNodes.py
View file @
47ce63c9
...
@@ -2949,6 +2949,7 @@ class SimpleCallNode(CallNode):
...
@@ -2949,6 +2949,7 @@ class SimpleCallNode(CallNode):
else
:
else
:
for
arg
in
self
.
args
:
for
arg
in
self
.
args
:
arg
.
analyse_types
(
env
)
arg
.
analyse_types
(
env
)
if
self
.
self
and
func_type
.
args
:
if
self
.
self
and
func_type
.
args
:
# Coerce 'self' to the type expected by the method.
# Coerce 'self' to the type expected by the method.
self_arg
=
func_type
.
args
[
0
]
self_arg
=
func_type
.
args
[
0
]
...
@@ -2965,10 +2966,13 @@ class SimpleCallNode(CallNode):
...
@@ -2965,10 +2966,13 @@ class SimpleCallNode(CallNode):
def
function_type
(
self
):
def
function_type
(
self
):
# Return the type of the function being called, coercing a function
# Return the type of the function being called, coercing a function
# pointer to a function if necessary.
# pointer to a function if necessary. If the function has fused
# arguments, return the specific type.
func_type
=
self
.
function
.
type
func_type
=
self
.
function
.
type
if
func_type
.
is_ptr
:
if
func_type
.
is_ptr
:
func_type
=
func_type
.
base_type
func_type
=
func_type
.
base_type
return
func_type
return
func_type
def
is_simple
(
self
):
def
is_simple
(
self
):
...
@@ -2982,6 +2986,7 @@ class SimpleCallNode(CallNode):
...
@@ -2982,6 +2986,7 @@ class SimpleCallNode(CallNode):
if
self
.
function
.
type
is
error_type
:
if
self
.
function
.
type
is
error_type
:
self
.
type
=
error_type
self
.
type
=
error_type
return
return
if
self
.
function
.
type
.
is_cpp_class
:
if
self
.
function
.
type
.
is_cpp_class
:
overloaded_entry
=
self
.
function
.
type
.
scope
.
lookup
(
"operator()"
)
overloaded_entry
=
self
.
function
.
type
.
scope
.
lookup
(
"operator()"
)
if
overloaded_entry
is
None
:
if
overloaded_entry
is
None
:
...
@@ -2992,8 +2997,16 @@ class SimpleCallNode(CallNode):
...
@@ -2992,8 +2997,16 @@ class SimpleCallNode(CallNode):
overloaded_entry
=
self
.
function
.
entry
overloaded_entry
=
self
.
function
.
entry
else
:
else
:
overloaded_entry
=
None
overloaded_entry
=
None
if
overloaded_entry
:
if
overloaded_entry
:
entry
=
PyrexTypes
.
best_match
(
self
.
args
,
overloaded_entry
.
all_alternatives
(),
self
.
pos
)
if
overloaded_entry
.
fused_cfunction
:
specific_cdef_funcs
=
overloaded_entry
.
fused_cfunction
.
nodes
alternatives
=
[
n
.
entry
for
n
in
specific_cdef_funcs
]
else
:
alternatives
=
overloaded_entry
.
all_alternatives
()
entry
=
PyrexTypes
.
best_match
(
self
.
args
,
alternatives
,
self
.
pos
,
env
)
if
not
entry
:
if
not
entry
:
self
.
type
=
PyrexTypes
.
error_type
self
.
type
=
PyrexTypes
.
error_type
self
.
result_code
=
"<error>"
self
.
result_code
=
"<error>"
...
@@ -3130,8 +3143,8 @@ class SimpleCallNode(CallNode):
...
@@ -3130,8 +3143,8 @@ class SimpleCallNode(CallNode):
for
actual_arg
in
self
.
args
[
len
(
formal_args
):]:
for
actual_arg
in
self
.
args
[
len
(
formal_args
):]:
arg_list_code
.
append
(
actual_arg
.
result
())
arg_list_code
.
append
(
actual_arg
.
result
())
result
=
"%s(%s)"
%
(
self
.
function
.
result
(),
', '
.
join
(
arg_list_code
))
result
=
"%s(%s)"
%
(
self
.
function
.
result
(),
', '
.
join
(
arg_list_code
))
return
result
return
result
def
generate_result_code
(
self
,
code
):
def
generate_result_code
(
self
,
code
):
...
...
Cython/Compiler/ModuleNode.py
View file @
47ce63c9
...
@@ -156,13 +156,22 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
...
@@ -156,13 +156,22 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f
.
close
()
f
.
close
()
def
generate_public_declaration
(
self
,
entry
,
h_code
,
i_code
):
def
generate_public_declaration
(
self
,
entry
,
h_code
,
i_code
):
if
entry
.
fused_cfunction
:
for
cfunction
in
entry
.
fused_cfunction
.
nodes
:
self
.
_generate_public_declaration
(
cfunction
.
entry
,
cfunction
.
entry
.
cname
,
h_code
,
i_code
)
else
:
self
.
_generate_public_declaration
(
entry
,
entry
.
cname
,
h_code
,
i_code
)
def
_generate_public_declaration
(
self
,
entry
,
cname
,
h_code
,
i_code
):
h_code
.
putln
(
"%s %s;"
%
(
h_code
.
putln
(
"%s %s;"
%
(
Naming
.
extern_c_macro
,
Naming
.
extern_c_macro
,
entry
.
type
.
declaration_code
(
entry
.
type
.
declaration_code
(
entry
.
cname
,
dll_linkage
=
"DL_IMPORT"
)))
cname
,
dll_linkage
=
"DL_IMPORT"
)))
if
i_code
:
if
i_code
:
i_code
.
putln
(
"cdef extern %s"
%
i_code
.
putln
(
"cdef extern %s"
%
entry
.
type
.
declaration_code
(
entry
.
cname
,
pyrex
=
1
))
entry
.
type
.
declaration_code
(
cname
,
pyrex
=
1
))
def
api_name
(
self
,
env
):
def
api_name
(
self
,
env
):
return
env
.
qualified_name
.
replace
(
"."
,
"__"
)
return
env
.
qualified_name
.
replace
(
"."
,
"__"
)
...
@@ -987,34 +996,43 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
...
@@ -987,34 +996,43 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def
generate_cfunction_predeclarations
(
self
,
env
,
code
,
definition
):
def
generate_cfunction_predeclarations
(
self
,
env
,
code
,
definition
):
for
entry
in
env
.
cfunc_entries
:
for
entry
in
env
.
cfunc_entries
:
if
entry
.
inline_func_in_pxd
or
(
not
entry
.
in_cinclude
and
(
definition
if
entry
.
fused_cfunction
:
or
entry
.
defined_in_pxd
or
entry
.
visibility
==
'extern'
)):
for
node
in
entry
.
fused_cfunction
.
nodes
:
if
entry
.
visibility
==
'public'
:
self
.
_generate_cfunction_predeclaration
(
storage_class
=
"%s "
%
Naming
.
extern_c_macro
code
,
definition
,
node
.
entry
)
dll_linkage
=
"DL_EXPORT"
else
:
elif
entry
.
visibility
==
'extern'
:
self
.
_generate_cfunction_predeclaration
(
code
,
definition
,
entry
)
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
"DL_IMPORT"
elif
entry
.
visibility
==
'private'
:
def
_generate_cfunction_predeclaration
(
self
,
code
,
definition
,
entry
):
storage_class
=
"static "
if
entry
.
inline_func_in_pxd
or
(
not
entry
.
in_cinclude
and
(
definition
dll_linkage
=
None
or
entry
.
defined_in_pxd
or
entry
.
visibility
==
'extern'
)):
else
:
if
entry
.
visibility
==
'public'
:
storage_class
=
"static "
storage_class
=
"%s "
%
Naming
.
extern_c_macro
dll_linkage
=
None
dll_linkage
=
"DL_EXPORT"
type
=
entry
.
type
elif
entry
.
visibility
==
'extern'
:
storage_class
=
"%s "
%
Naming
.
extern_c_macro
if
not
definition
and
entry
.
defined_in_pxd
:
dll_linkage
=
"DL_IMPORT"
type
=
CPtrType
(
type
)
elif
entry
.
visibility
==
'private'
:
header
=
type
.
declaration_code
(
entry
.
cname
,
storage_class
=
"static "
dll_linkage
=
dll_linkage
)
dll_linkage
=
None
if
entry
.
func_modifiers
:
else
:
modifiers
=
"%s "
%
' '
.
join
(
entry
.
func_modifiers
).
upper
()
storage_class
=
"static "
else
:
dll_linkage
=
None
modifiers
=
''
type
=
entry
.
type
code
.
putln
(
"%s%s%s; /*proto*/"
%
(
storage_class
,
if
not
definition
and
entry
.
defined_in_pxd
:
modifiers
,
type
=
CPtrType
(
type
)
header
))
header
=
type
.
declaration_code
(
entry
.
cname
,
dll_linkage
=
dll_linkage
)
if
entry
.
func_modifiers
:
modifiers
=
"%s "
%
' '
.
join
(
entry
.
func_modifiers
).
upper
()
else
:
modifiers
=
''
code
.
putln
(
"%s%s%s; /*proto*/"
%
(
storage_class
,
modifiers
,
header
))
def
generate_typeobj_definitions
(
self
,
env
,
code
):
def
generate_typeobj_definitions
(
self
,
env
,
code
):
full_module_name
=
env
.
qualified_name
full_module_name
=
env
.
qualified_name
...
...
Cython/Compiler/Naming.py
View file @
47ce63c9
...
@@ -93,6 +93,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope"
...
@@ -93,6 +93,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope"
frame_cname
=
pyrex_prefix
+
"frame"
frame_cname
=
pyrex_prefix
+
"frame"
frame_code_cname
=
pyrex_prefix
+
"frame_code"
frame_code_cname
=
pyrex_prefix
+
"frame_code"
binding_cfunc
=
pyrex_prefix
+
"binding_PyCFunctionType"
binding_cfunc
=
pyrex_prefix
+
"binding_PyCFunctionType"
fused_func_prefix
=
pyrex_prefix
+
'fuse_'
genexpr_id_ref
=
'genexpr'
genexpr_id_ref
=
'genexpr'
...
...
Cython/Compiler/Nodes.py
View file @
47ce63c9
This diff is collapsed.
Click to expand it.
Cython/Compiler/ParseTreeTransforms.py
View file @
47ce63c9
...
@@ -610,8 +610,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -610,8 +610,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
'operator.comma'
:
ExprNodes
.
c_binop_constructor
(
','
),
'operator.comma'
:
ExprNodes
.
c_binop_constructor
(
','
),
}
}
special_methods
=
cython
.
set
([
'declare'
,
'union'
,
'struct'
,
'typedef'
,
'sizeof'
,
special_methods
=
cython
.
set
([
'declare'
,
'union'
,
'struct'
,
'typedef'
,
'cast'
,
'pointer'
,
'compiled'
,
'NULL'
])
'sizeof'
,
'cast'
,
'pointer'
,
'compiled'
,
'NULL'
,
'fused_type'
])
special_methods
.
update
(
unop_method_nodes
.
keys
())
special_methods
.
update
(
unop_method_nodes
.
keys
())
def
__init__
(
self
,
context
,
compilation_directive_defaults
):
def
__init__
(
self
,
context
,
compilation_directive_defaults
):
...
@@ -896,6 +897,36 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -896,6 +897,36 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return
self
.
visit_with_directives
(
node
.
body
,
directive_dict
)
return
self
.
visit_with_directives
(
node
.
body
,
directive_dict
)
return
self
.
visit_Node
(
node
)
return
self
.
visit_Node
(
node
)
def
visit_CTypeDefNode
(
self
,
node
):
"Don't skip ctypedefs"
self
.
visitchildren
(
node
)
return
node
def
visit_FusedTypeNode
(
self
,
node
):
"""
See if a function call expression in a ctypedef is actually
cython.fused_type()
"""
def
err
():
error
(
node
.
pos
,
"Can only fuse types with cython.fused_type()"
)
if
len
(
node
.
funcname
)
==
1
:
fused_type
,
=
node
.
funcname
else
:
cython_module
,
fused_type
=
node
.
funcname
wrong_module
=
cython_module
not
in
self
.
cython_module_names
if
wrong_module
or
fused_type
!=
u'fused_type'
:
err
()
return
node
if
not
self
.
directive_names
.
get
(
fused_type
):
err
()
return
node
class
WithTransform
(
CythonTransform
,
SkipDeclarations
):
class
WithTransform
(
CythonTransform
,
SkipDeclarations
):
# EXCINFO is manually set to a variable that contains
# EXCINFO is manually set to a variable that contains
...
@@ -1115,6 +1146,14 @@ if VALUE is not None:
...
@@ -1115,6 +1146,14 @@ if VALUE is not None:
return
node
return
node
def
visit_FuncDefNode
(
self
,
node
):
def
visit_FuncDefNode
(
self
,
node
):
"""
Analyse a function and its body, as that hasn't happend yet. Also
analyse the directive_locals set by @cython.locals(). Then, if we are
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
self
.
seen_vars_stack
.
append
(
cython
.
set
())
self
.
seen_vars_stack
.
append
(
cython
.
set
())
lenv
=
node
.
local_scope
lenv
=
node
.
local_scope
node
.
body
.
analyse_control_flow
(
lenv
)
# this will be totally refactored
node
.
body
.
analyse_control_flow
(
lenv
)
# this will be totally refactored
...
@@ -1126,10 +1165,16 @@ if VALUE is not None:
...
@@ -1126,10 +1165,16 @@ if VALUE is not None:
lenv
.
declare_var
(
var
,
type
,
type_node
.
pos
)
lenv
.
declare_var
(
var
,
type
,
type_node
.
pos
)
else
:
else
:
error
(
type_node
.
pos
,
"Not a type"
)
error
(
type_node
.
pos
,
"Not a type"
)
node
.
body
.
analyse_declarations
(
lenv
)
self
.
env_stack
.
append
(
lenv
)
if
node
.
has_fused_arguments
:
self
.
visitchildren
(
node
)
node
=
Nodes
.
FusedCFuncDefNode
(
node
,
self
.
env_stack
[
-
1
])
self
.
env_stack
.
pop
()
self
.
visitchildren
(
node
)
else
:
node
.
body
.
analyse_declarations
(
lenv
)
self
.
env_stack
.
append
(
lenv
)
self
.
visitchildren
(
node
)
self
.
env_stack
.
pop
()
self
.
seen_vars_stack
.
pop
()
self
.
seen_vars_stack
.
pop
()
return
node
return
node
...
...
Cython/Compiler/Parsing.py
View file @
47ce63c9
...
@@ -2572,6 +2572,24 @@ def p_c_func_or_var_declaration(s, pos, ctx):
...
@@ -2572,6 +2572,24 @@ def p_c_func_or_var_declaration(s, pos, ctx):
overridable
=
ctx
.
overridable
)
overridable
=
ctx
.
overridable
)
return
result
return
result
def
p_typelist
(
s
):
"""
parse a list of basic c types as part of a function call, like
cython.fused_type(int, long, double)
"""
types
=
[]
pos
=
s
.
position
()
while
s
.
sy
==
'IDENT'
:
types
.
append
(
p_c_base_type
(
s
))
if
s
.
sy
!=
','
:
if
s
.
sy
!=
')'
:
s
.
expect
(
','
)
break
s
.
next
()
return
Nodes
.
FusedTypeNode
(
pos
,
types
=
types
)
def
p_ctypedef_statement
(
s
,
ctx
):
def
p_ctypedef_statement
(
s
,
ctx
):
# s.sy == 'ctypedef'
# s.sy == 'ctypedef'
pos
=
s
.
position
()
pos
=
s
.
position
()
...
@@ -2588,17 +2606,37 @@ def p_ctypedef_statement(s, ctx):
...
@@ -2588,17 +2606,37 @@ def p_ctypedef_statement(s, ctx):
return
p_c_enum_definition
(
s
,
pos
,
ctx
)
return
p_c_enum_definition
(
s
,
pos
,
ctx
)
else
:
else
:
return
p_c_struct_or_union_definition
(
s
,
pos
,
ctx
)
return
p_c_struct_or_union_definition
(
s
,
pos
,
ctx
)
elif
looking_at_expr
(
s
):
# ctypedef cython.fused_types(int, long) integral
if
s
.
sy
==
'IDENT'
:
funcname
=
[
s
.
systring
]
s
.
next
()
if
s
.
systring
==
u'.'
:
s
.
next
()
funcname
.
append
(
s
.
systring
)
s
.
expect
(
'IDENT'
)
s
.
expect
(
'('
)
base_type
=
p_typelist
(
s
)
s
.
expect
(
')'
)
# Check if funcname equals cython.fused_types in
# InterpretCompilerDirectives
base_type
.
funcname
=
funcname
else
:
s
.
error
(
"Syntax error in ctypedef statement"
)
else
:
else
:
base_type
=
p_c_base_type
(
s
,
nonempty
=
1
)
base_type
=
p_c_base_type
(
s
,
nonempty
=
1
)
if
base_type
.
name
is
None
:
if
base_type
.
name
is
None
:
s
.
error
(
"Syntax error in ctypedef statement"
)
s
.
error
(
"Syntax error in ctypedef statement"
)
declarator
=
p_c_declarator
(
s
,
ctx
,
is_type
=
1
,
nonempty
=
1
)
s
.
expect_newline
(
"Syntax error in ctypedef statement"
)
declarator
=
p_c_declarator
(
s
,
ctx
,
is_type
=
1
,
nonempty
=
1
)
return
Nodes
.
CTypeDefNode
(
s
.
expect_newline
(
"Syntax error in ctypedef statement"
)
pos
,
base_type
=
base_type
,
return
Nodes
.
CTypeDefNode
(
declarator
=
declarator
,
pos
,
base_type
=
base_type
,
visibility
=
visibility
,
api
=
api
,
declarator
=
declarator
,
in_pxd
=
ctx
.
level
==
'module_pxd'
)
visibility
=
visibility
,
api
=
api
,
in_pxd
=
ctx
.
level
==
'module_pxd'
)
def
p_decorators
(
s
):
def
p_decorators
(
s
):
decorators
=
[]
decorators
=
[]
...
...
Cython/Compiler/PyrexTypes.py
View file @
47ce63c9
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
# Pyrex - Types
# Pyrex - Types
#
#
import
cython
from
Code
import
UtilityCode
from
Code
import
UtilityCode
import
StringEncoding
import
StringEncoding
import
Naming
import
Naming
...
@@ -12,6 +14,9 @@ class BaseType(object):
...
@@ -12,6 +14,9 @@ class BaseType(object):
#
#
# Base class for all Pyrex types including pseudo-types.
# Base class for all Pyrex types including pseudo-types.
# List of attribute names of any subtypes
subtypes
=
[]
def
can_coerce_to_pyobject
(
self
,
env
):
def
can_coerce_to_pyobject
(
self
,
env
):
return
False
return
False
...
@@ -27,6 +32,42 @@ class BaseType(object):
...
@@ -27,6 +32,42 @@ class BaseType(object):
else
:
else
:
return
base_code
return
base_code
def
__deepcopy__
(
self
,
memo
):
"""
Types never need to be copied, if we do copy, Unfortunate Things
Will Happen!
"""
return
self
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
):
if
self
.
subtypes
:
def
add_fused_types
(
types
):
for
type
in
types
or
():
if
type
not
in
seen
:
seen
.
add
(
type
)
result
.
append
(
type
)
if
result
is
None
:
result
=
[]
seen
=
cython
.
set
()
for
attr
in
self
.
subtypes
:
list_or_subtype
=
getattr
(
self
,
attr
)
if
isinstance
(
list_or_subtype
,
BaseType
):
list_or_subtype
.
get_fused_types
(
result
,
seen
)
else
:
for
subtype
in
list_or_subtype
:
subtype
.
get_fused_types
(
result
,
seen
)
return
result
return
None
is_fused
=
property
(
get_fused_types
,
doc
=
"Whether this type or any of its "
"subtypes is a fused type"
)
class
PyrexType
(
BaseType
):
class
PyrexType
(
BaseType
):
#
#
# Base class for all Pyrex types.
# Base class for all Pyrex types.
...
@@ -195,7 +236,8 @@ class CTypedefType(BaseType):
...
@@ -195,7 +236,8 @@ class CTypedefType(BaseType):
to_py_utility_code
=
None
to_py_utility_code
=
None
from_py_utility_code
=
None
from_py_utility_code
=
None
subtypes
=
[
'typedef_base_type'
]
def
__init__
(
self
,
name
,
base_type
,
cname
,
is_external
=
0
):
def
__init__
(
self
,
name
,
base_type
,
cname
,
is_external
=
0
):
assert
not
base_type
.
is_complex
assert
not
base_type
.
is_complex
...
@@ -314,6 +356,9 @@ class BufferType(BaseType):
...
@@ -314,6 +356,9 @@ class BufferType(BaseType):
is_buffer
=
1
is_buffer
=
1
writable
=
True
writable
=
True
subtypes
=
[
'dtype'
]
def
__init__
(
self
,
base
,
dtype
,
ndim
,
mode
,
negative_indices
,
cast
):
def
__init__
(
self
,
base
,
dtype
,
ndim
,
mode
,
negative_indices
,
cast
):
self
.
base
=
base
self
.
base
=
base
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -345,7 +390,7 @@ class PyObjectType(PyrexType):
...
@@ -345,7 +390,7 @@ class PyObjectType(PyrexType):
buffer_defaults
=
None
buffer_defaults
=
None
is_extern
=
False
is_extern
=
False
is_subclassed
=
False
is_subclassed
=
False
def
__str__
(
self
):
def
__str__
(
self
):
return
"Python object"
return
"Python object"
...
@@ -618,6 +663,45 @@ class CType(PyrexType):
...
@@ -618,6 +663,45 @@ class CType(PyrexType):
return
0
return
0
class
FusedType
(
CType
):
"""
Represents a Fused Type. All it needs to do is keep track of the types
it aggregates, as it will be replaced with its specific version wherever
needed.
See http://wiki.cython.org/enhancements/fusedtypes
types [CSimpleBaseTypeNode] is the list of types to be fused
name str the name of the ctypedef
"""
is_fused
=
1
def
__init__
(
self
,
types
):
self
.
types
=
types
def
declaration_code
(
self
,
entity_code
,
for_display
=
0
,
dll_linkage
=
None
,
pyrex
=
0
):
if
pyrex
or
for_display
:
return
self
.
name
raise
Exception
(
"This may never happen, please report a bug"
)
def
__repr__
(
self
):
return
'FusedType(name=%r)'
%
self
.
name
def
specialize
(
self
,
values
):
return
values
[
self
]
def
get_fused_types
(
self
,
result
=
None
,
seen
=
None
):
if
result
is
None
:
return
[
self
]
if
self
not
in
seen
:
result
.
append
(
self
)
seen
.
add
(
self
)
class
CVoidType
(
CType
):
class
CVoidType
(
CType
):
#
#
# C "void" type
# C "void" type
...
@@ -1531,7 +1615,9 @@ class CArrayType(CType):
...
@@ -1531,7 +1615,9 @@ class CArrayType(CType):
# size integer or None Number of elements
# size integer or None Number of elements
is_array
=
1
is_array
=
1
subtypes
=
[
'base_type'
]
def
__init__
(
self
,
base_type
,
size
):
def
__init__
(
self
,
base_type
,
size
):
self
.
base_type
=
base_type
self
.
base_type
=
base_type
self
.
size
=
size
self
.
size
=
size
...
@@ -1577,6 +1663,8 @@ class CPtrType(CType):
...
@@ -1577,6 +1663,8 @@ class CPtrType(CType):
is_ptr
=
1
is_ptr
=
1
default_value
=
"0"
default_value
=
"0"
subtypes
=
[
'base_type'
]
def
__init__
(
self
,
base_type
):
def
__init__
(
self
,
base_type
):
self
.
base_type
=
base_type
self
.
base_type
=
base_type
...
@@ -1675,7 +1763,9 @@ class CFuncType(CType):
...
@@ -1675,7 +1763,9 @@ class CFuncType(CType):
is_cfunction
=
1
is_cfunction
=
1
original_sig
=
None
original_sig
=
None
subtypes
=
[
'return_type'
,
'args'
]
def
__init__
(
self
,
return_type
,
args
,
has_varargs
=
0
,
def
__init__
(
self
,
return_type
,
args
,
has_varargs
=
0
,
exception_value
=
None
,
exception_check
=
0
,
calling_convention
=
""
,
exception_value
=
None
,
exception_check
=
0
,
calling_convention
=
""
,
nogil
=
0
,
with_gil
=
0
,
is_overridable
=
0
,
optional_arg_count
=
0
,
nogil
=
0
,
with_gil
=
0
,
is_overridable
=
0
,
optional_arg_count
=
0
,
...
@@ -1691,7 +1781,7 @@ class CFuncType(CType):
...
@@ -1691,7 +1781,7 @@ class CFuncType(CType):
self
.
with_gil
=
with_gil
self
.
with_gil
=
with_gil
self
.
is_overridable
=
is_overridable
self
.
is_overridable
=
is_overridable
self
.
templates
=
templates
self
.
templates
=
templates
def
__repr__
(
self
):
def
__repr__
(
self
):
arg_reprs
=
map
(
repr
,
self
.
args
)
arg_reprs
=
map
(
repr
,
self
.
args
)
if
self
.
has_varargs
:
if
self
.
has_varargs
:
...
@@ -1915,7 +2005,7 @@ class CFuncType(CType):
...
@@ -1915,7 +2005,7 @@ class CFuncType(CType):
return
self
.
op_arg_struct
.
base_type
.
scope
.
lookup
(
arg_name
).
cname
return
self
.
op_arg_struct
.
base_type
.
scope
.
lookup
(
arg_name
).
cname
class
CFuncTypeArg
(
object
):
class
CFuncTypeArg
(
BaseType
):
# name string
# name string
# cname string
# cname string
# type PyrexType
# type PyrexType
...
@@ -1926,6 +2016,8 @@ class CFuncTypeArg(object):
...
@@ -1926,6 +2016,8 @@ class CFuncTypeArg(object):
or_none
=
False
or_none
=
False
accept_none
=
True
accept_none
=
True
subtypes
=
[
'type'
]
def
__init__
(
self
,
name
,
type
,
pos
,
cname
=
None
):
def
__init__
(
self
,
name
,
type
,
pos
,
cname
=
None
):
self
.
name
=
name
self
.
name
=
name
if
cname
is
not
None
:
if
cname
is
not
None
:
...
@@ -2478,7 +2570,7 @@ def is_promotion(src_type, dst_type):
...
@@ -2478,7 +2570,7 @@ def is_promotion(src_type, dst_type):
return
src_type
.
is_float
and
src_type
.
rank
<=
dst_type
.
rank
return
src_type
.
is_float
and
src_type
.
rank
<=
dst_type
.
rank
return
False
return
False
def
best_match
(
args
,
functions
,
pos
=
None
):
def
best_match
(
args
,
functions
,
pos
=
None
,
env
=
None
):
"""
"""
Given a list args of arguments and a list of functions, choose one
Given a list args of arguments and a list of functions, choose one
to call which seems to be the "best" fit for this list of arguments.
to call which seems to be the "best" fit for this list of arguments.
...
@@ -2546,12 +2638,33 @@ def best_match(args, functions, pos=None):
...
@@ -2546,12 +2638,33 @@ def best_match(args, functions, pos=None):
possibilities
=
[]
possibilities
=
[]
bad_types
=
[]
bad_types
=
[]
needed_coercions
=
{}
for
func
,
func_type
in
candidates
:
for
func
,
func_type
in
candidates
:
score
=
[
0
,
0
,
0
]
score
=
[
0
,
0
,
0
]
for
i
in
range
(
min
(
len
(
args
),
len
(
func_type
.
args
))):
for
i
in
range
(
min
(
len
(
args
),
len
(
func_type
.
args
))):
src_type
=
args
[
i
].
type
src_type
=
args
[
i
].
type
dst_type
=
func_type
.
args
[
i
].
type
dst_type
=
func_type
.
args
[
i
].
type
if
dst_type
.
assignable_from
(
src_type
):
assignable
=
dst_type
.
assignable_from
(
src_type
)
# Now take care of normal string literals. So when you call a cdef
# function that takes a char *, the coercion will mean that the
# type will simply become bytes. We need to do this coercion
# manually for overloaded and fused functions
if
not
assignable
and
src_type
.
is_pyobject
:
if
(
src_type
.
is_builtin_type
and
src_type
.
name
==
'str'
and
dst_type
.
resolve
()
is
c_char_ptr_type
):
c_src_type
=
c_char_ptr_type
else
:
c_src_type
=
src_type
.
default_coerced_ctype
()
if
c_src_type
:
assignable
=
dst_type
.
assignable_from
(
c_src_type
)
if
assignable
:
src_type
=
c_src_type
needed_coercions
[
func
]
=
i
,
dst_type
if
assignable
:
if
src_type
==
dst_type
or
dst_type
.
same_as
(
src_type
):
if
src_type
==
dst_type
or
dst_type
.
same_as
(
src_type
):
pass
# score 0
pass
# score 0
elif
is_promotion
(
src_type
,
dst_type
):
elif
is_promotion
(
src_type
,
dst_type
):
...
@@ -2567,18 +2680,28 @@ def best_match(args, functions, pos=None):
...
@@ -2567,18 +2680,28 @@ def best_match(args, functions, pos=None):
break
break
else
:
else
:
possibilities
.
append
((
score
,
func
))
# so we can sort it
possibilities
.
append
((
score
,
func
))
# so we can sort it
if
possibilities
:
if
possibilities
:
possibilities
.
sort
()
possibilities
.
sort
()
if
len
(
possibilities
)
>
1
and
possibilities
[
0
][
0
]
==
possibilities
[
1
][
0
]:
if
len
(
possibilities
)
>
1
and
possibilities
[
0
][
0
]
==
possibilities
[
1
][
0
]:
if
pos
is
not
None
:
if
pos
is
not
None
:
error
(
pos
,
"ambiguous overloaded method"
)
error
(
pos
,
"ambiguous overloaded method"
)
return
None
return
None
return
possibilities
[
0
][
1
]
function
=
possibilities
[
0
][
1
]
if
function
in
needed_coercions
and
env
:
arg_i
,
coerce_to_type
=
needed_coercions
[
function
]
args
[
arg_i
]
=
args
[
arg_i
].
coerce_to
(
coerce_to_type
,
env
)
return
function
if
pos
is
not
None
:
if
pos
is
not
None
:
if
len
(
bad_types
)
==
1
:
if
len
(
bad_types
)
==
1
:
error
(
pos
,
bad_types
[
0
][
1
])
error
(
pos
,
bad_types
[
0
][
1
])
else
:
else
:
error
(
pos
,
"no suitable method found"
)
error
(
pos
,
"no suitable method found"
)
return
None
return
None
def
widest_numeric_type
(
type1
,
type2
):
def
widest_numeric_type
(
type1
,
type2
):
...
...
Cython/Compiler/Symtab.py
View file @
47ce63c9
...
@@ -176,6 +176,7 @@ class Entry(object):
...
@@ -176,6 +176,7 @@ class Entry(object):
buffer_aux = None
buffer_aux = None
prev_entry = None
prev_entry = None
might_overflow = 0
might_overflow = 0
fused_cfunction = None
def __init__(self, name, cname, type, pos = None, init = None):
def __init__(self, name, cname, type, pos = None, init = None):
self.name = name
self.name = name
...
@@ -241,6 +242,7 @@ class Scope(object):
...
@@ -241,6 +242,7 @@ class Scope(object):
scope_prefix = ""
scope_prefix = ""
in_cinclude = 0
in_cinclude = 0
nogil = 0
nogil = 0
fused_to_specific = None
def __init__(self, name, outer_scope, parent_scope):
def __init__(self, name, outer_scope, parent_scope):
# The outer_scope is the next scope in the lookup chain.
# The outer_scope is the next scope in the lookup chain.
...
@@ -279,6 +281,9 @@ class Scope(object):
...
@@ -279,6 +281,9 @@ class Scope(object):
self.return_type = None
self.return_type = None
self.id_counters = {}
self.id_counters = {}
def __deepcopy__(self, memo):
return self
def start_branching(self, pos):
def start_branching(self, pos):
self.control_flow = self.control_flow.start_branch(pos)
self.control_flow = self.control_flow.start_branch(pos)
...
@@ -677,6 +682,8 @@ class Scope(object):
...
@@ -677,6 +682,8 @@ class Scope(object):
def lookup_type(self, name):
def lookup_type(self, name):
entry = self.lookup(name)
entry = self.lookup(name)
if entry and entry.is_type:
if entry and entry.is_type:
if entry.type.is_fused and self.fused_to_specific:
return entry.type.specialize(self.fused_to_specific)
return entry.type
return entry.type
def lookup_operator(self, operator, operands):
def lookup_operator(self, operator, operands):
...
...
Cython/Shadow.py
View file @
47ce63c9
...
@@ -225,6 +225,30 @@ class typedef(CythonType):
...
@@ -225,6 +225,30 @@ class typedef(CythonType):
value
=
cast
(
self
.
_basetype
,
*
arg
)
value
=
cast
(
self
.
_basetype
,
*
arg
)
return
value
return
value
class
_FusedType
(
CythonType
):
def
__call__
(
self
,
type
,
value
):
return
value
def
fused_type
(
*
args
):
if
not
args
:
raise
TypeError
(
"Expected at least one type as argument"
)
rank
=
-
1
for
type
in
args
:
if
type
not
in
(
py_int
,
py_long
,
py_float
,
py_complex
):
break
if
type_ordering
.
index
(
type
)
>
rank
:
result_type
=
type
else
:
return
result_type
# Not a simple numeric type, return a fused type instance. The result
# isn't really meant to be used, as we can't keep track of the context in
# pure-mode. Casting won't do anything in this case.
return
_FusedType
()
py_int
=
int
py_int
=
int
...
@@ -277,3 +301,5 @@ for t in int_types + float_types + complex_types + other_types:
...
@@ -277,3 +301,5 @@ for t in int_types + float_types + complex_types + other_types:
void
=
typedef
(
None
)
void
=
typedef
(
None
)
NULL
=
p_void
(
0
)
NULL
=
p_void
(
0
)
type_ordering
=
[
py_int
,
py_long
,
py_float
,
py_complex
]
\ No newline at end of file
tests/errors/fused_types.pyx
0 → 100644
View file @
47ce63c9
# mode: error
cimport
cython
from
cython
import
fused_type
# This is all invalid
ctypedef
foo
(
int
)
dtype1
ctypedef
foo
.
bar
(
float
)
dtype2
ctypedef
fused_type
(
foo
)
dtype3
dtype4
=
cython
.
typedef
(
cython
.
fused_type
(
int
,
long
,
kw
=
None
))
# This is all valid
ctypedef
fused_type
(
int
,
long
,
float
)
dtype5
ctypedef
cython
.
fused_type
(
int
,
long
)
dtype6
_ERRORS
=
u"""
fused_types.pyx:7:13: Can only fuse types with cython.fused_type()
fused_types.pyx:8:17: Can only fuse types with cython.fused_type()
fused_types.pyx:9:20: 'foo' is not a type identifier
fused_types.pyx:10:23: fused_type does not take keyword arguments
"""
tests/run/fused_types.pyx
0 → 100644
View file @
47ce63c9
# mode: run
cimport
cython
from
cpython
cimport
Py_INCREF
from
Cython
import
Shadow
as
pure_cython
ctypedef
char
*
string_t
ctypedef
cython
.
fused_type
(
int
,
long
,
float
,
double
,
string_t
)
fused_type1
ctypedef
cython
.
fused_type
(
string_t
)
fused_type2
def
test_pure
():
"""
>>> test_pure()
(10+0j)
"""
mytype
=
pure_cython
.
typedef
(
pure_cython
.
fused_type
(
int
,
long
,
complex
))
print
mytype
(
10
)
cdef
cdef
_func_with_fused_args
(
fused_type1
x
,
fused_type1
y
,
fused_type2
z
):
print
x
,
y
,
z
return
x
+
y
def
test_cdef_func_with_fused_args
():
"""
>>> test_cdef_func_with_fused_args()
spam ham eggs
spamham
10 20 butter
30
4.2 8.6 bunny
12.8
"""
print
cdef
_func_with_fused_args
(
'spam'
,
'ham'
,
'eggs'
)
print
cdef
_func_with_fused_args
(
10
,
20
,
'butter'
)
print
cdef
_func_with_fused_args
(
4.2
,
8.6
,
'bunny'
)
cdef
fused_type1
fused_with_pointer
(
fused_type1
*
array
):
for
i
in
range
(
5
):
print
array
[
i
]
obj
=
array
[
0
]
+
array
[
1
]
+
array
[
2
]
+
array
[
3
]
+
array
[
4
]
# if cython.typeof(fused_type1) is string_t:
Py_INCREF
(
obj
)
return
obj
def
test_fused_with_pointer
():
"""
>>> test_fused_with_pointer()
0
1
2
3
4
10
<BLANKLINE>
0
1
2
3
4
10
<BLANKLINE>
0.0
1.0
2.0
3.0
4.0
10.0
<BLANKLINE>
humpty
dumpty
fall
splatch
breakfast
humptydumptyfallsplatchbreakfast
"""
cdef
int
int_array
[
5
]
cdef
long
long_array
[
5
]
cdef
float
float_array
[
5
]
cdef
string_t
string_array
[
5
]
cdef
char
*
s1
=
"humpty"
,
*
s2
=
"dumpty"
,
*
s3
=
"fall"
,
*
s4
=
"splatch"
,
*
s5
=
"breakfast"
strings
=
[
"humpty"
,
"dumpty"
,
"fall"
,
"splatch"
,
"breakfast"
]
for
i
in
range
(
5
):
int_array
[
i
]
=
i
long_array
[
i
]
=
i
float_array
[
i
]
=
i
s
=
strings
[
i
]
string_array
[
i
]
=
s
print
fused_with_pointer
(
int_array
)
print
print
fused_with_pointer
(
long_array
)
print
print
fused_with_pointer
(
float_array
)
print
print
fused_with_pointer
(
string_array
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment