Commit 3bac3dd7 authored by Robert Bradshaw's avatar Robert Bradshaw Committed by GitHub

Merge pull request #1777 from scoder/fix_complex_comparison

Fix complex comparison
parents d4536e4f 99c86175
...@@ -1672,7 +1672,7 @@ class ImagNode(AtomicExprNode): ...@@ -1672,7 +1672,7 @@ class ImagNode(AtomicExprNode):
node = ImagNode(self.pos, value=self.value) node = ImagNode(self.pos, value=self.value)
if dst_type.is_pyobject: if dst_type.is_pyobject:
node.is_temp = 1 node.is_temp = 1
node.type = PyrexTypes.py_object_type node.type = Builtin.complex_type
# We still need to perform normal coerce_to processing on the # We still need to perform normal coerce_to processing on the
# result, because we might be coercing to an extension type, # result, because we might be coercing to an extension type,
# in which case a type test node will be needed. # in which case a type test node will be needed.
...@@ -11866,22 +11866,22 @@ class CmpNode(object): ...@@ -11866,22 +11866,22 @@ class CmpNode(object):
new_common_type = None new_common_type = None
# catch general errors # catch general errors
if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \ if (type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or
type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)): type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type))):
error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3") error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3")
new_common_type = error_type new_common_type = error_type
# try to use numeric comparisons where possible # try to use numeric comparisons where possible
elif type1.is_complex or type2.is_complex: elif type1.is_complex or type2.is_complex:
if op not in ('==', '!=') \ if (op not in ('==', '!=')
and (type1.is_complex or type1.is_numeric) \ and (type1.is_complex or type1.is_numeric)
and (type2.is_complex or type2.is_numeric): and (type2.is_complex or type2.is_numeric)):
error(self.pos, "complex types are unordered") error(self.pos, "complex types are unordered")
new_common_type = error_type new_common_type = error_type
elif type1.is_pyobject: elif type1.is_pyobject:
new_common_type = type1 new_common_type = Builtin.complex_type if type1.subtype_of(Builtin.complex_type) else py_object_type
elif type2.is_pyobject: elif type2.is_pyobject:
new_common_type = type2 new_common_type = Builtin.complex_type if type2.subtype_of(Builtin.complex_type) else py_object_type
else: else:
new_common_type = PyrexTypes.widest_numeric_type(type1, type2) new_common_type = PyrexTypes.widest_numeric_type(type1, type2)
elif type1.is_numeric and type2.is_numeric: elif type1.is_numeric and type2.is_numeric:
......
# ticket: 305 # ticket: 305
from cpython.object cimport Py_EQ, Py_NE
cimport cython cimport cython
cdef class Complex3j:
"""
>>> Complex3j() == 3j
True
>>> Complex3j() == Complex3j()
True
>>> Complex3j() != 3j
False
>>> Complex3j() != 3
True
>>> Complex3j() != Complex3j()
False
"""
def __richcmp__(a, b, int op):
if op == Py_EQ or op == Py_NE:
if isinstance(a, Complex3j):
eq = isinstance(b, Complex3j) or b == 3j
else:
eq = isinstance(b, Complex3j) and a == 3j
return eq if op == Py_EQ else not eq
return NotImplemented
def test_object_conversion(o): def test_object_conversion(o):
""" """
>>> test_object_conversion(2) >>> test_object_conversion(2)
...@@ -13,6 +39,7 @@ def test_object_conversion(o): ...@@ -13,6 +39,7 @@ def test_object_conversion(o):
cdef double complex b = o cdef double complex b = o
return (a, b) return (a, b)
def test_arithmetic(double complex z, double complex w): def test_arithmetic(double complex z, double complex w):
""" """
>>> test_arithmetic(2j, 4j) >>> test_arithmetic(2j, 4j)
...@@ -24,6 +51,7 @@ def test_arithmetic(double complex z, double complex w): ...@@ -24,6 +51,7 @@ def test_arithmetic(double complex z, double complex w):
""" """
return +z, -z+0, z+w, z-w, z*w, z/w return +z, -z+0, z+w, z-w, z*w, z/w
def test_div(double complex a, double complex b, expected): def test_div(double complex a, double complex b, expected):
""" """
>>> big = 2.0**1023 >>> big = 2.0**1023
...@@ -34,6 +62,7 @@ def test_div(double complex a, double complex b, expected): ...@@ -34,6 +62,7 @@ def test_div(double complex a, double complex b, expected):
if '_c99_' not in __name__: if '_c99_' not in __name__:
assert a / b == expected, (a / b, expected) assert a / b == expected, (a / b, expected)
def test_pow(double complex z, double complex w, tol=None): def test_pow(double complex z, double complex w, tol=None):
""" """
Various implementations produce slightly different results... Various implementations produce slightly different results...
...@@ -55,6 +84,7 @@ def test_pow(double complex z, double complex w, tol=None): ...@@ -55,6 +84,7 @@ def test_pow(double complex z, double complex w, tol=None):
else: else:
return abs(z**w / <object>z ** <object>w - 1) < tol return abs(z**w / <object>z ** <object>w - 1) < tol
def test_int_pow(double complex z, int n, tol=None): def test_int_pow(double complex z, int n, tol=None):
""" """
>>> [test_int_pow(complex(0, 1), k, 1e-15) for k in range(-4, 5)] >>> [test_int_pow(complex(0, 1), k, 1e-15) for k in range(-4, 5)]
...@@ -71,6 +101,7 @@ def test_int_pow(double complex z, int n, tol=None): ...@@ -71,6 +101,7 @@ def test_int_pow(double complex z, int n, tol=None):
else: else:
return abs(z**n / <object>z ** <object>n - 1) < tol return abs(z**n / <object>z ** <object>n - 1) < tol
@cython.cdivision(False) @cython.cdivision(False)
def test_div_by_zero(double complex z): def test_div_by_zero(double complex z):
""" """
...@@ -83,6 +114,7 @@ def test_div_by_zero(double complex z): ...@@ -83,6 +114,7 @@ def test_div_by_zero(double complex z):
""" """
return 1/z return 1/z
def test_coercion(int a, float b, double c, float complex d, double complex e): def test_coercion(int a, float b, double c, float complex d, double complex e):
""" """
>>> test_coercion(1, 1.5, 2.5, 4+1j, 10j) >>> test_coercion(1, 1.5, 2.5, 4+1j, 10j)
...@@ -101,29 +133,34 @@ def test_coercion(int a, float b, double c, float complex d, double complex e): ...@@ -101,29 +133,34 @@ def test_coercion(int a, float b, double c, float complex d, double complex e):
z = e; print z z = e; print z
return z + a + b + c + d + e return z + a + b + c + d + e
def test_compare(double complex a, double complex b): def test_compare(double complex a, double complex b):
""" """
>>> test_compare(3, 3) >>> test_compare(3, 3)
(True, False) (True, False, False, False, False, True)
>>> test_compare(3j, 3j) >>> test_compare(3j, 3j)
(True, False) (True, False, True, True, True, False)
>>> test_compare(3j, 4j) >>> test_compare(3j, 4j)
(False, True) (False, True, True, False, True, True)
>>> test_compare(3, 4) >>> test_compare(3, 4)
(False, True) (False, True, False, False, False, True)
""" """
return a == b, a != b return a == b, a != b, a == 3j, 3j == b, a == Complex3j(), Complex3j() != b
def test_compare_coerce(double complex a, int b): def test_compare_coerce(double complex a, int b):
""" """
>>> test_compare_coerce(3, 4) >>> test_compare_coerce(3, 4)
(False, True) (False, True, False, False, False, True)
>>> test_compare_coerce(4+1j, 4) >>> test_compare_coerce(4+1j, 4)
(False, True) (False, True, False, True, False, True)
>>> test_compare_coerce(4, 4) >>> test_compare_coerce(4, 4)
(True, False) (True, False, False, False, False, True)
>>> test_compare_coerce(3j, 4)
(False, True, True, False, True, False)
""" """
return a == b, a != b return a == b, a != b, a == 3j, 4+1j == a, a == Complex3j(), Complex3j() != a
def test_literal(): def test_literal():
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment