Commit 6b3b0fc4 authored by Raymond Hettinger's avatar Raymond Hettinger

Forward port r68941 adding itertools.compress().

parent ace67339
...@@ -286,7 +286,7 @@ counts less than one:: ...@@ -286,7 +286,7 @@ counts less than one::
Section 4.6.3, Exercise 19*\. Section 4.6.3, Exercise 19*\.
* To enumerate all distinct multisets of a given size over a given set of * To enumerate all distinct multisets of a given size over a given set of
elements, see the :func:`combinations_with_replacement` function in the elements, see :func:`combinations_with_replacement` in the
:ref:`itertools-recipes` for itertools:: :ref:`itertools-recipes` for itertools::
map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC
......
...@@ -133,6 +133,20 @@ loops that truncate the stream. ...@@ -133,6 +133,20 @@ loops that truncate the stream.
The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n`` The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
or zero when ``r > n``. or zero when ``r > n``.
.. function:: compress(data, selectors)
Make an iterator that filters elements from *data* returning only those that
have a corresponding element in *selectors* that evaluates to ``True``.
Stops when either the *data* or *selectors* iterables have been exhausted.
Equivalent to::
def compress(data, selectors):
# compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
return (d for d, s in zip(data, selectors) if s)
.. versionadded:: 2.7
.. function:: count([n]) .. function:: count([n])
Make an iterator that returns consecutive integers starting with *n*. If not Make an iterator that returns consecutive integers starting with *n*. If not
...@@ -594,10 +608,6 @@ which incur interpreter overhead. ...@@ -594,10 +608,6 @@ which incur interpreter overhead.
s = list(iterable) s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def compress(data, selectors):
"compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
return (d for d, s in zip(data, selectors) if s)
def combinations_with_replacement(iterable, r): def combinations_with_replacement(iterable, r):
"combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC" "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
# number items returned: (n+r-1)! / r! / (n-1)! # number items returned: (n+r-1)! / r! / (n-1)!
......
...@@ -195,6 +195,21 @@ class TestBasicOps(unittest.TestCase): ...@@ -195,6 +195,21 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1) self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
def test_compress(self):
self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
self.assertEqual(list(compress('ABCDEF', [0,0,0,0,0,0])), list(''))
self.assertEqual(list(compress('ABCDEF', [1,1,1,1,1,1])), list('ABCDEF'))
self.assertEqual(list(compress('ABCDEF', [1,0,1])), list('AC'))
self.assertEqual(list(compress('ABC', [0,1,1,1,1,1])), list('BC'))
n = 10000
data = chain.from_iterable(repeat(range(6), n))
selectors = chain.from_iterable(repeat((0, 1)))
self.assertEqual(list(compress(data, selectors)), [1,3,5] * n)
self.assertRaises(TypeError, compress, None, range(6)) # 1st arg not iterable
self.assertRaises(TypeError, compress, range(6), None) # 2nd arg not iterable
self.assertRaises(TypeError, compress, range(6)) # too few args
self.assertRaises(TypeError, compress, range(6), None) # too many args
def test_count(self): def test_count(self):
self.assertEqual(lzip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(lzip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
self.assertEqual(lzip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)]) self.assertEqual(lzip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
...@@ -715,6 +730,9 @@ class TestExamples(unittest.TestCase): ...@@ -715,6 +730,9 @@ class TestExamples(unittest.TestCase):
self.assertEqual(list(combinations(range(4), 3)), self.assertEqual(list(combinations(range(4), 3)),
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
def test_compress(self):
self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
def test_count(self): def test_count(self):
self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14]) self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14])
...@@ -795,6 +813,10 @@ class TestGC(unittest.TestCase): ...@@ -795,6 +813,10 @@ class TestGC(unittest.TestCase):
a = [] a = []
self.makecycle(combinations([1,2,a,3], 3), a) self.makecycle(combinations([1,2,a,3], 3), a)
def test_compress(self):
a = []
self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
def test_cycle(self): def test_cycle(self):
a = [] a = []
self.makecycle(cycle([a]*2), a) self.makecycle(cycle([a]*2), a)
...@@ -948,6 +970,15 @@ class TestVariousIteratorArgs(unittest.TestCase): ...@@ -948,6 +970,15 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, list, chain(N(s))) self.assertRaises(TypeError, list, chain(N(s)))
self.assertRaises(ZeroDivisionError, list, chain(E(s))) self.assertRaises(ZeroDivisionError, list, chain(E(s)))
def test_compress(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
n = len(s)
for g in (G, I, Ig, S, L, R):
self.assertEqual(list(compress(g(s), repeat(1))), list(g(s)))
self.assertRaises(TypeError, compress, X(s), repeat(1))
self.assertRaises(TypeError, compress, N(s), repeat(1))
self.assertRaises(ZeroDivisionError, list, compress(E(s), repeat(1)))
def test_product(self): def test_product(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
self.assertRaises(TypeError, product, X(s)) self.assertRaises(TypeError, product, X(s))
...@@ -1144,7 +1175,7 @@ class SubclassWithKwargsTest(unittest.TestCase): ...@@ -1144,7 +1175,7 @@ class SubclassWithKwargsTest(unittest.TestCase):
def test_keywords_in_subclass(self): def test_keywords_in_subclass(self):
# count is not subclassable... # count is not subclassable...
for cls in (repeat, zip, filter, filterfalse, chain, map, for cls in (repeat, zip, filter, filterfalse, chain, map,
starmap, islice, takewhile, dropwhile, cycle): starmap, islice, takewhile, dropwhile, cycle, compress):
class Subclass(cls): class Subclass(cls):
def __init__(self, newarg=None, *args): def __init__(self, newarg=None, *args):
cls.__init__(self, *args) cls.__init__(self, *args)
...@@ -1281,10 +1312,6 @@ Samuele ...@@ -1281,10 +1312,6 @@ Samuele
... s = list(iterable) ... s = list(iterable)
... return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) ... return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
>>> def compress(data, selectors):
... "compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
... return (d for d, s in zip(data, selectors) if s)
>>> def combinations_with_replacement(iterable, r): >>> def combinations_with_replacement(iterable, r):
... "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" ... "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
... pool = tuple(iterable) ... pool = tuple(iterable)
...@@ -1380,9 +1407,6 @@ perform as purported. ...@@ -1380,9 +1407,6 @@ perform as purported.
>>> list(powerset([1,2,3])) >>> list(powerset([1,2,3]))
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
>>> list(compress('abcdef', [1,0,1,0,1,1]))
['a', 'c', 'e', 'f']
>>> list(combinations_with_replacement('abc', 2)) >>> list(combinations_with_replacement('abc', 2))
[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
......
...@@ -150,6 +150,8 @@ Library ...@@ -150,6 +150,8 @@ Library
- Issue #4863: distutils.mwerkscompiler has been removed. - Issue #4863: distutils.mwerkscompiler has been removed.
- Added a new function: itertools.compress().
- Fix and properly document the multiprocessing module's logging - Fix and properly document the multiprocessing module's logging
support, expose the internal levels and provide proper usage support, expose the internal levels and provide proper usage
examples. examples.
......
...@@ -2331,6 +2331,162 @@ static PyTypeObject permutations_type = { ...@@ -2331,6 +2331,162 @@ static PyTypeObject permutations_type = {
}; };
/* compress object ************************************************************/
/* Equivalent to:
def compress(data, selectors):
"compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
return (d for d, s in zip(data, selectors) if s)
*/
typedef struct {
PyObject_HEAD
PyObject *data;
PyObject *selectors;
} compressobject;
static PyTypeObject compress_type;
static PyObject *
compress_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
PyObject *seq1, *seq2;
PyObject *data=NULL, *selectors=NULL;
compressobject *lz;
if (type == &compress_type && !_PyArg_NoKeywords("compress()", kwds))
return NULL;
if (!PyArg_UnpackTuple(args, "compress", 2, 2, &seq1, &seq2))
return NULL;
data = PyObject_GetIter(seq1);
if (data == NULL)
goto fail;
selectors = PyObject_GetIter(seq2);
if (selectors == NULL)
goto fail;
/* create compressobject structure */
lz = (compressobject *)type->tp_alloc(type, 0);
if (lz == NULL)
goto fail;
lz->data = data;
lz->selectors = selectors;
return (PyObject *)lz;
fail:
Py_XDECREF(data);
Py_XDECREF(selectors);
return NULL;
}
static void
compress_dealloc(compressobject *lz)
{
PyObject_GC_UnTrack(lz);
Py_XDECREF(lz->data);
Py_XDECREF(lz->selectors);
Py_TYPE(lz)->tp_free(lz);
}
static int
compress_traverse(compressobject *lz, visitproc visit, void *arg)
{
Py_VISIT(lz->data);
Py_VISIT(lz->selectors);
return 0;
}
static PyObject *
compress_next(compressobject *lz)
{
PyObject *data = lz->data, *selectors = lz->selectors;
PyObject *datum, *selector;
PyObject *(*datanext)(PyObject *) = *Py_TYPE(data)->tp_iternext;
PyObject *(*selectornext)(PyObject *) = *Py_TYPE(selectors)->tp_iternext;
int ok;
while (1) {
/* Steps: get datum, get selector, evaluate selector.
Order is important (to match the pure python version
in terms of which input gets a chance to raise an
exception first).
*/
datum = datanext(data);
if (datum == NULL)
return NULL;
selector = selectornext(selectors);
if (selector == NULL) {
Py_DECREF(datum);
return NULL;
}
ok = PyObject_IsTrue(selector);
Py_DECREF(selector);
if (ok == 1)
return datum;
Py_DECREF(datum);
if (ok == -1)
return NULL;
}
}
PyDoc_STRVAR(compress_doc,
"compress(data sequence, selector sequence) --> iterator over selected data\n\
\n\
Return data elements corresponding to true selector elements.\n\
Forms a shorter iterator from selected data elements using the\n\
selectors to choose the data elements.");
static PyTypeObject compress_type = {
PyVarObject_HEAD_INIT(NULL, 0)
"itertools.compress", /* tp_name */
sizeof(compressobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)compress_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
compress_doc, /* tp_doc */
(traverseproc)compress_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)compress_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
compress_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* filterfalse object ************************************************************/ /* filterfalse object ************************************************************/
typedef struct { typedef struct {
...@@ -3041,6 +3197,7 @@ PyInit_itertools(void) ...@@ -3041,6 +3197,7 @@ PyInit_itertools(void)
&islice_type, &islice_type,
&starmap_type, &starmap_type,
&chain_type, &chain_type,
&compress_type,
&filterfalse_type, &filterfalse_type,
&count_type, &count_type,
&ziplongest_type, &ziplongest_type,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment