Commit d65582ef authored by Dong-hee Na's avatar Dong-hee Na Committed by Dylan Trotter

Implement Complex.Add and Sub (#294)

parent 68416710
......@@ -49,6 +49,43 @@ func (c *Complex) Value() complex128 {
return c.value
}
func complexAdd(f *Frame, v, w *Object) (*Object, *BaseException) {
return complexArithmeticOp(f, "__add__", v, w, func(lhs, rhs complex128) complex128 {
return lhs + rhs
})
}
func complexEq(f *Frame, v, w *Object) (*Object, *BaseException) {
e, ok := complexCompare(toComplexUnsafe(v), w)
if !ok {
return NotImplemented, nil
}
return GetBool(e).ToObject(), nil
}
func complexHash(f *Frame, o *Object) (*Object, *BaseException) {
v := toComplexUnsafe(o).Value()
hashCombined := hashFloat(real(v)) + 1000003*hashFloat(imag(v))
if hashCombined == -1 {
hashCombined = -2
}
return NewInt(hashCombined).ToObject(), nil
}
func complexNE(f *Frame, v, w *Object) (*Object, *BaseException) {
e, ok := complexCompare(toComplexUnsafe(v), w)
if !ok {
return NotImplemented, nil
}
return GetBool(!e).ToObject(), nil
}
func complexRAdd(f *Frame, v, w *Object) (*Object, *BaseException) {
return complexArithmeticOp(f, "__radd__", v, w, func(lhs, rhs complex128) complex128 {
return lhs + rhs
})
}
func complexRepr(f *Frame, o *Object) (*Object, *BaseException) {
c := toComplexUnsafe(o).Value()
rs, is := "", ""
......@@ -68,7 +105,20 @@ func complexRepr(f *Frame, o *Object) (*Object, *BaseException) {
return NewStr(fmt.Sprintf("%s%s%s%sj%s", pre, rs, sign, is, post)).ToObject(), nil
}
func complexRSub(f *Frame, v, w *Object) (*Object, *BaseException) {
return complexArithmeticOp(f, "__rsub__", v, w, func(lhs, rhs complex128) complex128 {
return rhs - lhs
})
}
func complexSub(f *Frame, v, w *Object) (*Object, *BaseException) {
return complexArithmeticOp(f, "__rsub__", v, w, func(lhs, rhs complex128) complex128 {
return lhs - rhs
})
}
func initComplexType(dict map[string]*Object) {
ComplexType.slots.Add = &binaryOpSlot{complexAdd}
ComplexType.slots.Eq = &binaryOpSlot{complexEq}
ComplexType.slots.GE = &binaryOpSlot{complexCompareNotSupported}
ComplexType.slots.GT = &binaryOpSlot{complexCompareNotSupported}
......@@ -76,23 +126,10 @@ func initComplexType(dict map[string]*Object) {
ComplexType.slots.LE = &binaryOpSlot{complexCompareNotSupported}
ComplexType.slots.LT = &binaryOpSlot{complexCompareNotSupported}
ComplexType.slots.NE = &binaryOpSlot{complexNE}
ComplexType.slots.RAdd = &binaryOpSlot{complexRAdd}
ComplexType.slots.Repr = &unaryOpSlot{complexRepr}
}
func complexEq(f *Frame, v, w *Object) (*Object, *BaseException) {
e, ok := complexCompare(toComplexUnsafe(v), w)
if !ok {
return NotImplemented, nil
}
return GetBool(e).ToObject(), nil
}
func complexNE(f *Frame, v, w *Object) (*Object, *BaseException) {
e, ok := complexCompare(toComplexUnsafe(v), w)
if !ok {
return NotImplemented, nil
}
return GetBool(!e).ToObject(), nil
ComplexType.slots.RSub = &binaryOpSlot{complexRSub}
ComplexType.slots.Sub = &binaryOpSlot{complexSub}
}
func complexCompare(v *Complex, w *Object) (bool, bool) {
......@@ -132,11 +169,17 @@ func complexCoerce(o *Object) (complex128, bool) {
return complex(floatO, 0.0), true
}
func complexHash(f *Frame, o *Object) (*Object, *BaseException) {
v := toComplexUnsafe(o).Value()
hashCombined := hashFloat(real(v)) + 1000003*hashFloat(imag(v))
if hashCombined == -1 {
hashCombined = -2
func complexArithmeticOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) complex128) (*Object, *BaseException) {
if w.isInstance(ComplexType) {
return NewComplex(fun(toComplexUnsafe(v).Value(), toComplexUnsafe(w).Value())).ToObject(), nil
}
return NewInt(hashCombined).ToObject(), nil
floatW, ok := floatCoerce(w)
if !ok {
if math.IsInf(floatW, 0) {
return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to float")
}
return NotImplemented, nil
}
return NewComplex(fun(toComplexUnsafe(v).Value(), complex(floatW, 0))).ToObject(), nil
}
......@@ -16,6 +16,8 @@ package grumpy
import (
"math"
"math/big"
"math/cmplx"
"testing"
)
......@@ -42,6 +44,56 @@ func TestComplexEq(t *testing.T) {
}
}
func TestComplexBinaryOps(t *testing.T) {
cases := []struct {
fun func(f *Frame, v, w *Object) (*Object, *BaseException)
v, w *Object
want *Object
wantExc *BaseException
}{
{Add, NewComplex(1 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(2 + 3i).ToObject(), nil},
{Add, NewComplex(1 + 3i).ToObject(), NewFloat(-1).ToObject(), NewComplex(3i).ToObject(), nil},
{Add, NewComplex(1 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(2 + 3i).ToObject(), nil},
{Add, NewComplex(1 + 3i).ToObject(), NewComplex(-1 - 3i).ToObject(), NewComplex(0i).ToObject(), nil},
{Add, NewFloat(math.Inf(1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(1), 3)).ToObject(), nil},
{Add, NewFloat(math.Inf(-1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(-1), 3)).ToObject(), nil},
{Add, NewFloat(math.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), 3)).ToObject(), nil},
{Add, NewComplex(cmplx.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(cmplx.NaN()).ToObject(), nil},
{Add, NewFloat(math.Inf(-1)).ToObject(), NewComplex(complex(math.Inf(+1), 3)).ToObject(), NewComplex(complex(math.NaN(), 3)).ToObject(), nil},
{Add, NewComplex(1 + 3i).ToObject(), None, nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'complex' and 'NoneType'")},
{Add, None, NewComplex(1 + 3i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'NoneType' and 'complex'")},
{Add, NewInt(3).ToObject(), NewComplex(3i).ToObject(), NewComplex(3 + 3i).ToObject(), nil},
{Add, NewLong(big.NewInt(9999999)).ToObject(), NewComplex(3i).ToObject(), NewComplex(9999999 + 3i).ToObject(), nil},
{Add, NewFloat(3.5).ToObject(), NewComplex(3i).ToObject(), NewComplex(3.5 + 3i).ToObject(), nil},
{Sub, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(0i).ToObject(), nil},
{Sub, NewComplex(1 + 3i).ToObject(), NewComplex(3i).ToObject(), NewComplex(1).ToObject(), nil},
{Sub, NewComplex(1 + 3i).ToObject(), NewFloat(1).ToObject(), NewComplex(3i).ToObject(), nil},
{Sub, NewComplex(3i).ToObject(), NewFloat(1.2).ToObject(), NewComplex(-1.2 + 3i).ToObject(), nil},
{Sub, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(0i).ToObject(), nil},
{Sub, NewComplex(4 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(3 + 3i).ToObject(), nil},
{Sub, NewComplex(4 + 3i).ToObject(), NewLong(big.NewInt(99994)).ToObject(), NewComplex(-99990 + 3i).ToObject(), nil},
{Sub, NewFloat(math.Inf(1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(1), -3)).ToObject(), nil},
{Sub, NewFloat(math.Inf(-1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(-1), -3)).ToObject(), nil},
{Sub, NewComplex(1 + 3i).ToObject(), None, nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for -: 'complex' and 'NoneType'")},
{Sub, None, NewComplex(1 + 3i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for -: 'NoneType' and 'complex'")},
{Sub, NewFloat(math.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), -3)).ToObject(), nil},
{Sub, NewComplex(cmplx.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(cmplx.NaN()).ToObject(), nil},
{Sub, NewFloat(math.Inf(-1)).ToObject(), NewComplex(complex(math.Inf(-1), 3)).ToObject(), NewComplex(complex(math.NaN(), -3)).ToObject(), nil},
}
for _, cas := range cases {
switch got, result := checkInvokeResult(wrapFuncForTest(cas.fun), []*Object{cas.v, cas.w}, cas.want, cas.wantExc); result {
case checkInvokeResultExceptionMismatch:
t.Errorf("%s(%v, %v) raised %v, want %v", getFuncName(cas.fun), cas.v, cas.w, got, cas.wantExc)
case checkInvokeResultReturnValueMismatch:
if got == nil || cas.want == nil || !got.isInstance(ComplexType) || !cas.want.isInstance(ComplexType) ||
!complexesAreSame(toComplexUnsafe(got).Value(), toComplexUnsafe(cas.want).Value()) {
t.Errorf("%s(%v, %v) = %v, want %v", getFuncName(cas.fun), cas.v, cas.w, got, cas.want)
}
}
}
}
func TestComplexCompareNotSupported(t *testing.T) {
cases := []invokeTestCase{
{args: wrapArgs(complex(1, 2), 1), wantExc: mustCreateException(TypeErrorType, "no ordering relation is defined for complex numbers")},
......@@ -108,3 +160,11 @@ func TestComplexHash(t *testing.T) {
}
}
}
func floatsAreSame(a, b float64) bool {
return a == b || (math.IsNaN(a) && math.IsNaN(b))
}
func complexesAreSame(a, b complex128) bool {
return floatsAreSame(real(a), real(b)) && floatsAreSame(imag(a), imag(b))
}
......@@ -472,7 +472,6 @@ class OperatorTestCase(unittest.TestCase):
if dunder:
self.assertIs(dunder, orig)
@unittest.expectedFailure
def test_complex_operator(self):
self.assertRaises(TypeError, operator.lt, 1j, 2j)
self.assertRaises(TypeError, operator.le, 1j, 2j)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment