Commit 43f3d87d authored by Stefan Behnel's avatar Stefan Behnel

fix type inference for overloaded C++ operators

parent eb801873
......@@ -6988,6 +6988,13 @@ class UnopNode(ExprNode):
def infer_type(self, env):
operand_type = self.operand.infer_type(env)
if operand_type.is_cpp_class or operand_type.is_ptr:
cpp_type = operand_type.find_cpp_operation_type(self.operator)
if cpp_type is not None:
return cpp_type
return self.infer_unop_type(env, operand_type)
def infer_unop_type(self, env, operand_type):
if operand_type.is_pyobject:
return py_object_type
else:
......@@ -7042,30 +7049,23 @@ class UnopNode(ExprNode):
self.type = PyrexTypes.error_type
def analyse_cpp_operation(self, env):
type = self.operand.type
if type.is_ptr:
type = type.base_type
function = type.scope.lookup("operator%s" % self.operator)
if not function:
error(self.pos, "'%s' operator not defined for %s"
% (self.operator, type))
cpp_type = self.operand.type.find_cpp_operation_type(self.operator)
if cpp_type is None:
error(self.pos, "'%s' operator not defined for %s" % (
self.operator, type))
self.type_error()
return
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.type = func_type.return_type
self.type = cpp_type
class NotNode(ExprNode):
class NotNode(UnopNode):
# 'not' operator
#
# operand ExprNode
operator = '!'
type = PyrexTypes.c_bint_type
subexprs = ['operand']
def calculate_constant_result(self):
self.constant_result = not self.operand.constant_result
......@@ -7076,23 +7076,19 @@ class NotNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
def infer_type(self, env):
def infer_unop_type(self, env, operand_type):
return PyrexTypes.c_bint_type
def analyse_types(self, env):
self.operand.analyse_types(env)
if self.operand.type.is_cpp_class:
type = self.operand.type
function = type.scope.lookup("operator!")
if not function:
error(self.pos, "'!' operator not defined for %s"
% (type))
operand_type = self.operand.type
if operand_type.is_cpp_class:
cpp_type = operand_type.find_cpp_operation_type(self.operator)
if not cpp_type:
error(self.pos, "'!' operator not defined for %s" % operand_type)
self.type = PyrexTypes.error_type
return
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.type = func_type.return_type
self.type = cpp_type
else:
self.operand = self.operand.coerce_to_boolean(env)
......@@ -7181,6 +7177,12 @@ class DereferenceNode(CUnopNode):
operator = '*'
def infer_unop_type(self, env, operand_type):
if operand_type.is_ptr:
return operand_type.base_type
else:
return PyrexTypes.error_type
def analyse_c_operation(self, env):
if self.operand.type.is_ptr:
self.type = self.operand.type.base_type
......@@ -7213,19 +7215,23 @@ def inc_dec_constructor(is_prefix, operator):
return lambda pos, **kwds: DecrementIncrementNode(pos, is_prefix=is_prefix, operator=operator, **kwds)
class AmpersandNode(ExprNode):
class AmpersandNode(CUnopNode):
# The C address-of operator.
#
# operand ExprNode
operator = '&'
subexprs = ['operand']
def infer_type(self, env):
return PyrexTypes.c_ptr_type(self.operand.infer_type(env))
def infer_unop_type(self, env, operand_type):
return PyrexTypes.c_ptr_type(operand_type)
def analyse_types(self, env):
self.operand.analyse_types(env)
argtype = self.operand.type
if argtype.is_cpp_class:
cpp_type = argtype.find_cpp_operation_type(self.operator)
if cpp_type is not None:
self.type = cpp_type
return
if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()):
if argtype.is_memoryviewslice:
self.error("Cannot take address of memoryview slice")
......@@ -7932,6 +7938,16 @@ class CBinopNode(BinopNode):
self.operator,
self.operand2.result())
def compute_c_result_type(self, type1, type2):
cpp_type = None
if type1.is_cpp_class or type1.is_ptr:
cpp_type = type1.find_cpp_operation_type(self.operator, type2)
# FIXME: handle the reversed case?
#if cpp_type is None and (type2.is_cpp_class or type2.is_ptr):
# cpp_type = type2.find_cpp_operation_type(self.operator, type1)
# FIXME: do we need to handle other cases here?
return cpp_type
def c_binop_constructor(operator):
def make_binop_node(pos, **operands):
......
......@@ -2966,11 +2966,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return node
if not node.operand.is_literal:
return node
if isinstance(node.operand, ExprNodes.BoolNode):
return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
if isinstance(node, ExprNodes.NotNode):
return ExprNodes.BoolNode(node.pos, value = bool(node.constant_result),
constant_result = bool(node.constant_result))
elif isinstance(node.operand, ExprNodes.BoolNode):
return ExprNodes.IntNode(node.pos, value = str(int(node.constant_result)),
type = PyrexTypes.c_int_type,
constant_result = node.constant_result)
if node.operator == '+':
constant_result = int(node.constant_result))
elif node.operator == '+':
return self._handle_UnaryPlusNode(node)
elif node.operator == '-':
return self._handle_UnaryMinusNode(node)
......
......@@ -2292,6 +2292,11 @@ class CPtrType(CPointerBaseType):
def invalid_value(self):
return "1"
def find_cpp_operation_type(self, operator, operand_type=None):
if self.base_type.is_cpp_class:
return self.base_type.find_cpp_operation_type(operator, operand_type=None)
return None
class CNullPtrType(CPtrType):
is_null_ptr = 1
......@@ -3164,6 +3169,19 @@ class CppClassType(CType):
def attributes_known(self):
return self.scope is not None
def find_cpp_operation_type(self, operator, operand_type=None):
operands = [self]
if operand_type is not None:
operands.append(operand_type)
# pos == None => no errors
operator_entry = self.scope.lookup_operator_for_types(None, operator, operands)
if not operator_entry:
return None
func_type = operator_entry.type
if func_type.is_ptr:
func_type = func_type.base_type
return func_type.return_type
class TemplatePlaceholderType(CType):
......
# tag: cpp
from cython cimport typeof
cimport cython.operator
from cython.operator cimport dereference as deref
cdef out(s):
print s.decode('ASCII')
from libc.string cimport const_char
cdef out(s, result_type=None):
print '%s [%s]' % (s.decode('ascii'), result_type)
cdef extern from "cpp_operators_helper.h":
cdef cppclass TestOps:
char* operator+()
char* operator-()
char* operator*()
char* operator~()
char* operator!()
char* operator++()
char* operator--()
char* operator++(int)
char* operator--(int)
char* operator+(int)
char* operator-(int)
char* operator*(int)
char* operator/(int)
char* operator%(int)
char* operator|(int)
char* operator&(int)
char* operator^(int)
char* operator,(int)
char* operator<<(int)
char* operator>>(int)
char* operator==(int)
char* operator!=(int)
char* operator>=(int)
char* operator<=(int)
char* operator>(int)
char* operator<(int)
char* operator[](int)
char* operator()(int)
const_char* operator+()
const_char* operator-()
const_char* operator*()
const_char* operator~()
const_char* operator!()
const_char* operator++()
const_char* operator--()
const_char* operator++(int)
const_char* operator--(int)
const_char* operator+(int)
const_char* operator-(int)
const_char* operator*(int)
const_char* operator/(int)
const_char* operator%(int)
const_char* operator|(int)
const_char* operator&(int)
const_char* operator^(int)
const_char* operator,(int)
const_char* operator<<(int)
const_char* operator>>(int)
const_char* operator==(int)
const_char* operator!=(int)
const_char* operator>=(int)
const_char* operator<=(int)
const_char* operator>(int)
const_char* operator<(int)
const_char* operator[](int)
const_char* operator()(int)
def test_unops():
"""
>>> test_unops()
unary +
unary -
unary ~
unary *
unary !
unary + [const_char *]
unary - [const_char *]
unary ~ [const_char *]
unary * [const_char *]
unary ! [const_char *]
"""
cdef TestOps* t = new TestOps()
out(+t[0])
out(-t[0])
out(~t[0])
out(deref(t[0]))
out(not t[0])
out(+t[0], typeof(+t[0]))
out(-t[0], typeof(-t[0]))
out(~t[0], typeof(~t[0]))
x = deref(t[0])
out(x, typeof(x))
out(not t[0], typeof(not t[0]))
del t
def test_incdec():
"""
>>> test_incdec()
unary ++
unary --
post ++
post --
unary ++ [const_char *]
unary -- [const_char *]
post ++ [const_char *]
post -- [const_char *]
"""
cdef TestOps* t = new TestOps()
out(cython.operator.preincrement(t[0]))
out(cython.operator.predecrement(t[0]))
out(cython.operator.postincrement(t[0]))
out(cython.operator.postdecrement(t[0]))
a = cython.operator.preincrement(t[0])
out(a, typeof(a))
b = cython.operator.predecrement(t[0])
out(b, typeof(b))
c = cython.operator.postincrement(t[0])
out(c, typeof(c))
d = cython.operator.postdecrement(t[0])
out(d, typeof(d))
del t
def test_binop():
"""
>>> test_binop()
binary +
binary -
binary *
binary /
binary %
binary &
binary |
binary ^
binary <<
binary >>
binary COMMA
binary + [const_char *]
binary - [const_char *]
binary * [const_char *]
binary / [const_char *]
binary % [const_char *]
binary & [const_char *]
binary | [const_char *]
binary ^ [const_char *]
binary << [const_char *]
binary >> [const_char *]
binary COMMA [const_char *]
"""
cdef TestOps* t = new TestOps()
out(t[0] + 1)
out(t[0] - 1)
out(t[0] * 1)
out(t[0] / 1)
out(t[0] % 1)
out(t[0] + 1, typeof(t[0] + 1))
out(t[0] - 1, typeof(t[0] - 1))
out(t[0] * 1, typeof(t[0] * 1))
out(t[0] / 1, typeof(t[0] / 1))
out(t[0] % 1, typeof(t[0] % 1))
out(t[0] & 1)
out(t[0] | 1)
out(t[0] ^ 1)
out(t[0] & 1, typeof(t[0] & 1))
out(t[0] | 1, typeof(t[0] | 1))
out(t[0] ^ 1, typeof(t[0] ^ 1))
out(t[0] << 1)
out(t[0] >> 1)
out(t[0] << 1, typeof(t[0] << 1))
out(t[0] >> 1, typeof(t[0] >> 1))
out(cython.operator.comma(t[0], 1))
x = cython.operator.comma(t[0], 1)
out(x, typeof(x))
del t
def test_cmp():
"""
>>> test_cmp()
binary ==
binary !=
binary >=
binary >
binary <=
binary <
binary == [const_char *]
binary != [const_char *]
binary >= [const_char *]
binary > [const_char *]
binary <= [const_char *]
binary < [const_char *]
"""
cdef TestOps* t = new TestOps()
out(t[0] == 1)
out(t[0] != 1)
out(t[0] >= 1)
out(t[0] > 1)
out(t[0] <= 1)
out(t[0] < 1)
out(t[0] == 1, typeof(t[0] == 1))
out(t[0] != 1, typeof(t[0] != 1))
out(t[0] >= 1, typeof(t[0] >= 1))
out(t[0] > 1, typeof(t[0] > 1))
out(t[0] <= 1, typeof(t[0] <= 1))
out(t[0] < 1, typeof(t[0] < 1))
del t
def test_index_call():
"""
>>> test_index_call()
binary []
binary ()
binary [] [const_char *]
binary () [const_char *]
"""
cdef TestOps* t = new TestOps()
out(t[0][100])
out(t[0](100))
out(t[0][100], typeof(t[0][100]))
out(t[0](100), typeof(t[0](100)))
del t
# tag: cpp
from cython cimport typeof
from cython.operator cimport dereference as d
from cython.operator cimport preincrement as incr
from libcpp.vector cimport vector
def test_reversed_vector_iteration(L):
"""
>>> test_reversed_vector_iteration([1,2,3])
int: 3
int: 2
int: 1
int
"""
cdef vector[int] v = L
it = v.rbegin()
while it != v.rend():
a = d(it)
incr(it)
print('%s: %s' % (typeof(a), a))
print(typeof(a))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment