Commit 599f8475 authored by Ian Henriksen's avatar Ian Henriksen

Improved support for overloading unary operators.

parent d287f8b8
...@@ -189,7 +189,7 @@ def translate_cpp_exception(code, pos, inside, exception_value, nogil): ...@@ -189,7 +189,7 @@ def translate_cpp_exception(code, pos, inside, exception_value, nogil):
else: else:
raise_py_exception = '%s(); if (!PyErr_Occurred()) PyErr_SetString(PyExc_RuntimeError , "Error converting c++ exception.");' % exception_value.entry.cname raise_py_exception = '%s(); if (!PyErr_Occurred()) PyErr_SetString(PyExc_RuntimeError , "Error converting c++ exception.");' % exception_value.entry.cname
code.putln("try {") code.putln("try {")
code.putln("%s;" % inside) code.putln("%s" % inside)
code.putln("} catch(...) {") code.putln("} catch(...) {")
if nogil: if nogil:
code.put_ensure_gil(declare_gilstate=True) code.put_ensure_gil(declare_gilstate=True)
...@@ -9173,6 +9173,13 @@ class UnopNode(ExprNode): ...@@ -9173,6 +9173,13 @@ class UnopNode(ExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
if self.operand.type.is_pyobject: if self.operand.type.is_pyobject:
self.generate_py_operation_code(code) self.generate_py_operation_code(code)
elif self.is_temp:
if self.is_cpp_operation() and self.exception_check == '+':
translate_cpp_exception(code, self.pos,
"%s = %s %s;" % (self.result(), self.operator, self.operand.result()),
self.exception_value, self.in_nogil_context)
else:
code.putln("%s = %s %s;" % (self.result(), self.operator, self.operand.result()))
def generate_py_operation_code(self, code): def generate_py_operation_code(self, code):
function = self.py_operation_function(code) function = self.py_operation_function(code)
...@@ -9190,9 +9197,21 @@ class UnopNode(ExprNode): ...@@ -9190,9 +9197,21 @@ class UnopNode(ExprNode):
(self.operator, self.operand.type)) (self.operator, self.operand.type))
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
def analyse_cpp_operation(self, env): def analyse_cpp_operation(self, env, overload_check=True):
entry = env.lookup_operator(self.operator, [self.operand])
if overload_check and not entry:
self.type_error()
return
if entry:
self.exception_check = entry.type.exception_check
self.exception_value = entry.type.exception_value
if self.exception_check == '+':
self.is_temp = True
else:
self.exception_check = ''
self.exception_value = ''
cpp_type = self.operand.type.find_cpp_operation_type(self.operator) cpp_type = self.operand.type.find_cpp_operation_type(self.operator)
if cpp_type is None: if overload_check and cpp_type is None:
error(self.pos, "'%s' operator not defined for %s" % ( error(self.pos, "'%s' operator not defined for %s" % (
self.operator, type)) self.operator, type))
self.type_error() self.type_error()
...@@ -9225,12 +9244,7 @@ class NotNode(UnopNode): ...@@ -9225,12 +9244,7 @@ class NotNode(UnopNode):
self.operand = self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
operand_type = self.operand.type operand_type = self.operand.type
if operand_type.is_cpp_class: if operand_type.is_cpp_class:
cpp_type = operand_type.find_cpp_operation_type(self.operator) self.analyse_cpp_operation(env)
if not cpp_type:
error(self.pos, "'!' operator not defined for %s" % operand_type)
self.type = PyrexTypes.error_type
return
self.type = cpp_type
else: else:
self.operand = self.operand.coerce_to_boolean(env) self.operand = self.operand.coerce_to_boolean(env)
return self return self
...@@ -9371,10 +9385,7 @@ class AmpersandNode(CUnopNode): ...@@ -9371,10 +9385,7 @@ class AmpersandNode(CUnopNode):
self.operand = self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
argtype = self.operand.type argtype = self.operand.type
if argtype.is_cpp_class: if argtype.is_cpp_class:
cpp_type = argtype.find_cpp_operation_type(self.operator) self.analyse_cpp_operation(env, overload_check=False)
if cpp_type is not None:
self.type = cpp_type
return self
if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()): if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()):
if argtype.is_memoryviewslice: if argtype.is_memoryviewslice:
self.error("Cannot take address of memoryview slice") self.error("Cannot take address of memoryview slice")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment