Commit 0aeac107 authored by Raymond Hettinger's avatar Raymond Hettinger

* Add __eq__ and __ne__ so that things like list.index() work properly

  for lists of mixed types.
* Test that sort works.
parent 10959b1c
...@@ -8,10 +8,6 @@ ...@@ -8,10 +8,6 @@
# and Tim Peters # and Tim Peters
# Todo:
# Add rich comparisons for equality testing with other types
""" """
This is a Py2.3 implementation of decimal floating point arithmetic based on This is a Py2.3 implementation of decimal floating point arithmetic based on
the General Decimal Arithmetic Specification: the General Decimal Arithmetic Specification:
...@@ -644,6 +640,16 @@ class Decimal(object): ...@@ -644,6 +640,16 @@ class Decimal(object):
return -1 return -1
return 1 return 1
def __eq__(self, other):
if not isinstance(other, (Decimal, int, long)):
return False
return self.__cmp__(other) == 0
def __ne__(self, other):
if not isinstance(other, (Decimal, int, long)):
return True
return self.__cmp__(other) != 0
def compare(self, other, context=None): def compare(self, other, context=None):
"""Compares one to another. """Compares one to another.
......
...@@ -33,6 +33,7 @@ import pickle, copy ...@@ -33,6 +33,7 @@ import pickle, copy
from decimal import * from decimal import *
from test.test_support import TestSkipped, run_unittest, run_doctest, is_resource_enabled from test.test_support import TestSkipped, run_unittest, run_doctest, is_resource_enabled
import threading import threading
import random
# Tests are built around these assumed context defaults # Tests are built around these assumed context defaults
DefaultContext.prec=9 DefaultContext.prec=9
...@@ -841,17 +842,17 @@ class DecimalUsabilityTest(unittest.TestCase): ...@@ -841,17 +842,17 @@ class DecimalUsabilityTest(unittest.TestCase):
self.assertEqual(cmp(dc,45), 0) self.assertEqual(cmp(dc,45), 0)
#a Decimal and uncomparable #a Decimal and uncomparable
try: da == 'ugly' self.assertNotEqual(da, 'ugly')
except TypeError: pass self.assertNotEqual(da, 32.7)
else: self.fail('Did not raised an error!') self.assertNotEqual(da, object())
self.assertNotEqual(da, object)
try: da == '32.7'
except TypeError: pass # sortable
else: self.fail('Did not raised an error!') a = map(Decimal, xrange(100))
b = a[:]
try: da == object random.shuffle(a)
except TypeError: pass a.sort()
else: self.fail('Did not raised an error!') self.assertEqual(a, b)
def test_copy_and_deepcopy_methods(self): def test_copy_and_deepcopy_methods(self):
d = Decimal('43.24') d = Decimal('43.24')
...@@ -1078,6 +1079,10 @@ class ContextAPItests(unittest.TestCase): ...@@ -1078,6 +1079,10 @@ class ContextAPItests(unittest.TestCase):
v2 = vars(e)[k] v2 = vars(e)[k]
self.assertEqual(v1, v2) self.assertEqual(v1, v2)
def test_equality_with_other_types(self):
self.assert_(Decimal(10) in ['a', 1.0, Decimal(10), (1,2), {}])
self.assert_(Decimal(10) not in ['a', 1.0, (1,2), {}])
def test_main(arith=False, verbose=None): def test_main(arith=False, verbose=None):
""" Execute the tests. """ Execute the tests.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment