Commit 38e0cad5 authored by Facundo Batista's avatar Facundo Batista

The methods always return Decimal classes, even if they're

executed through a subclass (thanks Mark Dickinson).
Added a bit of testing for this.
parent cfaf3c48
...@@ -1473,7 +1473,7 @@ class Decimal(object): ...@@ -1473,7 +1473,7 @@ class Decimal(object):
pos += 1 pos += 1
payload = payload[pos:] payload = payload[pos:]
return Decimal((self._sign, payload, self._exp)) return Decimal((self._sign, payload, self._exp))
return self return Decimal(self)
def _fix(self, context): def _fix(self, context):
"""Round if it is necessary to keep self within prec precision. """Round if it is necessary to keep self within prec precision.
...@@ -1494,7 +1494,7 @@ class Decimal(object): ...@@ -1494,7 +1494,7 @@ class Decimal(object):
return self._fix_nan(context) return self._fix_nan(context)
else: else:
# self is +/-Infinity; return unaltered # self is +/-Infinity; return unaltered
return self return Decimal(self)
# if self is zero then exponent should be between Etiny and # if self is zero then exponent should be between Etiny and
# Emax if _clamp==0, and between Etiny and Etop if _clamp==1. # Emax if _clamp==0, and between Etiny and Etop if _clamp==1.
...@@ -1507,7 +1507,7 @@ class Decimal(object): ...@@ -1507,7 +1507,7 @@ class Decimal(object):
context._raise_error(Clamped) context._raise_error(Clamped)
return Decimal((self._sign, (0,), new_exp)) return Decimal((self._sign, (0,), new_exp))
else: else:
return self return Decimal(self)
# exp_min is the smallest allowable exponent of the result, # exp_min is the smallest allowable exponent of the result,
# equal to max(self.adjusted()-context.prec+1, Etiny) # equal to max(self.adjusted()-context.prec+1, Etiny)
...@@ -1551,7 +1551,7 @@ class Decimal(object): ...@@ -1551,7 +1551,7 @@ class Decimal(object):
return Decimal((self._sign, self_padded, Etop)) return Decimal((self._sign, self_padded, Etop))
# here self was representable to begin with; return unchanged # here self was representable to begin with; return unchanged
return self return Decimal(self)
_pick_rounding_function = {} _pick_rounding_function = {}
...@@ -1678,10 +1678,10 @@ class Decimal(object): ...@@ -1678,10 +1678,10 @@ class Decimal(object):
return context._raise_error(InvalidOperation, 'sNaN', return context._raise_error(InvalidOperation, 'sNaN',
1, modulo) 1, modulo)
if self_is_nan: if self_is_nan:
return self return self._fix_nan(context)
if other_is_nan: if other_is_nan:
return other return other._fix_nan(context)
return modulo return modulo._fix_nan(context)
# check inputs: we apply same restrictions as Python's pow() # check inputs: we apply same restrictions as Python's pow()
if not (self._isinteger() and if not (self._isinteger() and
...@@ -2179,7 +2179,7 @@ class Decimal(object): ...@@ -2179,7 +2179,7 @@ class Decimal(object):
if exp._isinfinity() or self._isinfinity(): if exp._isinfinity() or self._isinfinity():
if exp._isinfinity() and self._isinfinity(): if exp._isinfinity() and self._isinfinity():
return self # if both are inf, it is OK return Decimal(self) # if both are inf, it is OK
return context._raise_error(InvalidOperation, return context._raise_error(InvalidOperation,
'quantize with one INF') 'quantize with one INF')
...@@ -2254,7 +2254,7 @@ class Decimal(object): ...@@ -2254,7 +2254,7 @@ class Decimal(object):
rounding = rounding mode rounding = rounding mode
""" """
if self._is_special: if self._is_special:
return self return Decimal(self)
if not self: if not self:
return Decimal((self._sign, (0,), exp)) return Decimal((self._sign, (0,), exp))
...@@ -2285,9 +2285,9 @@ class Decimal(object): ...@@ -2285,9 +2285,9 @@ class Decimal(object):
ans = self._check_nans(context=context) ans = self._check_nans(context=context)
if ans: if ans:
return ans return ans
return self return Decimal(self)
if self._exp >= 0: if self._exp >= 0:
return self return Decimal(self)
if not self: if not self:
return Decimal((self._sign, (0,), 0)) return Decimal((self._sign, (0,), 0))
if context is None: if context is None:
...@@ -2310,9 +2310,9 @@ class Decimal(object): ...@@ -2310,9 +2310,9 @@ class Decimal(object):
ans = self._check_nans(context=context) ans = self._check_nans(context=context)
if ans: if ans:
return ans return ans
return self return Decimal(self)
if self._exp >= 0: if self._exp >= 0:
return self return Decimal(self)
else: else:
return self._rescale(0, rounding) return self._rescale(0, rounding)
...@@ -2426,6 +2426,9 @@ class Decimal(object): ...@@ -2426,6 +2426,9 @@ class Decimal(object):
""" """
other = _convert_other(other, raiseit=True) other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special: if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the # If one operand is a quiet NaN and the other is number, then the
# number is always returned # number is always returned
...@@ -2433,9 +2436,9 @@ class Decimal(object): ...@@ -2433,9 +2436,9 @@ class Decimal(object):
on = other._isnan() on = other._isnan()
if sn or on: if sn or on:
if on == 1 and sn != 2: if on == 1 and sn != 2:
return self return self._fix_nan(context)
if sn == 1 and on != 2: if sn == 1 and on != 2:
return other return other._fix_nan(context)
return self._check_nans(other, context) return self._check_nans(other, context)
c = self.__cmp__(other) c = self.__cmp__(other)
...@@ -2455,8 +2458,6 @@ class Decimal(object): ...@@ -2455,8 +2458,6 @@ class Decimal(object):
else: else:
ans = self ans = self
if context is None:
context = getcontext()
if context._rounding_decision == ALWAYS_ROUND: if context._rounding_decision == ALWAYS_ROUND:
return ans._fix(context) return ans._fix(context)
return ans return ans
...@@ -2469,6 +2470,9 @@ class Decimal(object): ...@@ -2469,6 +2470,9 @@ class Decimal(object):
""" """
other = _convert_other(other, raiseit=True) other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special: if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the # If one operand is a quiet NaN and the other is number, then the
# number is always returned # number is always returned
...@@ -2476,9 +2480,9 @@ class Decimal(object): ...@@ -2476,9 +2480,9 @@ class Decimal(object):
on = other._isnan() on = other._isnan()
if sn or on: if sn or on:
if on == 1 and sn != 2: if on == 1 and sn != 2:
return self return self._fix_nan(context)
if sn == 1 and on != 2: if sn == 1 and on != 2:
return other return other._fix_nan(context)
return self._check_nans(other, context) return self._check_nans(other, context)
c = self.__cmp__(other) c = self.__cmp__(other)
...@@ -2490,8 +2494,6 @@ class Decimal(object): ...@@ -2490,8 +2494,6 @@ class Decimal(object):
else: else:
ans = other ans = other
if context is None:
context = getcontext()
if context._rounding_decision == ALWAYS_ROUND: if context._rounding_decision == ALWAYS_ROUND:
return ans._fix(context) return ans._fix(context)
return ans return ans
...@@ -3087,6 +3089,9 @@ class Decimal(object): ...@@ -3087,6 +3089,9 @@ class Decimal(object):
"""Compares the values numerically with their sign ignored.""" """Compares the values numerically with their sign ignored."""
other = _convert_other(other, raiseit=True) other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special: if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the # If one operand is a quiet NaN and the other is number, then the
# number is always returned # number is always returned
...@@ -3094,9 +3099,9 @@ class Decimal(object): ...@@ -3094,9 +3099,9 @@ class Decimal(object):
on = other._isnan() on = other._isnan()
if sn or on: if sn or on:
if on == 1 and sn != 2: if on == 1 and sn != 2:
return self return self._fix_nan(context)
if sn == 1 and on != 2: if sn == 1 and on != 2:
return other return other._fix_nan(context)
return self._check_nans(other, context) return self._check_nans(other, context)
c = self.copy_abs().__cmp__(other.copy_abs()) c = self.copy_abs().__cmp__(other.copy_abs())
...@@ -3108,8 +3113,6 @@ class Decimal(object): ...@@ -3108,8 +3113,6 @@ class Decimal(object):
else: else:
ans = self ans = self
if context is None:
context = getcontext()
if context._rounding_decision == ALWAYS_ROUND: if context._rounding_decision == ALWAYS_ROUND:
return ans._fix(context) return ans._fix(context)
return ans return ans
...@@ -3118,6 +3121,9 @@ class Decimal(object): ...@@ -3118,6 +3121,9 @@ class Decimal(object):
"""Compares the values numerically with their sign ignored.""" """Compares the values numerically with their sign ignored."""
other = _convert_other(other, raiseit=True) other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special: if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the # If one operand is a quiet NaN and the other is number, then the
# number is always returned # number is always returned
...@@ -3125,9 +3131,9 @@ class Decimal(object): ...@@ -3125,9 +3131,9 @@ class Decimal(object):
on = other._isnan() on = other._isnan()
if sn or on: if sn or on:
if on == 1 and sn != 2: if on == 1 and sn != 2:
return self return self._fix_nan(context)
if sn == 1 and on != 2: if sn == 1 and on != 2:
return other return other._fix_nan(context)
return self._check_nans(other, context) return self._check_nans(other, context)
c = self.copy_abs().__cmp__(other.copy_abs()) c = self.copy_abs().__cmp__(other.copy_abs())
...@@ -3139,8 +3145,6 @@ class Decimal(object): ...@@ -3139,8 +3145,6 @@ class Decimal(object):
else: else:
ans = other ans = other
if context is None:
context = getcontext()
if context._rounding_decision == ALWAYS_ROUND: if context._rounding_decision == ALWAYS_ROUND:
return ans._fix(context) return ans._fix(context)
return ans return ans
...@@ -3296,7 +3300,7 @@ class Decimal(object): ...@@ -3296,7 +3300,7 @@ class Decimal(object):
return context._raise_error(InvalidOperation) return context._raise_error(InvalidOperation)
if self._isinfinity(): if self._isinfinity():
return self return Decimal(self)
# get values, pad if necessary # get values, pad if necessary
torot = int(other) torot = int(other)
...@@ -3334,7 +3338,7 @@ class Decimal(object): ...@@ -3334,7 +3338,7 @@ class Decimal(object):
return context._raise_error(InvalidOperation) return context._raise_error(InvalidOperation)
if self._isinfinity(): if self._isinfinity():
return self return Decimal(self)
d = Decimal((self._sign, self._int, self._exp + int(other))) d = Decimal((self._sign, self._int, self._exp + int(other)))
d = d._fix(context) d = d._fix(context)
...@@ -3355,12 +3359,12 @@ class Decimal(object): ...@@ -3355,12 +3359,12 @@ class Decimal(object):
return context._raise_error(InvalidOperation) return context._raise_error(InvalidOperation)
if self._isinfinity(): if self._isinfinity():
return self return Decimal(self)
# get values, pad if necessary # get values, pad if necessary
torot = int(other) torot = int(other)
if not torot: if not torot:
return self return Decimal(self)
rotdig = self._int rotdig = self._int
topad = context.prec - len(rotdig) topad = context.prec - len(rotdig)
if topad: if topad:
...@@ -3751,7 +3755,7 @@ class Context(object): ...@@ -3751,7 +3755,7 @@ class Context(object):
>>> ExtendedContext.copy_decimal(Decimal('-1.00')) >>> ExtendedContext.copy_decimal(Decimal('-1.00'))
Decimal("-1.00") Decimal("-1.00")
""" """
return a return Decimal(a)
def copy_negate(self, a): def copy_negate(self, a):
"""Returns a copy of the operand with the sign inverted. """Returns a copy of the operand with the sign inverted.
......
...@@ -1072,6 +1072,21 @@ class DecimalUsabilityTest(unittest.TestCase): ...@@ -1072,6 +1072,21 @@ class DecimalUsabilityTest(unittest.TestCase):
checkSameDec("to_eng_string") checkSameDec("to_eng_string")
checkSameDec("to_integral") checkSameDec("to_integral")
def test_subclassing(self):
# Different behaviours when subclassing Decimal
class MyDecimal(Decimal):
pass
d1 = MyDecimal(1)
d2 = MyDecimal(2)
d = d1 + d2
self.assertTrue(type(d) is Decimal)
d = d1.max(d2)
self.assertTrue(type(d) is Decimal)
class DecimalPythonAPItests(unittest.TestCase): class DecimalPythonAPItests(unittest.TestCase):
def test_pickle(self): def test_pickle(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment