Commit bd9e529d authored by Robert Bradshaw's avatar Robert Bradshaw

C++ cmp operators.

parent 45ea0053
...@@ -5303,9 +5303,9 @@ class CmpNode(object): ...@@ -5303,9 +5303,9 @@ class CmpNode(object):
def is_cpp_comparison(self): def is_cpp_comparison(self):
type1 = self.operand1.type type1 = self.operand1.type
type2 = self.operand2.type type2 = self.operand2.type
if type1.is_ptr: if type1.is_reference:
type1 = type1.base_type type1 = type1.base_type
if type2.is_ptr: if type2.is_reference:
type2 = type2.base_type type2 = type2.base_type
return type1.is_cpp_class or type2.is_cpp_class return type1.is_cpp_class or type2.is_cpp_class
...@@ -5569,6 +5569,9 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5569,6 +5569,9 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.is_cpp_comparison(): if self.is_cpp_comparison():
self.analyse_cpp_comparison(env) self.analyse_cpp_comparison(env)
if self.cascade:
error(self.pos, "Cascading comparison not yet supported for cpp types.")
return
if self.cascade: if self.cascade:
self.cascade.analyse_types(env) self.cascade.analyse_types(env)
...@@ -5601,9 +5604,9 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5601,9 +5604,9 @@ class PrimaryCmpNode(ExprNode, CmpNode):
def analyse_cpp_comparison(self, env): def analyse_cpp_comparison(self, env):
type1 = self.operand1.type type1 = self.operand1.type
type2 = self.operand2.type type2 = self.operand2.type
if type1.is_ptr: if type1.is_reference:
type1 = type1.base_type type1 = type1.base_type
if type2.is_ptr: if type2.is_reference:
type2 = type2.base_type type2 = type2.base_type
entry = env.lookup(type1.name) entry = env.lookup(type1.name)
function = entry.type.scope.lookup("operator%s" % self.operator) function = entry.type.scope.lookup("operator%s" % self.operator)
...@@ -5611,12 +5614,12 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5611,12 +5614,12 @@ class PrimaryCmpNode(ExprNode, CmpNode):
error(self.pos, "Invalid types for '%s' (%s, %s)" % error(self.pos, "Invalid types for '%s' (%s, %s)" %
(self.operator, type1, type2)) (self.operator, type1, type2))
return return
entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos) entry = PyrexTypes.best_match([self.operand2], function.all_alternatives(), self.pos)
if entry is None: if entry is None:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
if (entry.type.is_ptr): if entry.type.is_ptr:
self.type = entry.type.base_type.return_type self.type = entry.type.base_type.return_type
else: else:
self.type = entry.type.return_type self.type = entry.type.return_type
......
...@@ -2058,7 +2058,11 @@ def p_c_func_declarator(s, pos, ctx, base, cmethod_flag): ...@@ -2058,7 +2058,11 @@ def p_c_func_declarator(s, pos, ctx, base, cmethod_flag):
exception_value = exc_val, exception_check = exc_check, exception_value = exc_val, exception_check = exc_check,
nogil = nogil or ctx.nogil or with_gil, with_gil = with_gil) nogil = nogil or ctx.nogil or with_gil, with_gil = with_gil)
supported_overloaded_operators = set(['+', '-', '*', '/', '%', '++', '--', '~', '|', '&', '^', '<<', '>>' ]) supported_overloaded_operators = set([
'+', '-', '*', '/', '%',
'++', '--', '~', '|', '&', '^', '<<', '>>',
'==', '!=', '>=', '>', '<=', '<',
])
def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
assignable, nonempty): assignable, nonempty):
......
...@@ -27,6 +27,13 @@ cdef extern from "cpp_operators_helper.h": ...@@ -27,6 +27,13 @@ cdef extern from "cpp_operators_helper.h":
char* operator<<(int) char* operator<<(int)
char* operator>>(int) char* operator>>(int)
char* operator==(int)
char* operator!=(int)
char* operator>=(int)
char* operator<=(int)
char* operator>(int)
char* operator<(int)
def test_unops(): def test_unops():
""" """
>>> test_unops() >>> test_unops()
...@@ -85,3 +92,22 @@ def test_binop(): ...@@ -85,3 +92,22 @@ def test_binop():
print t[0] << 1 print t[0] << 1
print t[0] >> 1 print t[0] >> 1
del t del t
def test_cmp():
"""
>>> test_cmp()
binary ==
binary !=
binary >=
binary >
binary <=
binary <
"""
cdef TestOps* t = new TestOps()
print t[0] == 1
print t[0] != 1
print t[0] >= 1
print t[0] > 1
print t[0] <= 1
print t[0] < 1
del t
...@@ -30,5 +30,12 @@ public: ...@@ -30,5 +30,12 @@ public:
BIN_OP(|); BIN_OP(|);
BIN_OP(&); BIN_OP(&);
BIN_OP(^); BIN_OP(^);
BIN_OP(==);
BIN_OP(!=);
BIN_OP(<=);
BIN_OP(<);
BIN_OP(>=);
BIN_OP(>);
}; };
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment