Commit 43f3d87d authored by Stefan Behnel's avatar Stefan Behnel

fix type inference for overloaded C++ operators

parent eb801873
...@@ -6988,6 +6988,13 @@ class UnopNode(ExprNode): ...@@ -6988,6 +6988,13 @@ class UnopNode(ExprNode):
def infer_type(self, env): def infer_type(self, env):
operand_type = self.operand.infer_type(env) operand_type = self.operand.infer_type(env)
if operand_type.is_cpp_class or operand_type.is_ptr:
cpp_type = operand_type.find_cpp_operation_type(self.operator)
if cpp_type is not None:
return cpp_type
return self.infer_unop_type(env, operand_type)
def infer_unop_type(self, env, operand_type):
if operand_type.is_pyobject: if operand_type.is_pyobject:
return py_object_type return py_object_type
else: else:
...@@ -7042,30 +7049,23 @@ class UnopNode(ExprNode): ...@@ -7042,30 +7049,23 @@ class UnopNode(ExprNode):
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
def analyse_cpp_operation(self, env): def analyse_cpp_operation(self, env):
type = self.operand.type cpp_type = self.operand.type.find_cpp_operation_type(self.operator)
if type.is_ptr: if cpp_type is None:
type = type.base_type error(self.pos, "'%s' operator not defined for %s" % (
function = type.scope.lookup("operator%s" % self.operator) self.operator, type))
if not function:
error(self.pos, "'%s' operator not defined for %s"
% (self.operator, type))
self.type_error() self.type_error()
return return
func_type = function.type self.type = cpp_type
if func_type.is_ptr:
func_type = func_type.base_type
self.type = func_type.return_type
class NotNode(ExprNode): class NotNode(UnopNode):
# 'not' operator # 'not' operator
# #
# operand ExprNode # operand ExprNode
operator = '!'
type = PyrexTypes.c_bint_type type = PyrexTypes.c_bint_type
subexprs = ['operand']
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = not self.operand.constant_result self.constant_result = not self.operand.constant_result
...@@ -7076,23 +7076,19 @@ class NotNode(ExprNode): ...@@ -7076,23 +7076,19 @@ class NotNode(ExprNode):
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def infer_type(self, env): def infer_unop_type(self, env, operand_type):
return PyrexTypes.c_bint_type return PyrexTypes.c_bint_type
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand.analyse_types(env)
if self.operand.type.is_cpp_class: operand_type = self.operand.type
type = self.operand.type if operand_type.is_cpp_class:
function = type.scope.lookup("operator!") cpp_type = operand_type.find_cpp_operation_type(self.operator)
if not function: if not cpp_type:
error(self.pos, "'!' operator not defined for %s" error(self.pos, "'!' operator not defined for %s" % operand_type)
% (type))
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
return return
func_type = function.type self.type = cpp_type
if func_type.is_ptr:
func_type = func_type.base_type
self.type = func_type.return_type
else: else:
self.operand = self.operand.coerce_to_boolean(env) self.operand = self.operand.coerce_to_boolean(env)
...@@ -7181,6 +7177,12 @@ class DereferenceNode(CUnopNode): ...@@ -7181,6 +7177,12 @@ class DereferenceNode(CUnopNode):
operator = '*' operator = '*'
def infer_unop_type(self, env, operand_type):
if operand_type.is_ptr:
return operand_type.base_type
else:
return PyrexTypes.error_type
def analyse_c_operation(self, env): def analyse_c_operation(self, env):
if self.operand.type.is_ptr: if self.operand.type.is_ptr:
self.type = self.operand.type.base_type self.type = self.operand.type.base_type
...@@ -7213,19 +7215,23 @@ def inc_dec_constructor(is_prefix, operator): ...@@ -7213,19 +7215,23 @@ def inc_dec_constructor(is_prefix, operator):
return lambda pos, **kwds: DecrementIncrementNode(pos, is_prefix=is_prefix, operator=operator, **kwds) return lambda pos, **kwds: DecrementIncrementNode(pos, is_prefix=is_prefix, operator=operator, **kwds)
class AmpersandNode(ExprNode): class AmpersandNode(CUnopNode):
# The C address-of operator. # The C address-of operator.
# #
# operand ExprNode # operand ExprNode
operator = '&'
subexprs = ['operand'] def infer_unop_type(self, env, operand_type):
return PyrexTypes.c_ptr_type(operand_type)
def infer_type(self, env):
return PyrexTypes.c_ptr_type(self.operand.infer_type(env))
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand.analyse_types(env)
argtype = self.operand.type argtype = self.operand.type
if argtype.is_cpp_class:
cpp_type = argtype.find_cpp_operation_type(self.operator)
if cpp_type is not None:
self.type = cpp_type
return
if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()): if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()):
if argtype.is_memoryviewslice: if argtype.is_memoryviewslice:
self.error("Cannot take address of memoryview slice") self.error("Cannot take address of memoryview slice")
...@@ -7932,6 +7938,16 @@ class CBinopNode(BinopNode): ...@@ -7932,6 +7938,16 @@ class CBinopNode(BinopNode):
self.operator, self.operator,
self.operand2.result()) self.operand2.result())
def compute_c_result_type(self, type1, type2):
cpp_type = None
if type1.is_cpp_class or type1.is_ptr:
cpp_type = type1.find_cpp_operation_type(self.operator, type2)
# FIXME: handle the reversed case?
#if cpp_type is None and (type2.is_cpp_class or type2.is_ptr):
# cpp_type = type2.find_cpp_operation_type(self.operator, type1)
# FIXME: do we need to handle other cases here?
return cpp_type
def c_binop_constructor(operator): def c_binop_constructor(operator):
def make_binop_node(pos, **operands): def make_binop_node(pos, **operands):
......
...@@ -2966,11 +2966,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2966,11 +2966,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return node return node
if not node.operand.is_literal: if not node.operand.is_literal:
return node return node
if isinstance(node.operand, ExprNodes.BoolNode): if isinstance(node, ExprNodes.NotNode):
return ExprNodes.IntNode(node.pos, value = str(node.constant_result), return ExprNodes.BoolNode(node.pos, value = bool(node.constant_result),
constant_result = bool(node.constant_result))
elif isinstance(node.operand, ExprNodes.BoolNode):
return ExprNodes.IntNode(node.pos, value = str(int(node.constant_result)),
type = PyrexTypes.c_int_type, type = PyrexTypes.c_int_type,
constant_result = node.constant_result) constant_result = int(node.constant_result))
if node.operator == '+': elif node.operator == '+':
return self._handle_UnaryPlusNode(node) return self._handle_UnaryPlusNode(node)
elif node.operator == '-': elif node.operator == '-':
return self._handle_UnaryMinusNode(node) return self._handle_UnaryMinusNode(node)
......
...@@ -2292,6 +2292,11 @@ class CPtrType(CPointerBaseType): ...@@ -2292,6 +2292,11 @@ class CPtrType(CPointerBaseType):
def invalid_value(self): def invalid_value(self):
return "1" return "1"
def find_cpp_operation_type(self, operator, operand_type=None):
if self.base_type.is_cpp_class:
return self.base_type.find_cpp_operation_type(operator, operand_type=None)
return None
class CNullPtrType(CPtrType): class CNullPtrType(CPtrType):
is_null_ptr = 1 is_null_ptr = 1
...@@ -3164,6 +3169,19 @@ class CppClassType(CType): ...@@ -3164,6 +3169,19 @@ class CppClassType(CType):
def attributes_known(self): def attributes_known(self):
return self.scope is not None return self.scope is not None
def find_cpp_operation_type(self, operator, operand_type=None):
operands = [self]
if operand_type is not None:
operands.append(operand_type)
# pos == None => no errors
operator_entry = self.scope.lookup_operator_for_types(None, operator, operands)
if not operator_entry:
return None
func_type = operator_entry.type
if func_type.is_ptr:
func_type = func_type.base_type
return func_type.return_type
class TemplatePlaceholderType(CType): class TemplatePlaceholderType(CType):
......
# tag: cpp # tag: cpp
from cython cimport typeof
cimport cython.operator cimport cython.operator
from cython.operator cimport dereference as deref from cython.operator cimport dereference as deref
cdef out(s): from libc.string cimport const_char
print s.decode('ASCII')
cdef out(s, result_type=None):
print '%s [%s]' % (s.decode('ascii'), result_type)
cdef extern from "cpp_operators_helper.h": cdef extern from "cpp_operators_helper.h":
cdef cppclass TestOps: cdef cppclass TestOps:
char* operator+() const_char* operator+()
char* operator-() const_char* operator-()
char* operator*() const_char* operator*()
char* operator~() const_char* operator~()
char* operator!() const_char* operator!()
char* operator++() const_char* operator++()
char* operator--() const_char* operator--()
char* operator++(int) const_char* operator++(int)
char* operator--(int) const_char* operator--(int)
char* operator+(int) const_char* operator+(int)
char* operator-(int) const_char* operator-(int)
char* operator*(int) const_char* operator*(int)
char* operator/(int) const_char* operator/(int)
char* operator%(int) const_char* operator%(int)
char* operator|(int) const_char* operator|(int)
char* operator&(int) const_char* operator&(int)
char* operator^(int) const_char* operator^(int)
char* operator,(int) const_char* operator,(int)
char* operator<<(int) const_char* operator<<(int)
char* operator>>(int) const_char* operator>>(int)
char* operator==(int) const_char* operator==(int)
char* operator!=(int) const_char* operator!=(int)
char* operator>=(int) const_char* operator>=(int)
char* operator<=(int) const_char* operator<=(int)
char* operator>(int) const_char* operator>(int)
char* operator<(int) const_char* operator<(int)
char* operator[](int) const_char* operator[](int)
char* operator()(int) const_char* operator()(int)
def test_unops(): def test_unops():
""" """
>>> test_unops() >>> test_unops()
unary + unary + [const_char *]
unary - unary - [const_char *]
unary ~ unary ~ [const_char *]
unary * unary * [const_char *]
unary ! unary ! [const_char *]
""" """
cdef TestOps* t = new TestOps() cdef TestOps* t = new TestOps()
out(+t[0]) out(+t[0], typeof(+t[0]))
out(-t[0]) out(-t[0], typeof(-t[0]))
out(~t[0]) out(~t[0], typeof(~t[0]))
out(deref(t[0])) x = deref(t[0])
out(not t[0]) out(x, typeof(x))
out(not t[0], typeof(not t[0]))
del t del t
def test_incdec(): def test_incdec():
""" """
>>> test_incdec() >>> test_incdec()
unary ++ unary ++ [const_char *]
unary -- unary -- [const_char *]
post ++ post ++ [const_char *]
post -- post -- [const_char *]
""" """
cdef TestOps* t = new TestOps() cdef TestOps* t = new TestOps()
out(cython.operator.preincrement(t[0])) a = cython.operator.preincrement(t[0])
out(cython.operator.predecrement(t[0])) out(a, typeof(a))
out(cython.operator.postincrement(t[0])) b = cython.operator.predecrement(t[0])
out(cython.operator.postdecrement(t[0])) out(b, typeof(b))
c = cython.operator.postincrement(t[0])
out(c, typeof(c))
d = cython.operator.postdecrement(t[0])
out(d, typeof(d))
del t del t
def test_binop(): def test_binop():
""" """
>>> test_binop() >>> test_binop()
binary + binary + [const_char *]
binary - binary - [const_char *]
binary * binary * [const_char *]
binary / binary / [const_char *]
binary % binary % [const_char *]
binary & binary & [const_char *]
binary | binary | [const_char *]
binary ^ binary ^ [const_char *]
binary << binary << [const_char *]
binary >> binary >> [const_char *]
binary COMMA binary COMMA [const_char *]
""" """
cdef TestOps* t = new TestOps() cdef TestOps* t = new TestOps()
out(t[0] + 1) out(t[0] + 1, typeof(t[0] + 1))
out(t[0] - 1) out(t[0] - 1, typeof(t[0] - 1))
out(t[0] * 1) out(t[0] * 1, typeof(t[0] * 1))
out(t[0] / 1) out(t[0] / 1, typeof(t[0] / 1))
out(t[0] % 1) out(t[0] % 1, typeof(t[0] % 1))
out(t[0] & 1) out(t[0] & 1, typeof(t[0] & 1))
out(t[0] | 1) out(t[0] | 1, typeof(t[0] | 1))
out(t[0] ^ 1) out(t[0] ^ 1, typeof(t[0] ^ 1))
out(t[0] << 1) out(t[0] << 1, typeof(t[0] << 1))
out(t[0] >> 1) out(t[0] >> 1, typeof(t[0] >> 1))
out(cython.operator.comma(t[0], 1)) x = cython.operator.comma(t[0], 1)
out(x, typeof(x))
del t del t
def test_cmp(): def test_cmp():
""" """
>>> test_cmp() >>> test_cmp()
binary == binary == [const_char *]
binary != binary != [const_char *]
binary >= binary >= [const_char *]
binary > binary > [const_char *]
binary <= binary <= [const_char *]
binary < binary < [const_char *]
""" """
cdef TestOps* t = new TestOps() cdef TestOps* t = new TestOps()
out(t[0] == 1) out(t[0] == 1, typeof(t[0] == 1))
out(t[0] != 1) out(t[0] != 1, typeof(t[0] != 1))
out(t[0] >= 1) out(t[0] >= 1, typeof(t[0] >= 1))
out(t[0] > 1) out(t[0] > 1, typeof(t[0] > 1))
out(t[0] <= 1) out(t[0] <= 1, typeof(t[0] <= 1))
out(t[0] < 1) out(t[0] < 1, typeof(t[0] < 1))
del t del t
def test_index_call(): def test_index_call():
""" """
>>> test_index_call() >>> test_index_call()
binary [] binary [] [const_char *]
binary () binary () [const_char *]
""" """
cdef TestOps* t = new TestOps() cdef TestOps* t = new TestOps()
out(t[0][100]) out(t[0][100], typeof(t[0][100]))
out(t[0](100)) out(t[0](100), typeof(t[0](100)))
del t del t
# tag: cpp
from cython cimport typeof
from cython.operator cimport dereference as d
from cython.operator cimport preincrement as incr
from libcpp.vector cimport vector
def test_reversed_vector_iteration(L):
"""
>>> test_reversed_vector_iteration([1,2,3])
int: 3
int: 2
int: 1
int
"""
cdef vector[int] v = L
it = v.rbegin()
while it != v.rend():
a = d(it)
incr(it)
print('%s: %s' % (typeof(a), a))
print(typeof(a))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment