Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Boxiang Sun
cython
Commits
43f3d87d
Commit
43f3d87d
authored
Jul 24, 2012
by
Stefan Behnel
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix type inference for overloaded C++ operators
parent
eb801873
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
197 additions
and
126 deletions
+197
-126
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+46
-30
Cython/Compiler/Optimize.py
Cython/Compiler/Optimize.py
+7
-4
Cython/Compiler/PyrexTypes.py
Cython/Compiler/PyrexTypes.py
+18
-0
tests/run/cpp_operators.pyx
tests/run/cpp_operators.pyx
+102
-92
tests/run/cpp_type_inference.pyx
tests/run/cpp_type_inference.pyx
+24
-0
No files found.
Cython/Compiler/ExprNodes.py
View file @
43f3d87d
...
@@ -6988,6 +6988,13 @@ class UnopNode(ExprNode):
...
@@ -6988,6 +6988,13 @@ class UnopNode(ExprNode):
def
infer_type
(
self
,
env
):
def
infer_type
(
self
,
env
):
operand_type
=
self
.
operand
.
infer_type
(
env
)
operand_type
=
self
.
operand
.
infer_type
(
env
)
if
operand_type
.
is_cpp_class
or
operand_type
.
is_ptr
:
cpp_type
=
operand_type
.
find_cpp_operation_type
(
self
.
operator
)
if
cpp_type
is
not
None
:
return
cpp_type
return
self
.
infer_unop_type
(
env
,
operand_type
)
def
infer_unop_type
(
self
,
env
,
operand_type
):
if
operand_type
.
is_pyobject
:
if
operand_type
.
is_pyobject
:
return
py_object_type
return
py_object_type
else
:
else
:
...
@@ -7042,30 +7049,23 @@ class UnopNode(ExprNode):
...
@@ -7042,30 +7049,23 @@ class UnopNode(ExprNode):
self
.
type
=
PyrexTypes
.
error_type
self
.
type
=
PyrexTypes
.
error_type
def
analyse_cpp_operation
(
self
,
env
):
def
analyse_cpp_operation
(
self
,
env
):
type
=
self
.
operand
.
type
cpp_type
=
self
.
operand
.
type
.
find_cpp_operation_type
(
self
.
operator
)
if
type
.
is_ptr
:
if
cpp_type
is
None
:
type
=
type
.
base_type
error
(
self
.
pos
,
"'%s' operator not defined for %s"
%
(
function
=
type
.
scope
.
lookup
(
"operator%s"
%
self
.
operator
)
self
.
operator
,
type
))
if
not
function
:
error
(
self
.
pos
,
"'%s' operator not defined for %s"
%
(
self
.
operator
,
type
))
self
.
type_error
()
self
.
type_error
()
return
return
func_type
=
function
.
type
self
.
type
=
cpp_type
if
func_type
.
is_ptr
:
func_type
=
func_type
.
base_type
self
.
type
=
func_type
.
return_type
class
NotNode
(
Expr
Node
):
class
NotNode
(
Unop
Node
):
# 'not' operator
# 'not' operator
#
#
# operand ExprNode
# operand ExprNode
operator
=
'!'
type
=
PyrexTypes
.
c_bint_type
type
=
PyrexTypes
.
c_bint_type
subexprs
=
[
'operand'
]
def
calculate_constant_result
(
self
):
def
calculate_constant_result
(
self
):
self
.
constant_result
=
not
self
.
operand
.
constant_result
self
.
constant_result
=
not
self
.
operand
.
constant_result
...
@@ -7076,23 +7076,19 @@ class NotNode(ExprNode):
...
@@ -7076,23 +7076,19 @@ class NotNode(ExprNode):
except
Exception
,
e
:
except
Exception
,
e
:
self
.
compile_time_value_error
(
e
)
self
.
compile_time_value_error
(
e
)
def
infer_
type
(
self
,
env
):
def
infer_
unop_type
(
self
,
env
,
operand_type
):
return
PyrexTypes
.
c_bint_type
return
PyrexTypes
.
c_bint_type
def
analyse_types
(
self
,
env
):
def
analyse_types
(
self
,
env
):
self
.
operand
.
analyse_types
(
env
)
self
.
operand
.
analyse_types
(
env
)
if
self
.
operand
.
type
.
is_cpp_class
:
operand_type
=
self
.
operand
.
type
type
=
self
.
operand
.
type
if
operand_type
.
is_cpp_class
:
function
=
type
.
scope
.
lookup
(
"operator!"
)
cpp_type
=
operand_type
.
find_cpp_operation_type
(
self
.
operator
)
if
not
function
:
if
not
cpp_type
:
error
(
self
.
pos
,
"'!' operator not defined for %s"
error
(
self
.
pos
,
"'!' operator not defined for %s"
%
operand_type
)
%
(
type
))
self
.
type
=
PyrexTypes
.
error_type
self
.
type
=
PyrexTypes
.
error_type
return
return
func_type
=
function
.
type
self
.
type
=
cpp_type
if
func_type
.
is_ptr
:
func_type
=
func_type
.
base_type
self
.
type
=
func_type
.
return_type
else
:
else
:
self
.
operand
=
self
.
operand
.
coerce_to_boolean
(
env
)
self
.
operand
=
self
.
operand
.
coerce_to_boolean
(
env
)
...
@@ -7181,6 +7177,12 @@ class DereferenceNode(CUnopNode):
...
@@ -7181,6 +7177,12 @@ class DereferenceNode(CUnopNode):
operator
=
'*'
operator
=
'*'
def
infer_unop_type
(
self
,
env
,
operand_type
):
if
operand_type
.
is_ptr
:
return
operand_type
.
base_type
else
:
return
PyrexTypes
.
error_type
def
analyse_c_operation
(
self
,
env
):
def
analyse_c_operation
(
self
,
env
):
if
self
.
operand
.
type
.
is_ptr
:
if
self
.
operand
.
type
.
is_ptr
:
self
.
type
=
self
.
operand
.
type
.
base_type
self
.
type
=
self
.
operand
.
type
.
base_type
...
@@ -7213,19 +7215,23 @@ def inc_dec_constructor(is_prefix, operator):
...
@@ -7213,19 +7215,23 @@ def inc_dec_constructor(is_prefix, operator):
return
lambda
pos
,
**
kwds
:
DecrementIncrementNode
(
pos
,
is_prefix
=
is_prefix
,
operator
=
operator
,
**
kwds
)
return
lambda
pos
,
**
kwds
:
DecrementIncrementNode
(
pos
,
is_prefix
=
is_prefix
,
operator
=
operator
,
**
kwds
)
class
AmpersandNode
(
Expr
Node
):
class
AmpersandNode
(
CUnop
Node
):
# The C address-of operator.
# The C address-of operator.
#
#
# operand ExprNode
# operand ExprNode
operator
=
'&'
subexprs
=
[
'operand'
]
def
infer_unop_type
(
self
,
env
,
operand_type
):
return
PyrexTypes
.
c_ptr_type
(
operand_type
)
def
infer_type
(
self
,
env
):
return
PyrexTypes
.
c_ptr_type
(
self
.
operand
.
infer_type
(
env
))
def
analyse_types
(
self
,
env
):
def
analyse_types
(
self
,
env
):
self
.
operand
.
analyse_types
(
env
)
self
.
operand
.
analyse_types
(
env
)
argtype
=
self
.
operand
.
type
argtype
=
self
.
operand
.
type
if
argtype
.
is_cpp_class
:
cpp_type
=
argtype
.
find_cpp_operation_type
(
self
.
operator
)
if
cpp_type
is
not
None
:
self
.
type
=
cpp_type
return
if
not
(
argtype
.
is_cfunction
or
argtype
.
is_reference
or
self
.
operand
.
is_addressable
()):
if
not
(
argtype
.
is_cfunction
or
argtype
.
is_reference
or
self
.
operand
.
is_addressable
()):
if
argtype
.
is_memoryviewslice
:
if
argtype
.
is_memoryviewslice
:
self
.
error
(
"Cannot take address of memoryview slice"
)
self
.
error
(
"Cannot take address of memoryview slice"
)
...
@@ -7932,6 +7938,16 @@ class CBinopNode(BinopNode):
...
@@ -7932,6 +7938,16 @@ class CBinopNode(BinopNode):
self
.
operator
,
self
.
operator
,
self
.
operand2
.
result
())
self
.
operand2
.
result
())
def
compute_c_result_type
(
self
,
type1
,
type2
):
cpp_type
=
None
if
type1
.
is_cpp_class
or
type1
.
is_ptr
:
cpp_type
=
type1
.
find_cpp_operation_type
(
self
.
operator
,
type2
)
# FIXME: handle the reversed case?
#if cpp_type is None and (type2.is_cpp_class or type2.is_ptr):
# cpp_type = type2.find_cpp_operation_type(self.operator, type1)
# FIXME: do we need to handle other cases here?
return
cpp_type
def
c_binop_constructor
(
operator
):
def
c_binop_constructor
(
operator
):
def
make_binop_node
(
pos
,
**
operands
):
def
make_binop_node
(
pos
,
**
operands
):
...
...
Cython/Compiler/Optimize.py
View file @
43f3d87d
...
@@ -2966,11 +2966,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
...
@@ -2966,11 +2966,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return
node
return
node
if
not
node
.
operand
.
is_literal
:
if
not
node
.
operand
.
is_literal
:
return
node
return
node
if
isinstance
(
node
.
operand
,
ExprNodes
.
BoolNode
):
if
isinstance
(
node
,
ExprNodes
.
NotNode
):
return
ExprNodes
.
IntNode
(
node
.
pos
,
value
=
str
(
node
.
constant_result
),
return
ExprNodes
.
BoolNode
(
node
.
pos
,
value
=
bool
(
node
.
constant_result
),
constant_result
=
bool
(
node
.
constant_result
))
elif
isinstance
(
node
.
operand
,
ExprNodes
.
BoolNode
):
return
ExprNodes
.
IntNode
(
node
.
pos
,
value
=
str
(
int
(
node
.
constant_result
)),
type
=
PyrexTypes
.
c_int_type
,
type
=
PyrexTypes
.
c_int_type
,
constant_result
=
node
.
constant_result
)
constant_result
=
int
(
node
.
constant_result
)
)
if
node
.
operator
==
'+'
:
el
if
node
.
operator
==
'+'
:
return
self
.
_handle_UnaryPlusNode
(
node
)
return
self
.
_handle_UnaryPlusNode
(
node
)
elif
node
.
operator
==
'-'
:
elif
node
.
operator
==
'-'
:
return
self
.
_handle_UnaryMinusNode
(
node
)
return
self
.
_handle_UnaryMinusNode
(
node
)
...
...
Cython/Compiler/PyrexTypes.py
View file @
43f3d87d
...
@@ -2292,6 +2292,11 @@ class CPtrType(CPointerBaseType):
...
@@ -2292,6 +2292,11 @@ class CPtrType(CPointerBaseType):
def
invalid_value
(
self
):
def
invalid_value
(
self
):
return
"1"
return
"1"
def
find_cpp_operation_type
(
self
,
operator
,
operand_type
=
None
):
if
self
.
base_type
.
is_cpp_class
:
return
self
.
base_type
.
find_cpp_operation_type
(
operator
,
operand_type
=
None
)
return
None
class
CNullPtrType
(
CPtrType
):
class
CNullPtrType
(
CPtrType
):
is_null_ptr
=
1
is_null_ptr
=
1
...
@@ -3164,6 +3169,19 @@ class CppClassType(CType):
...
@@ -3164,6 +3169,19 @@ class CppClassType(CType):
def
attributes_known
(
self
):
def
attributes_known
(
self
):
return
self
.
scope
is
not
None
return
self
.
scope
is
not
None
def
find_cpp_operation_type
(
self
,
operator
,
operand_type
=
None
):
operands
=
[
self
]
if
operand_type
is
not
None
:
operands
.
append
(
operand_type
)
# pos == None => no errors
operator_entry
=
self
.
scope
.
lookup_operator_for_types
(
None
,
operator
,
operands
)
if
not
operator_entry
:
return
None
func_type
=
operator_entry
.
type
if
func_type
.
is_ptr
:
func_type
=
func_type
.
base_type
return
func_type
.
return_type
class
TemplatePlaceholderType
(
CType
):
class
TemplatePlaceholderType
(
CType
):
...
...
tests/run/cpp_operators.pyx
View file @
43f3d87d
# tag: cpp
# tag: cpp
from
cython
cimport
typeof
cimport
cython.operator
cimport
cython.operator
from
cython.operator
cimport
dereference
as
deref
from
cython.operator
cimport
dereference
as
deref
cdef
out
(
s
):
from
libc.string
cimport
const_char
print
s
.
decode
(
'ASCII'
)
cdef
out
(
s
,
result_type
=
None
):
print
'%s [%s]'
%
(
s
.
decode
(
'ascii'
),
result_type
)
cdef
extern
from
"cpp_operators_helper.h"
:
cdef
extern
from
"cpp_operators_helper.h"
:
cdef
cppclass
TestOps
:
cdef
cppclass
TestOps
:
char
*
operator
+
()
c
onst_c
har
*
operator
+
()
char
*
operator
-
()
c
onst_c
har
*
operator
-
()
char
*
operator
*
()
c
onst_c
har
*
operator
*
()
char
*
operator
~
()
c
onst_c
har
*
operator
~
()
char
*
operator
!
()
c
onst_c
har
*
operator
!
()
char
*
operator
++
()
c
onst_c
har
*
operator
++
()
char
*
operator
--
()
c
onst_c
har
*
operator
--
()
char
*
operator
++
(
int
)
c
onst_c
har
*
operator
++
(
int
)
char
*
operator
--
(
int
)
c
onst_c
har
*
operator
--
(
int
)
char
*
operator
+
(
int
)
c
onst_c
har
*
operator
+
(
int
)
char
*
operator
-
(
int
)
c
onst_c
har
*
operator
-
(
int
)
char
*
operator
*
(
int
)
c
onst_c
har
*
operator
*
(
int
)
char
*
operator
/
(
int
)
c
onst_c
har
*
operator
/
(
int
)
char
*
operator
%
(
int
)
c
onst_c
har
*
operator
%
(
int
)
char
*
operator
|
(
int
)
c
onst_c
har
*
operator
|
(
int
)
char
*
operator
&
(
int
)
c
onst_c
har
*
operator
&
(
int
)
char
*
operator
^
(
int
)
c
onst_c
har
*
operator
^
(
int
)
char
*
operator
,(
int
)
c
onst_c
har
*
operator
,(
int
)
char
*
operator
<<
(
int
)
c
onst_c
har
*
operator
<<
(
int
)
char
*
operator
>>
(
int
)
c
onst_c
har
*
operator
>>
(
int
)
char
*
operator
==
(
int
)
c
onst_c
har
*
operator
==
(
int
)
char
*
operator
!=
(
int
)
c
onst_c
har
*
operator
!=
(
int
)
char
*
operator
>=
(
int
)
c
onst_c
har
*
operator
>=
(
int
)
char
*
operator
<=
(
int
)
c
onst_c
har
*
operator
<=
(
int
)
char
*
operator
>
(
int
)
c
onst_c
har
*
operator
>
(
int
)
char
*
operator
<
(
int
)
c
onst_c
har
*
operator
<
(
int
)
char
*
operator
[](
int
)
c
onst_c
har
*
operator
[](
int
)
char
*
operator
()(
int
)
c
onst_c
har
*
operator
()(
int
)
def
test_unops
():
def
test_unops
():
"""
"""
>>> test_unops()
>>> test_unops()
unary +
unary +
[const_char *]
unary -
unary -
[const_char *]
unary ~
unary ~
[const_char *]
unary *
unary *
[const_char *]
unary !
unary !
[const_char *]
"""
"""
cdef
TestOps
*
t
=
new
TestOps
()
cdef
TestOps
*
t
=
new
TestOps
()
out
(
+
t
[
0
])
out
(
+
t
[
0
],
typeof
(
+
t
[
0
]))
out
(
-
t
[
0
])
out
(
-
t
[
0
],
typeof
(
-
t
[
0
]))
out
(
~
t
[
0
])
out
(
~
t
[
0
],
typeof
(
~
t
[
0
]))
out
(
deref
(
t
[
0
]))
x
=
deref
(
t
[
0
])
out
(
not
t
[
0
])
out
(
x
,
typeof
(
x
))
out
(
not
t
[
0
],
typeof
(
not
t
[
0
]))
del
t
del
t
def
test_incdec
():
def
test_incdec
():
"""
"""
>>> test_incdec()
>>> test_incdec()
unary ++
unary ++
[const_char *]
unary --
unary --
[const_char *]
post ++
post ++
[const_char *]
post --
post --
[const_char *]
"""
"""
cdef
TestOps
*
t
=
new
TestOps
()
cdef
TestOps
*
t
=
new
TestOps
()
out
(
cython
.
operator
.
preincrement
(
t
[
0
]))
a
=
cython
.
operator
.
preincrement
(
t
[
0
])
out
(
cython
.
operator
.
predecrement
(
t
[
0
]))
out
(
a
,
typeof
(
a
))
out
(
cython
.
operator
.
postincrement
(
t
[
0
]))
b
=
cython
.
operator
.
predecrement
(
t
[
0
])
out
(
cython
.
operator
.
postdecrement
(
t
[
0
]))
out
(
b
,
typeof
(
b
))
c
=
cython
.
operator
.
postincrement
(
t
[
0
])
out
(
c
,
typeof
(
c
))
d
=
cython
.
operator
.
postdecrement
(
t
[
0
])
out
(
d
,
typeof
(
d
))
del
t
del
t
def
test_binop
():
def
test_binop
():
"""
"""
>>> test_binop()
>>> test_binop()
binary +
binary +
[const_char *]
binary -
binary -
[const_char *]
binary *
binary *
[const_char *]
binary /
binary /
[const_char *]
binary %
binary %
[const_char *]
binary &
binary &
[const_char *]
binary |
binary |
[const_char *]
binary ^
binary ^
[const_char *]
binary <<
binary <<
[const_char *]
binary >>
binary >>
[const_char *]
binary COMMA
binary COMMA
[const_char *]
"""
"""
cdef
TestOps
*
t
=
new
TestOps
()
cdef
TestOps
*
t
=
new
TestOps
()
out
(
t
[
0
]
+
1
)
out
(
t
[
0
]
+
1
,
typeof
(
t
[
0
]
+
1
)
)
out
(
t
[
0
]
-
1
)
out
(
t
[
0
]
-
1
,
typeof
(
t
[
0
]
-
1
)
)
out
(
t
[
0
]
*
1
)
out
(
t
[
0
]
*
1
,
typeof
(
t
[
0
]
*
1
)
)
out
(
t
[
0
]
/
1
)
out
(
t
[
0
]
/
1
,
typeof
(
t
[
0
]
/
1
)
)
out
(
t
[
0
]
%
1
)
out
(
t
[
0
]
%
1
,
typeof
(
t
[
0
]
%
1
)
)
out
(
t
[
0
]
&
1
)
out
(
t
[
0
]
&
1
,
typeof
(
t
[
0
]
&
1
)
)
out
(
t
[
0
]
|
1
)
out
(
t
[
0
]
|
1
,
typeof
(
t
[
0
]
|
1
)
)
out
(
t
[
0
]
^
1
)
out
(
t
[
0
]
^
1
,
typeof
(
t
[
0
]
^
1
)
)
out
(
t
[
0
]
<<
1
)
out
(
t
[
0
]
<<
1
,
typeof
(
t
[
0
]
<<
1
)
)
out
(
t
[
0
]
>>
1
)
out
(
t
[
0
]
>>
1
,
typeof
(
t
[
0
]
>>
1
)
)
out
(
cython
.
operator
.
comma
(
t
[
0
],
1
))
x
=
cython
.
operator
.
comma
(
t
[
0
],
1
)
out
(
x
,
typeof
(
x
))
del
t
del
t
def
test_cmp
():
def
test_cmp
():
"""
"""
>>> test_cmp()
>>> test_cmp()
binary ==
binary ==
[const_char *]
binary !=
binary !=
[const_char *]
binary >=
binary >=
[const_char *]
binary >
binary >
[const_char *]
binary <=
binary <=
[const_char *]
binary <
binary <
[const_char *]
"""
"""
cdef
TestOps
*
t
=
new
TestOps
()
cdef
TestOps
*
t
=
new
TestOps
()
out
(
t
[
0
]
==
1
)
out
(
t
[
0
]
==
1
,
typeof
(
t
[
0
]
==
1
)
)
out
(
t
[
0
]
!=
1
)
out
(
t
[
0
]
!=
1
,
typeof
(
t
[
0
]
!=
1
)
)
out
(
t
[
0
]
>=
1
)
out
(
t
[
0
]
>=
1
,
typeof
(
t
[
0
]
>=
1
)
)
out
(
t
[
0
]
>
1
)
out
(
t
[
0
]
>
1
,
typeof
(
t
[
0
]
>
1
)
)
out
(
t
[
0
]
<=
1
)
out
(
t
[
0
]
<=
1
,
typeof
(
t
[
0
]
<=
1
)
)
out
(
t
[
0
]
<
1
)
out
(
t
[
0
]
<
1
,
typeof
(
t
[
0
]
<
1
)
)
del
t
del
t
def
test_index_call
():
def
test_index_call
():
"""
"""
>>> test_index_call()
>>> test_index_call()
binary []
binary []
[const_char *]
binary ()
binary ()
[const_char *]
"""
"""
cdef
TestOps
*
t
=
new
TestOps
()
cdef
TestOps
*
t
=
new
TestOps
()
out
(
t
[
0
][
100
])
out
(
t
[
0
][
100
]
,
typeof
(
t
[
0
][
100
])
)
out
(
t
[
0
](
100
))
out
(
t
[
0
](
100
)
,
typeof
(
t
[
0
](
100
))
)
del
t
del
t
tests/run/cpp_type_inference.pyx
0 → 100644
View file @
43f3d87d
# tag: cpp
from
cython
cimport
typeof
from
cython.operator
cimport
dereference
as
d
from
cython.operator
cimport
preincrement
as
incr
from
libcpp.vector
cimport
vector
def
test_reversed_vector_iteration
(
L
):
"""
>>> test_reversed_vector_iteration([1,2,3])
int: 3
int: 2
int: 1
int
"""
cdef
vector
[
int
]
v
=
L
it
=
v
.
rbegin
()
while
it
!=
v
.
rend
():
a
=
d
(
it
)
incr
(
it
)
print
(
'%s: %s'
%
(
typeof
(
a
),
a
))
print
(
typeof
(
a
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment