Commit 66f91ea9 authored by Raymond Hettinger's avatar Raymond Hettinger

C implementation of itertools.permutations().

parent 9c065740
...@@ -47,15 +47,6 @@ def fact(n): ...@@ -47,15 +47,6 @@ def fact(n):
'Factorial' 'Factorial'
return prod(range(1, n+1)) return prod(range(1, n+1))
def permutations(iterable, r=None):
# XXX use this until real permutations code is added
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
class TestBasicOps(unittest.TestCase): class TestBasicOps(unittest.TestCase):
def test_chain(self): def test_chain(self):
self.assertEqual(list(chain('abc', 'def')), list('abcdef')) self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
...@@ -117,6 +108,8 @@ class TestBasicOps(unittest.TestCase): ...@@ -117,6 +108,8 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(c)), r) # no duplicate elements self.assertEqual(len(set(c)), r) # no duplicate elements
self.assertEqual(list(c), sorted(c)) # keep original ordering self.assertEqual(list(c), sorted(c)) # keep original ordering
self.assert_(all(e in values for e in c)) # elements taken from input iterable self.assert_(all(e in values for e in c)) # elements taken from input iterable
self.assertEqual(list(c),
[e for e in values if e in c]) # comb is a subsequence of the input iterable
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
...@@ -127,9 +120,10 @@ class TestBasicOps(unittest.TestCase): ...@@ -127,9 +120,10 @@ class TestBasicOps(unittest.TestCase):
def test_permutations(self): def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
## self.assertRaises(TypeError, permutations, None) # pool is not iterable self.assertRaises(TypeError, permutations, None) # pool is not iterable
## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
self.assertEqual(list(permutations(range(3), 2)), self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
...@@ -182,7 +176,7 @@ class TestBasicOps(unittest.TestCase): ...@@ -182,7 +176,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(permutations(values))) # test default r self.assertEqual(result, list(permutations(values))) # test default r
# Test implementation detail: tuple re-use # Test implementation detail: tuple re-use
## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1) self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
def test_count(self): def test_count(self):
...@@ -407,12 +401,23 @@ class TestBasicOps(unittest.TestCase): ...@@ -407,12 +401,23 @@ class TestBasicOps(unittest.TestCase):
list(product(*args, **dict(repeat=r)))) list(product(*args, **dict(repeat=r))))
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
self.assertRaises(TypeError, product, range(6), None) self.assertRaises(TypeError, product, range(6), None)
def product2(*args, **kwds):
'Pure python version used in docs'
pools = map(tuple, args) * kwds.get('repeat', 1)
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3), argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3),
set('abcdefg'), range(11), tuple(range(13))] set('abcdefg'), range(11), tuple(range(13))]
for i in range(100): for i in range(100):
args = [random.choice(argtypes) for j in range(random.randrange(5))] args = [random.choice(argtypes) for j in range(random.randrange(5))]
expected_len = prod(map(len, args)) expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len) self.assertEqual(len(list(product(*args))), expected_len)
self.assertEqual(list(product(*args)), list(product2(*args)))
args = map(iter, args) args = map(iter, args)
self.assertEqual(len(list(product(*args))), expected_len) self.assertEqual(len(list(product(*args))), expected_len)
......
...@@ -699,7 +699,7 @@ Library ...@@ -699,7 +699,7 @@ Library
- Added itertools.product() which forms the Cartesian product of - Added itertools.product() which forms the Cartesian product of
the input iterables. the input iterables.
- Added itertools.combinations(). - Added itertools.combinations() and itertools.permutations().
- Patch #1541463: optimize performance of cgi.FieldStorage operations. - Patch #1541463: optimize performance of cgi.FieldStorage operations.
......
...@@ -2238,6 +2238,279 @@ static PyTypeObject combinations_type = { ...@@ -2238,6 +2238,279 @@ static PyTypeObject combinations_type = {
}; };
/* permutations object ************************************************************
def permutations(iterable, r=None):
'permutations(range(3), 2) --> (0,1) (0,2) (1,0) (1,2) (2,0) (2,1)'
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
indices = range(n)
cycles = range(n-r+1, n+1)[::-1]
yield tuple(pool[i] for i in indices[:r])
while n:
for i in reversed(range(r)):
cycles[i] -= 1
if cycles[i] == 0:
indices[i:] = indices[i+1:] + indices[i:i+1]
cycles[i] = n - i
else:
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield tuple(pool[i] for i in indices[:r])
break
else:
return
*/
typedef struct {
PyObject_HEAD
PyObject *pool; /* input converted to a tuple */
Py_ssize_t *indices; /* one index per element in the pool */
Py_ssize_t *cycles; /* one rollover counter per element in the result */
PyObject *result; /* most recently returned result tuple */
Py_ssize_t r; /* size of result tuple */
int stopped; /* set to 1 when the permutations iterator is exhausted */
} permutationsobject;
static PyTypeObject permutations_type;
static PyObject *
permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
permutationsobject *po;
Py_ssize_t n;
Py_ssize_t r;
PyObject *robj = Py_None;
PyObject *pool = NULL;
PyObject *iterable = NULL;
Py_ssize_t *indices = NULL;
Py_ssize_t *cycles = NULL;
Py_ssize_t i;
static char *kwargs[] = {"iterable", "r", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:permutations", kwargs,
&iterable, &robj))
return NULL;
pool = PySequence_Tuple(iterable);
if (pool == NULL)
goto error;
n = PyTuple_GET_SIZE(pool);
r = n;
if (robj != Py_None) {
r = PyInt_AsSsize_t(robj);
if (r == -1 && PyErr_Occurred())
goto error;
}
if (r < 0) {
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
goto error;
}
if (r > n) {
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
goto error;
}
indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
if (indices == NULL || cycles == NULL) {
PyErr_NoMemory();
goto error;
}
for (i=0 ; i<n ; i++)
indices[i] = i;
for (i=0 ; i<r ; i++)
cycles[i] = n - i;
/* create permutationsobject structure */
po = (permutationsobject *)type->tp_alloc(type, 0);
if (po == NULL)
goto error;
po->pool = pool;
po->indices = indices;
po->cycles = cycles;
po->result = NULL;
po->r = r;
po->stopped = 0;
return (PyObject *)po;
error:
if (indices != NULL)
PyMem_Free(indices);
if (cycles != NULL)
PyMem_Free(cycles);
Py_XDECREF(pool);
return NULL;
}
static void
permutations_dealloc(permutationsobject *po)
{
PyObject_GC_UnTrack(po);
Py_XDECREF(po->pool);
Py_XDECREF(po->result);
PyMem_Free(po->indices);
PyMem_Free(po->cycles);
Py_TYPE(po)->tp_free(po);
}
static int
permutations_traverse(permutationsobject *po, visitproc visit, void *arg)
{
if (po->pool != NULL)
Py_VISIT(po->pool);
if (po->result != NULL)
Py_VISIT(po->result);
return 0;
}
static PyObject *
permutations_next(permutationsobject *po)
{
PyObject *elem;
PyObject *oldelem;
PyObject *pool = po->pool;
Py_ssize_t *indices = po->indices;
Py_ssize_t *cycles = po->cycles;
PyObject *result = po->result;
Py_ssize_t n = PyTuple_GET_SIZE(pool);
Py_ssize_t r = po->r;
Py_ssize_t i, j, k, index;
if (po->stopped)
return NULL;
if (result == NULL) {
/* On the first pass, initialize result tuple using the indices */
result = PyTuple_New(r);
if (result == NULL)
goto empty;
po->result = result;
for (i=0; i<r ; i++) {
index = indices[i];
elem = PyTuple_GET_ITEM(pool, index);
Py_INCREF(elem);
PyTuple_SET_ITEM(result, i, elem);
}
} else {
if (n == 0)
goto empty;
/* Copy the previous result tuple or re-use it if available */
if (Py_REFCNT(result) > 1) {
PyObject *old_result = result;
result = PyTuple_New(r);
if (result == NULL)
goto empty;
po->result = result;
for (i=0; i<r ; i++) {
elem = PyTuple_GET_ITEM(old_result, i);
Py_INCREF(elem);
PyTuple_SET_ITEM(result, i, elem);
}
Py_DECREF(old_result);
}
/* Now, we've got the only copy so we can update it in-place */
assert(r == 0 || Py_REFCNT(result) == 1);
/* Decrement rightmost cycle, moving leftward upon zero rollover */
for (i=r-1 ; i>=0 ; i--) {
cycles[i] -= 1;
if (cycles[i] == 0) {
/* rotatation: indices[i:] = indices[i+1:] + indices[i:i+1] */
index = indices[i];
for (j=i ; j<n-1 ; j++)
indices[j] = indices[j+1];
indices[n-1] = index;
cycles[i] = n - i;
} else {
j = cycles[i];
index = indices[i];
indices[i] = indices[n-j];
indices[n-j] = index;
for (k=i; k<r ; k++) {
/* start with i, the leftmost element that changed */
/* yield tuple(pool[k] for k in indices[:r]) */
index = indices[k];
elem = PyTuple_GET_ITEM(pool, index);
Py_INCREF(elem);
oldelem = PyTuple_GET_ITEM(result, k);
PyTuple_SET_ITEM(result, k, elem);
Py_DECREF(oldelem);
}
break;
}
}
/* If i is negative, then the cycles have all
rolled-over and we're done. */
if (i < 0)
goto empty;
}
Py_INCREF(result);
return result;
empty:
po->stopped = 1;
return NULL;
}
PyDoc_STRVAR(permutations_doc,
"permutations(iterables[, r]) --> permutations object\n\
\n\
Return successive r-length permutations of elements in the iterable.\n\n\
permutations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)");
static PyTypeObject permutations_type = {
PyVarObject_HEAD_INIT(NULL, 0)
"itertools.permutations", /* tp_name */
sizeof(permutationsobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)permutations_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
permutations_doc, /* tp_doc */
(traverseproc)permutations_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)permutations_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
permutations_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* ifilter object ************************************************************/ /* ifilter object ************************************************************/
typedef struct { typedef struct {
...@@ -3295,6 +3568,7 @@ inititertools(void) ...@@ -3295,6 +3568,7 @@ inititertools(void)
&count_type, &count_type,
&izip_type, &izip_type,
&iziplongest_type, &iziplongest_type,
&permutations_type,
&product_type, &product_type,
&repeat_type, &repeat_type,
&groupby_type, &groupby_type,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment