Commit 826f083a authored by Raymond Hettinger's avatar Raymond Hettinger

* Fix decimal's handling of foreign types. Now returns NotImplemented

  instead of raising a TypeError.  Allows other types to successfully
  implement __radd__() style methods.
* Remove future division import from test suite.
* Remove test suite's shadowing of __builtin__.dir().
parent b204c0c5
...@@ -645,6 +645,8 @@ class Decimal(object): ...@@ -645,6 +645,8 @@ class Decimal(object):
def __cmp__(self, other, context=None): def __cmp__(self, other, context=None):
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if self._is_special or other._is_special: if self._is_special or other._is_special:
ans = self._check_nans(other, context) ans = self._check_nans(other, context)
...@@ -696,12 +698,12 @@ class Decimal(object): ...@@ -696,12 +698,12 @@ class Decimal(object):
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, (Decimal, int, long)): if not isinstance(other, (Decimal, int, long)):
return False return NotImplemented
return self.__cmp__(other) == 0 return self.__cmp__(other) == 0
def __ne__(self, other): def __ne__(self, other):
if not isinstance(other, (Decimal, int, long)): if not isinstance(other, (Decimal, int, long)):
return True return NotImplemented
return self.__cmp__(other) != 0 return self.__cmp__(other) != 0
def compare(self, other, context=None): def compare(self, other, context=None):
...@@ -714,6 +716,8 @@ class Decimal(object): ...@@ -714,6 +716,8 @@ class Decimal(object):
Like __cmp__, but returns Decimal instances. Like __cmp__, but returns Decimal instances.
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
#compare(NaN, NaN) = NaN #compare(NaN, NaN) = NaN
if (self._is_special or other and other._is_special): if (self._is_special or other and other._is_special):
...@@ -919,6 +923,8 @@ class Decimal(object): ...@@ -919,6 +923,8 @@ class Decimal(object):
-INF + INF (or the reverse) cause InvalidOperation errors. -INF + INF (or the reverse) cause InvalidOperation errors.
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if context is None: if context is None:
context = getcontext() context = getcontext()
...@@ -1006,6 +1012,8 @@ class Decimal(object): ...@@ -1006,6 +1012,8 @@ class Decimal(object):
def __sub__(self, other, context=None): def __sub__(self, other, context=None):
"""Return self + (-other)""" """Return self + (-other)"""
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if self._is_special or other._is_special: if self._is_special or other._is_special:
ans = self._check_nans(other, context=context) ans = self._check_nans(other, context=context)
...@@ -1023,6 +1031,8 @@ class Decimal(object): ...@@ -1023,6 +1031,8 @@ class Decimal(object):
def __rsub__(self, other, context=None): def __rsub__(self, other, context=None):
"""Return other + (-self)""" """Return other + (-self)"""
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
tmp = Decimal(self) tmp = Decimal(self)
tmp._sign = 1 - tmp._sign tmp._sign = 1 - tmp._sign
...@@ -1068,6 +1078,8 @@ class Decimal(object): ...@@ -1068,6 +1078,8 @@ class Decimal(object):
(+-) INF * 0 (or its reverse) raise InvalidOperation. (+-) INF * 0 (or its reverse) raise InvalidOperation.
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if context is None: if context is None:
context = getcontext() context = getcontext()
...@@ -1140,6 +1152,10 @@ class Decimal(object): ...@@ -1140,6 +1152,10 @@ class Decimal(object):
computing the other value are not raised. computing the other value are not raised.
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
if divmod in (0, 1):
return NotImplemented
return (NotImplemented, NotImplemented)
if context is None: if context is None:
context = getcontext() context = getcontext()
...@@ -1292,6 +1308,8 @@ class Decimal(object): ...@@ -1292,6 +1308,8 @@ class Decimal(object):
def __rdiv__(self, other, context=None): def __rdiv__(self, other, context=None):
"""Swaps self/other and returns __div__.""" """Swaps self/other and returns __div__."""
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
return other.__div__(self, context=context) return other.__div__(self, context=context)
__rtruediv__ = __rdiv__ __rtruediv__ = __rdiv__
...@@ -1304,6 +1322,8 @@ class Decimal(object): ...@@ -1304,6 +1322,8 @@ class Decimal(object):
def __rdivmod__(self, other, context=None): def __rdivmod__(self, other, context=None):
"""Swaps self/other and returns __divmod__.""" """Swaps self/other and returns __divmod__."""
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
return other.__divmod__(self, context=context) return other.__divmod__(self, context=context)
def __mod__(self, other, context=None): def __mod__(self, other, context=None):
...@@ -1311,6 +1331,8 @@ class Decimal(object): ...@@ -1311,6 +1331,8 @@ class Decimal(object):
self % other self % other
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if self._is_special or other._is_special: if self._is_special or other._is_special:
ans = self._check_nans(other, context) ans = self._check_nans(other, context)
...@@ -1325,6 +1347,8 @@ class Decimal(object): ...@@ -1325,6 +1347,8 @@ class Decimal(object):
def __rmod__(self, other, context=None): def __rmod__(self, other, context=None):
"""Swaps self/other and returns __mod__.""" """Swaps self/other and returns __mod__."""
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
return other.__mod__(self, context=context) return other.__mod__(self, context=context)
def remainder_near(self, other, context=None): def remainder_near(self, other, context=None):
...@@ -1332,6 +1356,8 @@ class Decimal(object): ...@@ -1332,6 +1356,8 @@ class Decimal(object):
Remainder nearest to 0- abs(remainder-near) <= other/2 Remainder nearest to 0- abs(remainder-near) <= other/2
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if self._is_special or other._is_special: if self._is_special or other._is_special:
ans = self._check_nans(other, context) ans = self._check_nans(other, context)
...@@ -1411,6 +1437,8 @@ class Decimal(object): ...@@ -1411,6 +1437,8 @@ class Decimal(object):
def __rfloordiv__(self, other, context=None): def __rfloordiv__(self, other, context=None):
"""Swaps self/other and returns __floordiv__.""" """Swaps self/other and returns __floordiv__."""
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
return other.__floordiv__(self, context=context) return other.__floordiv__(self, context=context)
def __float__(self): def __float__(self):
...@@ -1661,6 +1689,8 @@ class Decimal(object): ...@@ -1661,6 +1689,8 @@ class Decimal(object):
If modulo is None (default), don't take it mod modulo. If modulo is None (default), don't take it mod modulo.
""" """
n = _convert_other(n) n = _convert_other(n)
if n is NotImplemented:
return n
if context is None: if context is None:
context = getcontext() context = getcontext()
...@@ -1747,6 +1777,8 @@ class Decimal(object): ...@@ -1747,6 +1777,8 @@ class Decimal(object):
def __rpow__(self, other, context=None): def __rpow__(self, other, context=None):
"""Swaps self/other and returns __pow__.""" """Swaps self/other and returns __pow__."""
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
return other.__pow__(self, context=context) return other.__pow__(self, context=context)
def normalize(self, context=None): def normalize(self, context=None):
...@@ -2001,6 +2033,8 @@ class Decimal(object): ...@@ -2001,6 +2033,8 @@ class Decimal(object):
NaN (and signals if one is sNaN). Also rounds. NaN (and signals if one is sNaN). Also rounds.
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if self._is_special or other._is_special: if self._is_special or other._is_special:
# if one operand is a quiet NaN and the other is number, then the # if one operand is a quiet NaN and the other is number, then the
...@@ -2048,6 +2082,8 @@ class Decimal(object): ...@@ -2048,6 +2082,8 @@ class Decimal(object):
NaN (and signals if one is sNaN). Also rounds. NaN (and signals if one is sNaN). Also rounds.
""" """
other = _convert_other(other) other = _convert_other(other)
if other is NotImplemented:
return other
if self._is_special or other._is_special: if self._is_special or other._is_special:
# if one operand is a quiet NaN and the other is number, then the # if one operand is a quiet NaN and the other is number, then the
...@@ -2874,8 +2910,7 @@ def _convert_other(other): ...@@ -2874,8 +2910,7 @@ def _convert_other(other):
return other return other
if isinstance(other, (int, long)): if isinstance(other, (int, long)):
return Decimal(other) return Decimal(other)
return NotImplemented
raise TypeError, "You can interact Decimal only with int, long or Decimal data types."
_infinity_map = { _infinity_map = {
'inf' : 1, 'inf' : 1,
......
...@@ -24,8 +24,6 @@ you're working through IDLE, you can import this test module and call test_main( ...@@ -24,8 +24,6 @@ you're working through IDLE, you can import this test module and call test_main(
with the corresponding argument. with the corresponding argument.
""" """
from __future__ import division
import unittest import unittest
import glob import glob
import os, sys import os, sys
...@@ -54,9 +52,9 @@ if __name__ == '__main__': ...@@ -54,9 +52,9 @@ if __name__ == '__main__':
else: else:
file = __file__ file = __file__
testdir = os.path.dirname(file) or os.curdir testdir = os.path.dirname(file) or os.curdir
dir = testdir + os.sep + TESTDATADIR + os.sep directory = testdir + os.sep + TESTDATADIR + os.sep
skip_expected = not os.path.isdir(dir) skip_expected = not os.path.isdir(directory)
# Make sure it actually raises errors when not expected and caught in flags # Make sure it actually raises errors when not expected and caught in flags
# Slower, since it runs some things several times. # Slower, since it runs some things several times.
...@@ -109,7 +107,6 @@ class DecimalTest(unittest.TestCase): ...@@ -109,7 +107,6 @@ class DecimalTest(unittest.TestCase):
Changed for unittest. Changed for unittest.
""" """
def setUp(self): def setUp(self):
global dir
self.context = Context() self.context = Context()
for key in DefaultContext.traps.keys(): for key in DefaultContext.traps.keys():
DefaultContext.traps[key] = 1 DefaultContext.traps[key] = 1
...@@ -302,11 +299,11 @@ class DecimalTest(unittest.TestCase): ...@@ -302,11 +299,11 @@ class DecimalTest(unittest.TestCase):
# Dynamically build custom test definition for each file in the test # Dynamically build custom test definition for each file in the test
# directory and add the definitions to the DecimalTest class. This # directory and add the definitions to the DecimalTest class. This
# procedure insures that new files do not get skipped. # procedure insures that new files do not get skipped.
for filename in os.listdir(dir): for filename in os.listdir(directory):
if '.decTest' not in filename: if '.decTest' not in filename:
continue continue
head, tail = filename.split('.') head, tail = filename.split('.')
tester = lambda self, f=filename: self.eval_file(dir + f) tester = lambda self, f=filename: self.eval_file(directory + f)
setattr(DecimalTest, 'test_' + head, tester) setattr(DecimalTest, 'test_' + head, tester)
del filename, head, tail, tester del filename, head, tail, tester
...@@ -476,6 +473,52 @@ class DecimalImplicitConstructionTest(unittest.TestCase): ...@@ -476,6 +473,52 @@ class DecimalImplicitConstructionTest(unittest.TestCase):
def test_implicit_from_Decimal(self): def test_implicit_from_Decimal(self):
self.assertEqual(Decimal(5) + Decimal(45), Decimal(50)) self.assertEqual(Decimal(5) + Decimal(45), Decimal(50))
def test_rop(self):
# Allow other classes to be trained to interact with Decimals
class E:
def __divmod__(self, other):
return 'divmod ' + str(other)
def __rdivmod__(self, other):
return str(other) + ' rdivmod'
def __lt__(self, other):
return 'lt ' + str(other)
def __gt__(self, other):
return 'gt ' + str(other)
def __le__(self, other):
return 'le ' + str(other)
def __ge__(self, other):
return 'ge ' + str(other)
def __eq__(self, other):
return 'eq ' + str(other)
def __ne__(self, other):
return 'ne ' + str(other)
self.assertEqual(divmod(E(), Decimal(10)), 'divmod 10')
self.assertEqual(divmod(Decimal(10), E()), '10 rdivmod')
self.assertEqual(eval('Decimal(10) < E()'), 'gt 10')
self.assertEqual(eval('Decimal(10) > E()'), 'lt 10')
self.assertEqual(eval('Decimal(10) <= E()'), 'ge 10')
self.assertEqual(eval('Decimal(10) >= E()'), 'le 10')
self.assertEqual(eval('Decimal(10) == E()'), 'eq 10')
self.assertEqual(eval('Decimal(10) != E()'), 'ne 10')
# insert operator methods and then exercise them
for sym, lop, rop in (
('+', '__add__', '__radd__'),
('-', '__sub__', '__rsub__'),
('*', '__mul__', '__rmul__'),
('/', '__div__', '__rdiv__'),
('%', '__mod__', '__rmod__'),
('//', '__floordiv__', '__rfloordiv__'),
('**', '__pow__', '__rpow__'),
):
setattr(E, lop, lambda self, other: 'str' + lop + str(other))
setattr(E, rop, lambda self, other: str(other) + rop + 'str')
self.assertEqual(eval('E()' + sym + 'Decimal(10)'),
'str' + lop + '10')
self.assertEqual(eval('Decimal(10)' + sym + 'E()'),
'10' + rop + 'str')
class DecimalArithmeticOperatorsTest(unittest.TestCase): class DecimalArithmeticOperatorsTest(unittest.TestCase):
'''Unit tests for all arithmetic operators, binary and unary.''' '''Unit tests for all arithmetic operators, binary and unary.'''
......
...@@ -91,6 +91,11 @@ Library ...@@ -91,6 +91,11 @@ Library
- distutils.commands.upload was added to support uploading distribution - distutils.commands.upload was added to support uploading distribution
files to PyPI. files to PyPI.
- decimal operator and comparison methods now return NotImplemented
instead of raising a TypeError when interacting with other types. This
allows other classes to implement __radd__ style methods and have them
work as expected.
- Bug #1163325: Decimal infinities failed to hash. Attempting to - Bug #1163325: Decimal infinities failed to hash. Attempting to
hash a NaN raised an InvalidOperation instead of a TypeError. hash a NaN raised an InvalidOperation instead of a TypeError.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment