Commit 8ce9f162 authored by Tim Peters's avatar Tim Peters

PyUnicode_Join(): Two primary aims:

1. u1.join([u2]) is u2
2. Be more careful about C-level int overflow.

Since PySequence_Fast() isn't needed to achieve #1, it's not used -- but
the code could sure be simpler if it were.
parent 00f8da77
...@@ -3975,49 +3975,110 @@ int fixtitle(PyUnicodeObject *self) ...@@ -3975,49 +3975,110 @@ int fixtitle(PyUnicodeObject *self)
return 1; return 1;
} }
PyObject *PyUnicode_Join(PyObject *separator, PyObject *
PyObject *seq) PyUnicode_Join(PyObject *separator, PyObject *seq)
{ {
PyObject *internal_separator = NULL;
Py_UNICODE *sep; Py_UNICODE *sep;
int seplen; size_t seplen;
PyUnicodeObject *res = NULL; PyUnicodeObject *res = NULL;
int reslen = 0; size_t sz; /* # allocated bytes for string in res */
Py_UNICODE *p; size_t reslen; /* # used bytes */
int sz = 100; Py_UNICODE *p; /* pointer to free byte in res's string area */
PyObject *it; /* iterator */
PyObject *item;
int i; int i;
PyObject *it; PyObject *temp;
it = PyObject_GetIter(seq); it = PyObject_GetIter(seq);
if (it == NULL) if (it == NULL)
return NULL; return NULL;
if (separator == NULL) { item = PyIter_Next(it);
Py_UNICODE blank = ' '; if (item == NULL) {
sep = ␣ if (PyErr_Occurred())
seplen = 1; goto onError;
/* empty sequence; return u"" */
res = _PyUnicode_New(0);
goto Done;
} }
else {
separator = PyUnicode_FromObject(separator); /* If this is the only item, maybe we can get out cheap. */
if (separator == NULL) res = (PyUnicodeObject *)item;
item = PyIter_Next(it);
if (item == NULL) {
if (PyErr_Occurred())
goto onError; goto onError;
sep = PyUnicode_AS_UNICODE(separator); /* There's only one item in the sequence. */
seplen = PyUnicode_GET_SIZE(separator); if (PyUnicode_CheckExact(res)) /* whatever.join([u]) -> u */
goto Done;
} }
/* There are at least two to join (item != NULL), or there's only
* one but it's not an exact Unicode (item == NULL). res needs
* conversion to Unicode in either case.
* Caution: we may need to ensure a copy is made, and that's trickier
* than it sounds because, e.g., PyUnicode_FromObject() may return
* a shared object (which must not be mutated).
*/
if (! PyUnicode_Check(res) && ! PyString_Check(res)) {
PyErr_Format(PyExc_TypeError,
"sequence item 0: expected string or Unicode,"
" %.80s found",
res->ob_type->tp_name);
Py_XDECREF(item);
goto onError;
}
temp = PyUnicode_FromObject((PyObject *)res);
if (temp == NULL) {
Py_XDECREF(item);
goto onError;
}
Py_DECREF(res);
if (item == NULL) {
/* res was the only item */
res = (PyUnicodeObject *)temp;
goto Done;
}
/* There are at least two items. As above, temp may be a shared object,
* so we need to copy it.
*/
reslen = PyUnicode_GET_SIZE(temp);
sz = reslen + 100; /* breathing room */
if (sz < reslen || sz > INT_MAX) /* overflow -- no breathing room */
sz = reslen;
res = _PyUnicode_New(sz); res = _PyUnicode_New(sz);
if (res == NULL) if (res == NULL) {
Py_DECREF(item);
goto onError; goto onError;
}
p = PyUnicode_AS_UNICODE(res); p = PyUnicode_AS_UNICODE(res);
reslen = 0; Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(temp), (int)reslen);
p += reslen;
Py_DECREF(temp);
for (i = 0; ; ++i) { if (separator == NULL) {
int itemlen; Py_UNICODE blank = ' ';
PyObject *item = PyIter_Next(it); sep = &blank;
if (item == NULL) { seplen = 1;
if (PyErr_Occurred()) }
else {
internal_separator = PyUnicode_FromObject(separator);
if (internal_separator == NULL) {
Py_DECREF(item);
goto onError; goto onError;
break;
} }
sep = PyUnicode_AS_UNICODE(internal_separator);
seplen = PyUnicode_GET_SIZE(internal_separator);
}
i = 1;
do {
size_t itemlen;
size_t newreslen;
/* Catenate the separator, then item. */
/* First convert item to Unicode. */
if (!PyUnicode_Check(item)) { if (!PyUnicode_Check(item)) {
PyObject *v; PyObject *v;
if (!PyString_Check(item)) { if (!PyString_Check(item)) {
...@@ -4034,36 +4095,55 @@ PyObject *PyUnicode_Join(PyObject *separator, ...@@ -4034,36 +4095,55 @@ PyObject *PyUnicode_Join(PyObject *separator,
if (item == NULL) if (item == NULL)
goto onError; goto onError;
} }
/* Make sure we have enough space for the separator and the item. */
itemlen = PyUnicode_GET_SIZE(item); itemlen = PyUnicode_GET_SIZE(item);
while (reslen + itemlen + seplen >= sz) { newreslen = reslen + seplen + itemlen;
if (_PyUnicode_Resize(&res, sz*2) < 0) { if (newreslen < reslen || newreslen > INT_MAX)
goto Overflow;
if (newreslen > sz) {
do {
size_t oldsize = sz;
sz += sz;
if (sz < oldsize || sz > INT_MAX)
goto Overflow;
} while (newreslen > sz);
if (_PyUnicode_Resize(&res, (int)sz) < 0) {
Py_DECREF(item); Py_DECREF(item);
goto onError; goto onError;
} }
sz *= 2;
p = PyUnicode_AS_UNICODE(res) + reslen; p = PyUnicode_AS_UNICODE(res) + reslen;
} }
if (i > 0) { Py_UNICODE_COPY(p, sep, (int)seplen);
Py_UNICODE_COPY(p, sep, seplen);
p += seplen; p += seplen;
reslen += seplen; Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(item), (int)itemlen);
}
Py_UNICODE_COPY(p, PyUnicode_AS_UNICODE(item), itemlen);
p += itemlen; p += itemlen;
reslen += itemlen;
Py_DECREF(item); Py_DECREF(item);
} reslen = newreslen;
if (_PyUnicode_Resize(&res, reslen) < 0)
++i;
item = PyIter_Next(it);
} while (item != NULL);
if (PyErr_Occurred())
goto onError;
if (_PyUnicode_Resize(&res, (int)reslen) < 0)
goto onError; goto onError;
Py_XDECREF(separator); Done:
Py_XDECREF(internal_separator);
Py_DECREF(it); Py_DECREF(it);
return (PyObject *)res; return (PyObject *)res;
Overflow:
PyErr_SetString(PyExc_OverflowError,
"join() is too long for a Python string");
Py_DECREF(item);
/* fall through */
onError: onError:
Py_XDECREF(separator); Py_XDECREF(internal_separator);
Py_XDECREF(res);
Py_DECREF(it); Py_DECREF(it);
Py_XDECREF(res);
return NULL; return NULL;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment