Commit 19fe14e7 authored by Tim Peters's avatar Tim Peters

Derivative of patch #102549, "simpler, faster(!) implementation of string.join".

Also fixes two long-standing bugs (present in 2.0):
1. .join() didn't check that the result size fit in an int.
2. string.join(s) when len(s)==1 returned s[0] regardless of s[0]'s
   type; e.g., "".join([3]) returned 3 (overly optimistic optimization).
I resisted a keen temptation to make .join() apply str() automagically.
parent e3d6e41d
...@@ -794,46 +794,55 @@ static PyObject * ...@@ -794,46 +794,55 @@ static PyObject *
string_join(PyStringObject *self, PyObject *args) string_join(PyStringObject *self, PyObject *args)
{ {
char *sep = PyString_AS_STRING(self); char *sep = PyString_AS_STRING(self);
int seplen = PyString_GET_SIZE(self); const int seplen = PyString_GET_SIZE(self);
PyObject *res = NULL; PyObject *res = NULL;
int reslen = 0;
char *p; char *p;
int seqlen = 0; int seqlen = 0;
int sz = 100; size_t sz = 0;
int i, slen, sz_incr; int i;
PyObject *orig, *seq, *item; PyObject *orig, *seq, *item;
if (!PyArg_ParseTuple(args, "O:join", &orig)) if (!PyArg_ParseTuple(args, "O:join", &orig))
return NULL; return NULL;
if (!(seq = PySequence_Fast(orig, ""))) { seq = PySequence_Fast(orig, "");
if (seq == NULL) {
if (PyErr_ExceptionMatches(PyExc_TypeError)) if (PyErr_ExceptionMatches(PyExc_TypeError))
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"sequence expected, %.80s found", "sequence expected, %.80s found",
orig->ob_type->tp_name); orig->ob_type->tp_name);
return NULL; return NULL;
} }
/* From here on out, errors go through finally: for proper
* reference count manipulations.
*/
seqlen = PySequence_Size(seq); seqlen = PySequence_Size(seq);
if (seqlen == 0) {
Py_DECREF(seq);
return PyString_FromString("");
}
if (seqlen == 1) { if (seqlen == 1) {
item = PySequence_Fast_GET_ITEM(seq, 0); item = PySequence_Fast_GET_ITEM(seq, 0);
if (!PyString_Check(item) && !PyUnicode_Check(item)) {
PyErr_Format(PyExc_TypeError,
"sequence item 0: expected string,"
" %.80s found",
item->ob_type->tp_name);
Py_DECREF(seq);
return NULL;
}
Py_INCREF(item); Py_INCREF(item);
Py_DECREF(seq); Py_DECREF(seq);
return item; return item;
} }
if (!(res = PyString_FromStringAndSize((char*)NULL, sz))) /* There are at least two things to join. Do a pre-pass to figure out
goto finally; * the total amount of space we'll need (sz), see whether any argument
* is absurd, and defer to the Unicode join if appropriate.
p = PyString_AS_STRING(res); */
for (i = 0; i < seqlen; i++) { for (i = 0; i < seqlen; i++) {
const size_t old_sz = sz;
item = PySequence_Fast_GET_ITEM(seq, i); item = PySequence_Fast_GET_ITEM(seq, i);
if (!PyString_Check(item)){ if (!PyString_Check(item)){
if (PyUnicode_Check(item)) { if (PyUnicode_Check(item)) {
Py_DECREF(res);
Py_DECREF(seq); Py_DECREF(seq);
return PyUnicode_Join((PyObject *)self, orig); return PyUnicode_Join((PyObject *)self, orig);
} }
...@@ -841,40 +850,45 @@ string_join(PyStringObject *self, PyObject *args) ...@@ -841,40 +850,45 @@ string_join(PyStringObject *self, PyObject *args)
"sequence item %i: expected string," "sequence item %i: expected string,"
" %.80s found", " %.80s found",
i, item->ob_type->tp_name); i, item->ob_type->tp_name);
goto finally; Py_DECREF(seq);
return NULL;
} }
slen = PyString_GET_SIZE(item); sz += PyString_GET_SIZE(item);
while (reslen + slen + seplen >= sz) { if (i != 0)
/* at least double the size of the string */ sz += seplen;
sz_incr = slen + seplen > sz ? slen + seplen : sz; if (sz < old_sz || sz > INT_MAX) {
if (_PyString_Resize(&res, sz + sz_incr)) { PyErr_SetString(PyExc_OverflowError,
goto finally; "join() is too long for a Python string");
} Py_DECREF(seq);
sz += sz_incr; return NULL;
p = PyString_AS_STRING(res) + reslen;
} }
if (i > 0) { }
/* Allocate result space. */
res = PyString_FromStringAndSize((char*)NULL, (int)sz);
if (res == NULL) {
Py_DECREF(seq);
return NULL;
}
/* Catenate everything. */
p = PyString_AS_STRING(res);
for (i = 0; i < seqlen; ++i) {
size_t n;
item = PySequence_Fast_GET_ITEM(seq, i);
n = PyString_GET_SIZE(item);
memcpy(p, PyString_AS_STRING(item), n);
p += n;
if (i < seqlen - 1) {
memcpy(p, sep, seplen); memcpy(p, sep, seplen);
p += seplen; p += seplen;
reslen += seplen;
} }
memcpy(p, PyString_AS_STRING(item), slen);
p += slen;
reslen += slen;
} }
if (_PyString_Resize(&res, reslen))
goto finally;
Py_DECREF(seq);
return res;
finally:
Py_DECREF(seq); Py_DECREF(seq);
Py_XDECREF(res); return res;
return NULL;
} }
static long static long
string_find_internal(PyStringObject *self, PyObject *args, int dir) string_find_internal(PyStringObject *self, PyObject *args, int dir)
{ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment