Commit 3bac3dd7 authored by Robert Bradshaw's avatar Robert Bradshaw Committed by GitHub

Merge pull request #1777 from scoder/fix_complex_comparison

Fix complex comparison
parents d4536e4f 99c86175
......@@ -1672,7 +1672,7 @@ class ImagNode(AtomicExprNode):
node = ImagNode(self.pos, value=self.value)
if dst_type.is_pyobject:
node.is_temp = 1
node.type = PyrexTypes.py_object_type
node.type = Builtin.complex_type
# We still need to perform normal coerce_to processing on the
# result, because we might be coercing to an extension type,
# in which case a type test node will be needed.
......@@ -11866,22 +11866,22 @@ class CmpNode(object):
new_common_type = None
# catch general errors
if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \
type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)):
if (type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or
type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type))):
error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3")
new_common_type = error_type
# try to use numeric comparisons where possible
elif type1.is_complex or type2.is_complex:
if op not in ('==', '!=') \
and (type1.is_complex or type1.is_numeric) \
and (type2.is_complex or type2.is_numeric):
if (op not in ('==', '!=')
and (type1.is_complex or type1.is_numeric)
and (type2.is_complex or type2.is_numeric)):
error(self.pos, "complex types are unordered")
new_common_type = error_type
elif type1.is_pyobject:
new_common_type = type1
new_common_type = Builtin.complex_type if type1.subtype_of(Builtin.complex_type) else py_object_type
elif type2.is_pyobject:
new_common_type = type2
new_common_type = Builtin.complex_type if type2.subtype_of(Builtin.complex_type) else py_object_type
else:
new_common_type = PyrexTypes.widest_numeric_type(type1, type2)
elif type1.is_numeric and type2.is_numeric:
......
# ticket: 305
from cpython.object cimport Py_EQ, Py_NE
cimport cython
cdef class Complex3j:
"""
>>> Complex3j() == 3j
True
>>> Complex3j() == Complex3j()
True
>>> Complex3j() != 3j
False
>>> Complex3j() != 3
True
>>> Complex3j() != Complex3j()
False
"""
def __richcmp__(a, b, int op):
if op == Py_EQ or op == Py_NE:
if isinstance(a, Complex3j):
eq = isinstance(b, Complex3j) or b == 3j
else:
eq = isinstance(b, Complex3j) and a == 3j
return eq if op == Py_EQ else not eq
return NotImplemented
def test_object_conversion(o):
"""
>>> test_object_conversion(2)
......@@ -13,6 +39,7 @@ def test_object_conversion(o):
cdef double complex b = o
return (a, b)
def test_arithmetic(double complex z, double complex w):
"""
>>> test_arithmetic(2j, 4j)
......@@ -24,6 +51,7 @@ def test_arithmetic(double complex z, double complex w):
"""
return +z, -z+0, z+w, z-w, z*w, z/w
def test_div(double complex a, double complex b, expected):
"""
>>> big = 2.0**1023
......@@ -34,6 +62,7 @@ def test_div(double complex a, double complex b, expected):
if '_c99_' not in __name__:
assert a / b == expected, (a / b, expected)
def test_pow(double complex z, double complex w, tol=None):
"""
Various implementations produce slightly different results...
......@@ -55,6 +84,7 @@ def test_pow(double complex z, double complex w, tol=None):
else:
return abs(z**w / <object>z ** <object>w - 1) < tol
def test_int_pow(double complex z, int n, tol=None):
"""
>>> [test_int_pow(complex(0, 1), k, 1e-15) for k in range(-4, 5)]
......@@ -71,6 +101,7 @@ def test_int_pow(double complex z, int n, tol=None):
else:
return abs(z**n / <object>z ** <object>n - 1) < tol
@cython.cdivision(False)
def test_div_by_zero(double complex z):
"""
......@@ -83,6 +114,7 @@ def test_div_by_zero(double complex z):
"""
return 1/z
def test_coercion(int a, float b, double c, float complex d, double complex e):
"""
>>> test_coercion(1, 1.5, 2.5, 4+1j, 10j)
......@@ -101,29 +133,34 @@ def test_coercion(int a, float b, double c, float complex d, double complex e):
z = e; print z
return z + a + b + c + d + e
def test_compare(double complex a, double complex b):
"""
>>> test_compare(3, 3)
(True, False)
(True, False, False, False, False, True)
>>> test_compare(3j, 3j)
(True, False)
(True, False, True, True, True, False)
>>> test_compare(3j, 4j)
(False, True)
(False, True, True, False, True, True)
>>> test_compare(3, 4)
(False, True)
(False, True, False, False, False, True)
"""
return a == b, a != b
return a == b, a != b, a == 3j, 3j == b, a == Complex3j(), Complex3j() != b
def test_compare_coerce(double complex a, int b):
"""
>>> test_compare_coerce(3, 4)
(False, True)
(False, True, False, False, False, True)
>>> test_compare_coerce(4+1j, 4)
(False, True)
(False, True, False, True, False, True)
>>> test_compare_coerce(4, 4)
(True, False)
(True, False, False, False, False, True)
>>> test_compare_coerce(3j, 4)
(False, True, True, False, True, False)
"""
return a == b, a != b
return a == b, a != b, a == 3j, 4+1j == a, a == Complex3j(), Complex3j() != a
def test_literal():
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment