Commit 57cb50bb authored by Raymond Hettinger's avatar Raymond Hettinger

Improve diff for assertCountEqual() to actually show the differing counts.

New output looks like this:

Traceback (most recent call last):
  File "test.py", line 5, in test_ce
    self.assertCountEqual('abracadabra xx', 'simsalabim xx')
AssertionError: Element counts were not equal:
Expected 5, got 2:  'a'
Expected 2, got 1:  'b'
Expected 0, got 2:  'i'
Expected 0, got 2:  'm'
Expected 0, got 1:  'l'
Expected 0, got 2:  's'
Expected 1, got 0:  'c'
Expected 1, got 0:  'd'
Expected 2, got 0:  'r'
parent 8325e975
...@@ -10,7 +10,8 @@ import collections ...@@ -10,7 +10,8 @@ import collections
from . import result from . import result
from .util import (strclass, safe_repr, sorted_list_difference, from .util import (strclass, safe_repr, sorted_list_difference,
unorderable_list_difference) unorderable_list_difference, _count_diff_all_purpose,
_count_diff_hashable)
__unittest = True __unittest = True
...@@ -1022,23 +1023,22 @@ class TestCase(object): ...@@ -1022,23 +1023,22 @@ class TestCase(object):
expected = collections.Counter(expected_seq) expected = collections.Counter(expected_seq)
except TypeError: except TypeError:
# Handle case with unhashable elements # Handle case with unhashable elements
missing, unexpected = unorderable_list_difference(expected_seq, actual_seq) differences = _count_diff_all_purpose(expected_seq, actual_seq)
else: else:
if actual == expected: if actual == expected:
return return
missing = list(expected - actual) differences = _count_diff_hashable(expected_seq, actual_seq)
unexpected = list(actual - expected)
if differences:
errors = [] standardMsg = 'Element counts were not equal:\n'
if missing: lines = []
errors.append('Expected, but missing:\n %s' % for act, exp, elem in differences:
safe_repr(missing)) line = 'Expected %d, got %d: %r' % (exp, act, elem)
if unexpected: lines.append(line)
errors.append('Unexpected, but present:\n %s' % diffMsg = '\n'.join(lines)
safe_repr(unexpected)) standardMsg = self._truncateMessage(standardMsg, diffMsg)
if errors: msg = self._formatMessage(msg, standardMsg)
standardMsg = '\n'.join(errors) self.fail(msg)
self.fail(self._formatMessage(msg, standardMsg))
def assertMultiLineEqual(self, first, second, msg=None): def assertMultiLineEqual(self, first, second, msg=None):
"""Assert that two multi-line strings are equal.""" """Assert that two multi-line strings are equal."""
......
...@@ -229,12 +229,6 @@ class TestLongMessage(unittest.TestCase): ...@@ -229,12 +229,6 @@ class TestLongMessage(unittest.TestCase):
"^Missing: 'key'$", "^Missing: 'key'$",
"^Missing: 'key' : oops$"]) "^Missing: 'key' : oops$"])
def testassertCountEqual(self):
self.assertMessages('assertCountEqual', ([], [None]),
[r"\[None\]$", "^oops$",
r"\[None\]$",
r"\[None\] : oops$"])
def testAssertMultiLineEqual(self): def testAssertMultiLineEqual(self):
self.assertMessages('assertMultiLineEqual', ("", "foo"), self.assertMessages('assertMultiLineEqual', ("", "foo"),
[r"\+ foo$", "^oops$", [r"\+ foo$", "^oops$",
......
"""Various utility functions.""" """Various utility functions."""
from collections import namedtuple, Counter
__unittest = True __unittest = True
_MAX_LENGTH = 80 _MAX_LENGTH = 80
...@@ -12,7 +14,6 @@ def safe_repr(obj, short=False): ...@@ -12,7 +14,6 @@ def safe_repr(obj, short=False):
return result return result
return result[:_MAX_LENGTH] + ' [truncated]...' return result[:_MAX_LENGTH] + ' [truncated]...'
def strclass(cls): def strclass(cls):
return "%s.%s" % (cls.__module__, cls.__name__) return "%s.%s" % (cls.__module__, cls.__name__)
...@@ -77,3 +78,58 @@ def unorderable_list_difference(expected, actual): ...@@ -77,3 +78,58 @@ def unorderable_list_difference(expected, actual):
def three_way_cmp(x, y): def three_way_cmp(x, y):
"""Return -1 if x < y, 0 if x == y and 1 if x > y""" """Return -1 if x < y, 0 if x == y and 1 if x > y"""
return (x > y) - (x < y) return (x > y) - (x < y)
_Mismatch = namedtuple('Mismatch', 'actual expected value')
def _count_diff_all_purpose(actual, expected):
'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
# elements need not be hashable
s, t = list(actual), list(expected)
m, n = len(s), len(t)
NULL = object()
result = []
for i, elem in enumerate(s):
if elem is NULL:
continue
cnt_s = cnt_t = 0
for j in range(i, m):
if s[j] == elem:
cnt_s += 1
s[j] = NULL
for j, other_elem in enumerate(t):
if other_elem == elem:
cnt_t += 1
t[j] = NULL
if cnt_s != cnt_t:
diff = _Mismatch(cnt_s, cnt_t, elem)
result.append(diff)
for i, elem in enumerate(t):
if elem is NULL:
continue
cnt_t = 0
for j in range(i, n):
if t[j] == elem:
cnt_t += 1
t[j] = NULL
diff = _Mismatch(0, cnt_t, elem)
result.append(diff)
return result
def _count_diff_hashable(actual, expected):
'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
# elements must be hashable
s, t = Counter(actual), Counter(expected)
if s == t:
return []
result = []
for elem, cnt_s in s.items():
cnt_t = t[elem]
if cnt_s != cnt_t:
diff = _Mismatch(cnt_s, cnt_t, elem)
result.append(diff)
for elem, cnt_t in t.items():
if elem not in s:
diff = _Mismatch(0, cnt_t, elem)
result.append(diff)
return result
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment