Commit d0b95046 authored by da-woods's avatar da-woods Committed by GitHub

Fix bug with complex powers of negative numbers (#5014)

* Fix bug with complex powers of negative numbers

A shortcut was incorrectly applied that returned NaN instead
of an imaginary number

* Add stress test
parent 9b989f43
......@@ -265,7 +265,7 @@ static {{type}} __Pyx_PyComplex_As_{{type_name}}(PyObject* o) {
if (a.imag == 0) {
if (a.real == 0) {
return a;
} else if (b.imag == 0) {
} else if ((b.imag == 0) && (a.real >= 0)) {
z.real = pow{{m}}(a.real, b.real);
z.imag = 0;
return z;
......
......@@ -80,6 +80,8 @@ def test_pow(double complex z, double complex w, tol=None):
True
>>> test_pow(-0.5, 1j, tol=1e-15)
True
>>> test_pow(-1, 0.5, tol=1e-15)
True
"""
if tol is None:
return z**w
......@@ -264,3 +266,87 @@ cpdef double complex complex_retval():
1j
"""
return 1j
def stress_test():
"""
Run the main operations on 1000 pseudo-random numbers to
try to spot anything accidentally missed from the test cases
(doesn't cover inf and NaN as inputs though)
>>> stress_test()
"""
cdef double complex x
cdef double complex y
from random import Random
from math import ldexp
r = Random()
r.seed("I'm a seed") # try to make the test somewhat reproducible
# copied from https://docs.python.org/3/library/random.html#recipes
# gets evenly distributed random numbers
def full_random():
mantissa = 0x10_0000_0000_0000 | r.getrandbits(52)
exponent = -53
x = 0
while not x:
x = r.getrandbits(32)
exponent += x.bit_length() - 32
return ldexp(mantissa, exponent)
for n in range(1, 1001):
if n % 50 == 0:
# strategical insert some 0 values
a = 0
else:
a = full_random()
if n % 51 == 0:
b = 0
else:
b = full_random()
if n % 52 == 0:
c = 0
else:
c = full_random()
if n % 53 == 0:
d = 0
else:
d = full_random()
x= a+1j*b
y = c+1j*d
py_dict = dict(x=x, y=y)
sum_ = x+y
sum_py = eval("x+y", py_dict)
delta_sum = abs(sum_/sum_py - 1)
assert delta_sum < 1e-15, f"{x} {y} {sum_} {sum_py} {delta_sum}"
minus = x-y
minus_py = eval("x-y", py_dict)
delta_minus = abs(minus/minus_py - 1)
assert delta_minus < 1e-15, f"{x} {y} {minus} {minus_py} {delta_minus}"
times = x*y
times_py = eval("x*y", py_dict)
delta_times = abs(times/times_py - 1)
assert delta_times < 1e-15, f"{x} {y} {times} {times_py} {delta_times}"
divide = x/y
divide_py = eval("x/y", py_dict)
delta_divide = abs(divide/divide_py - 1)
assert delta_divide < 1e-15, f"{x} {y} {divide} {divide_py} {delta_divide}"
divide2 = y/x
divide2_py = eval("y/x", py_dict)
delta_divide2 = abs(divide2/divide2_py - 1)
assert delta_divide2 < 1e-15, f"{x} {y} {divide2} {divide2_py} {delta_divide2}"
pow_ = x**y
pow_py = eval("x**y", py_dict)
delta_pow = abs(pow_/pow_py - 1)
assert delta_pow < 1e-15, f"{x} {y} {pow_} {pow_py} {delta_pow}"
pow2 = y**x
pow2_py = eval("y**x", py_dict)
delta_pow2 = abs(pow2/pow2_py - 1)
assert delta_pow2 < 1e-15, f"{x} {y} {pow2} {pow2_py} {delta_pow2}"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment