Commit 8af9a568 authored by Robert Bradshaw's avatar Robert Bradshaw

Complex numeber comparison, etc.

parent 4f05119a
...@@ -4213,8 +4213,11 @@ class NumBinopNode(BinopNode): ...@@ -4213,8 +4213,11 @@ class NumBinopNode(BinopNode):
self.operator, self.operator,
self.operand2.result()) self.operand2.result())
else: else:
func = self.type.binary_op(self.operator)
if func is None:
error(self.pos, "binary operator %s not supported for %s" % (self.operator, self.type))
return "%s(%s, %s)" % ( return "%s(%s, %s)" % (
self.type.binop(self.operator), func,
self.operand1.result(), self.operand1.result(),
self.operand2.result()) self.operand2.result())
...@@ -4318,7 +4321,7 @@ class DivNode(NumBinopNode): ...@@ -4318,7 +4321,7 @@ class DivNode(NumBinopNode):
return "float division" return "float division"
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
if not self.type.is_pyobject: if not self.type.is_pyobject and not self.type.is_complex:
if self.cdivision is None: if self.cdivision is None:
self.cdivision = (code.globalstate.directives['cdivision'] self.cdivision = (code.globalstate.directives['cdivision']
or not self.type.signed or not self.type.signed
...@@ -4331,7 +4334,11 @@ class DivNode(NumBinopNode): ...@@ -4331,7 +4334,11 @@ class DivNode(NumBinopNode):
def generate_div_warning_code(self, code): def generate_div_warning_code(self, code):
if not self.type.is_pyobject: if not self.type.is_pyobject:
if self.zerodivision_check: if self.zerodivision_check:
code.putln("if (unlikely(%s == 0)) {" % self.operand2.result()) if not self.infix:
zero_test = "%s(%s)" % (self.type.unary_op('zero'), self.operand2.result())
else:
zero_test = "%s == 0" % self.operand2.result()
code.putln("if (unlikely(%s)) {" % zero_test)
code.putln('PyErr_Format(PyExc_ZeroDivisionError, "%s");' % self.zero_division_message()) code.putln('PyErr_Format(PyExc_ZeroDivisionError, "%s");' % self.zero_division_message())
code.putln(code.error_goto(self.pos)) code.putln(code.error_goto(self.pos))
code.putln("}") code.putln("}")
...@@ -4344,7 +4351,7 @@ class DivNode(NumBinopNode): ...@@ -4344,7 +4351,7 @@ class DivNode(NumBinopNode):
code.putln('PyErr_Format(PyExc_OverflowError, "value too large to perform division");') code.putln('PyErr_Format(PyExc_OverflowError, "value too large to perform division");')
code.putln(code.error_goto(self.pos)) code.putln(code.error_goto(self.pos))
code.putln("}") code.putln("}")
if code.globalstate.directives['cdivision_warnings']: if code.globalstate.directives['cdivision_warnings'] and self.operand != '/':
code.globalstate.use_utility_code(cdivision_warning_utility_code) code.globalstate.use_utility_code(cdivision_warning_utility_code)
code.putln("if ((%s < 0) ^ (%s < 0)) {" % ( code.putln("if ((%s < 0) ^ (%s < 0)) {" % (
self.operand1.result(), self.operand1.result(),
...@@ -4355,7 +4362,9 @@ class DivNode(NumBinopNode): ...@@ -4355,7 +4362,9 @@ class DivNode(NumBinopNode):
code.putln("}") code.putln("}")
def calculate_result_code(self): def calculate_result_code(self):
if self.type.is_float and self.operator == '//': if self.type.is_complex:
return NumBinopNode.calculate_result_code(self)
elif self.type.is_float and self.operator == '//':
return "floor(%s / %s)" % ( return "floor(%s / %s)" % (
self.operand1.result(), self.operand1.result(),
self.operand2.result()) self.operand2.result())
...@@ -4705,7 +4714,13 @@ class CmpNode(object): ...@@ -4705,7 +4714,13 @@ class CmpNode(object):
or (self.cascade and self.cascade.is_python_result())) or (self.cascade and self.cascade.is_python_result()))
def check_types(self, env, operand1, op, operand2): def check_types(self, env, operand1, op, operand2):
if not self.types_okay(operand1, op, operand2): if operand1.type.is_complex or operand2.type.is_complex:
if op not in ('==', '!='):
error(self.pos, "complex types unordered")
common_type = PyrexTypes.widest_numeric_type(operand1.type, operand2.type)
self.operand1 = operand1.coerce_to(common_type, env)
self.operand2 = operand2.coerce_to(common_type, env)
elif not self.types_okay(operand1, op, operand2):
error(self.pos, "Invalid types for '%s' (%s, %s)" % error(self.pos, "Invalid types for '%s' (%s, %s)" %
(self.operator, operand1.type, operand2.type)) (self.operator, operand1.type, operand2.type))
...@@ -4754,6 +4769,16 @@ class CmpNode(object): ...@@ -4754,6 +4769,16 @@ class CmpNode(object):
richcmp_constants[op], richcmp_constants[op],
code.error_goto_if_null(result_code, self.pos))) code.error_goto_if_null(result_code, self.pos)))
code.put_gotref(result_code) code.put_gotref(result_code)
elif operand1.type.is_complex and not code.globalstate.directives['c99_complex']:
if op == "!=": negation = "!"
else: negation = ""
code.putln("%s = %s(%s%s(%s, %s));" % (
result_code,
coerce_result,
negation,
operand1.type.unary_op('eq'),
operand1.result(),
operand2.result()))
else: else:
type1 = operand1.type type1 = operand1.type
type2 = operand2.type type2 = operand2.type
...@@ -4881,6 +4906,17 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): ...@@ -4881,6 +4906,17 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
self.not_const() self.not_const()
def calculate_result_code(self): def calculate_result_code(self):
if self.operand1.type.is_complex:
if self.operator == "!=":
negation = "!"
else:
negation = ""
return "(%s%s(%s, %s))" % (
negation,
self.operand1.type.binary_op('=='),
self.operand1.result(),
self.operand2.result())
else:
return "(%s %s %s)" % ( return "(%s %s %s)" % (
self.operand1.result(), self.operand1.result(),
self.c_operator(self.operator), self.c_operator(self.operator),
......
...@@ -1921,6 +1921,10 @@ class DefNode(FuncDefNode): ...@@ -1921,6 +1921,10 @@ class DefNode(FuncDefNode):
has_star_or_kw_args = self.star_arg is not None \ has_star_or_kw_args = self.star_arg is not None \
or self.starstar_arg is not None or has_kwonly_args or self.starstar_arg is not None or has_kwonly_args
for arg in self.args:
if not arg.type.is_pyobject and arg.type.from_py_function is None:
arg.type.create_from_py_utility_code(env)
if not self.signature_has_generic_args(): if not self.signature_has_generic_args():
if has_star_or_kw_args: if has_star_or_kw_args:
error(self.pos, "This method cannot have * or keyword arguments") error(self.pos, "This method cannot have * or keyword arguments")
...@@ -1951,8 +1955,6 @@ class DefNode(FuncDefNode): ...@@ -1951,8 +1955,6 @@ class DefNode(FuncDefNode):
error(arg.pos, "Non-default argument following default argument") error(arg.pos, "Non-default argument following default argument")
elif not arg.is_self_arg: elif not arg.is_self_arg:
positional_args.append(arg) positional_args.append(arg)
if arg.type.from_py_function is None:
arg.type.create_from_py_utility_code(env)
self.generate_tuple_and_keyword_parsing_code( self.generate_tuple_and_keyword_parsing_code(
positional_args, kw_only_args, end_label, code) positional_args, kw_only_args, end_label, code)
......
...@@ -753,18 +753,33 @@ class CComplexType(CNumericType): ...@@ -753,18 +753,33 @@ class CComplexType(CNumericType):
self.from_py_function = "__pyx_PyObject_As_" + self.specalization_name() self.from_py_function = "__pyx_PyObject_As_" + self.specalization_name()
return True return True
def binop(self, op): def lookup_op(self, nargs, op):
try: try:
return self.binops[op] return self.binops[nargs, op]
except KeyError: except KeyError:
if op in "+-*/": pass
from ExprNodes import compile_time_binary_operators try:
op_name = compile_time_binary_operators[op].__name__ op_name = complex_ops[nargs, op]
self.binops[op] = func_name = "%s_%s" % (self.specalization_name(), op_name) self.binops[nargs, op] = func_name = "%s_%s" % (self.specalization_name(), op_name)
return func_name return func_name
else: except KeyError:
error("Binary '%s' not supported in for %s" % (op, self)) return None
return "<error>"
def unary_op(self, op):
return self.lookup_op(1, op)
def binary_op(self, op):
return self.lookup_op(2, op)
complex_ops = {
(1, '-'): 'neg',
(1, 'zero'): 'is_zero',
(2, '+'): 'add',
(2, '-') : 'sub',
(2, '*'): 'mul',
(2, '/'): 'div',
(2, '=='): 'eq',
}
complex_generic_utility_code = UtilityCode( complex_generic_utility_code = UtilityCode(
proto=""" proto="""
...@@ -804,6 +819,7 @@ proto=""" ...@@ -804,6 +819,7 @@ proto="""
#define %(type_name)s_from_parts(x, y) ((x) + (y)*(%(type)s)_Complex_I) #define %(type_name)s_from_parts(x, y) ((x) + (y)*(%(type)s)_Complex_I)
#define %(type_name)s_is_zero(a) ((a) == 0) #define %(type_name)s_is_zero(a) ((a) == 0)
#define %(type_name)s_eq(a, b) ((a) == (b))
#define %(type_name)s_add(a, b) ((a)+(b)) #define %(type_name)s_add(a, b) ((a)+(b))
#define %(type_name)s_sub(a, b) ((a)-(b)) #define %(type_name)s_sub(a, b) ((a)-(b))
#define %(type_name)s_mul(a, b) ((a)*(b)) #define %(type_name)s_mul(a, b) ((a)*(b))
...@@ -819,6 +835,10 @@ proto=""" ...@@ -819,6 +835,10 @@ proto="""
return (a.real == 0) & (a.imag == 0); return (a.real == 0) & (a.imag == 0);
} }
static INLINE int %(type_name)s_eq(%(type)s a, %(type)s b) {
return (a.real == b.real) & (a.imag == b.imag);
}
static INLINE %(type)s %(type_name)s_add(%(type)s a, %(type)s b) { static INLINE %(type)s %(type_name)s_add(%(type)s a, %(type)s b) {
%(type)s z; %(type)s z;
z.real = a.real + b.real; z.real = a.real + b.real;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment