Commit eeb896c4 authored by Martin Panter's avatar Martin Panter

Issue #24802: Copy bytes-like objects to null-terminated buffers if necessary

This avoids possible buffer overreads when int(), float(), compile(), exec()
and eval() are passed bytes-like objects. Similar code is removed from the
complex() constructor, where it was not reachable.

Patch by John Leitch, Serhiy Storchaka and Martin Panter.
parent 9ad0aae6
...@@ -530,6 +530,27 @@ if 1: ...@@ -530,6 +530,27 @@ if 1:
check_limit("a", "[0]") check_limit("a", "[0]")
check_limit("a", "*a") check_limit("a", "*a")
def test_null_terminated(self):
# The source code is null-terminated internally, but bytes-like
# objects are accepted, which could be not terminated.
# Exception changed from TypeError to ValueError in 3.5
with self.assertRaisesRegex(Exception, "cannot contain null"):
compile("123\x00", "<dummy>", "eval")
with self.assertRaisesRegex(Exception, "cannot contain null"):
compile(memoryview(b"123\x00"), "<dummy>", "eval")
code = compile(memoryview(b"123\x00")[1:-1], "<dummy>", "eval")
self.assertEqual(eval(code), 23)
code = compile(memoryview(b"1234")[1:-1], "<dummy>", "eval")
self.assertEqual(eval(code), 23)
code = compile(memoryview(b"$23$")[1:-1], "<dummy>", "eval")
self.assertEqual(eval(code), 23)
# Also test when eval() and exec() do the compilation step
self.assertEqual(eval(memoryview(b"1234")[1:-1]), 23)
namespace = dict()
exec(memoryview(b"ax = 123")[1:-1], namespace)
self.assertEqual(namespace['x'], 12)
class TestStackSize(unittest.TestCase): class TestStackSize(unittest.TestCase):
# These tests check that the computed stack size for a code object # These tests check that the computed stack size for a code object
......
...@@ -31,7 +31,6 @@ class GeneralFloatCases(unittest.TestCase): ...@@ -31,7 +31,6 @@ class GeneralFloatCases(unittest.TestCase):
self.assertEqual(float(3.14), 3.14) self.assertEqual(float(3.14), 3.14)
self.assertEqual(float(314), 314.0) self.assertEqual(float(314), 314.0)
self.assertEqual(float(" 3.14 "), 3.14) self.assertEqual(float(" 3.14 "), 3.14)
self.assertEqual(float(b" 3.14 "), 3.14)
self.assertRaises(ValueError, float, " 0x3.1 ") self.assertRaises(ValueError, float, " 0x3.1 ")
self.assertRaises(ValueError, float, " -0x3.p-1 ") self.assertRaises(ValueError, float, " -0x3.p-1 ")
self.assertRaises(ValueError, float, " +0x3.p-1 ") self.assertRaises(ValueError, float, " +0x3.p-1 ")
...@@ -43,7 +42,6 @@ class GeneralFloatCases(unittest.TestCase): ...@@ -43,7 +42,6 @@ class GeneralFloatCases(unittest.TestCase):
self.assertRaises(ValueError, float, "+.inf") self.assertRaises(ValueError, float, "+.inf")
self.assertRaises(ValueError, float, ".") self.assertRaises(ValueError, float, ".")
self.assertRaises(ValueError, float, "-.") self.assertRaises(ValueError, float, "-.")
self.assertRaises(ValueError, float, b"-")
self.assertRaises(TypeError, float, {}) self.assertRaises(TypeError, float, {})
self.assertRaisesRegex(TypeError, "not 'dict'", float, {}) self.assertRaisesRegex(TypeError, "not 'dict'", float, {})
# Lone surrogate # Lone surrogate
...@@ -57,6 +55,42 @@ class GeneralFloatCases(unittest.TestCase): ...@@ -57,6 +55,42 @@ class GeneralFloatCases(unittest.TestCase):
float(b'.' + b'1'*1000) float(b'.' + b'1'*1000)
float('.' + '1'*1000) float('.' + '1'*1000)
def test_non_numeric_input_types(self):
# Test possible non-numeric types for the argument x, including
# subclasses of the explicitly documented accepted types.
class CustomStr(str): pass
class CustomBytes(bytes): pass
class CustomByteArray(bytearray): pass
factories = [
bytes,
bytearray,
lambda b: CustomStr(b.decode()),
CustomBytes,
CustomByteArray,
memoryview,
]
try:
from array import array
except ImportError:
pass
else:
factories.append(lambda b: array('B', b))
for f in factories:
x = f(b" 3.14 ")
with self.subTest(type(x)):
self.assertEqual(float(x), 3.14)
with self.assertRaisesRegex(ValueError, "could not convert"):
float(f(b'A' * 0x10))
def test_float_memoryview(self):
self.assertEqual(float(memoryview(b'12.3')[1:4]), 2.3)
self.assertEqual(float(memoryview(b'12.3\x00')[1:4]), 2.3)
self.assertEqual(float(memoryview(b'12.3 ')[1:4]), 2.3)
self.assertEqual(float(memoryview(b'12.3A')[1:4]), 2.3)
self.assertEqual(float(memoryview(b'12.34')[1:4]), 2.3)
def test_error_message(self): def test_error_message(self):
testlist = ('\xbd', '123\xbd', ' 123 456 ') testlist = ('\xbd', '123\xbd', ' 123 456 ')
for s in testlist: for s in testlist:
......
...@@ -276,16 +276,40 @@ class IntTestCases(unittest.TestCase): ...@@ -276,16 +276,40 @@ class IntTestCases(unittest.TestCase):
class CustomBytes(bytes): pass class CustomBytes(bytes): pass
class CustomByteArray(bytearray): pass class CustomByteArray(bytearray): pass
values = [b'100', factories = [
bytearray(b'100'), bytes,
CustomStr('100'), bytearray,
CustomBytes(b'100'), lambda b: CustomStr(b.decode()),
CustomByteArray(b'100')] CustomBytes,
CustomByteArray,
for x in values: memoryview,
msg = 'x has type %s' % type(x).__name__ ]
self.assertEqual(int(x), 100, msg=msg) try:
self.assertEqual(int(x, 2), 4, msg=msg) from array import array
except ImportError:
pass
else:
factories.append(lambda b: array('B', b))
for f in factories:
x = f(b'100')
with self.subTest(type(x)):
self.assertEqual(int(x), 100)
if isinstance(x, (str, bytes, bytearray)):
self.assertEqual(int(x, 2), 4)
else:
msg = "can't convert non-string"
with self.assertRaisesRegex(TypeError, msg):
int(x, 2)
with self.assertRaisesRegex(ValueError, 'invalid literal'):
int(f(b'A' * 0x10))
def test_int_memoryview(self):
self.assertEqual(int(memoryview(b'123')[1:3]), 23)
self.assertEqual(int(memoryview(b'123\x00')[1:3]), 23)
self.assertEqual(int(memoryview(b'123 ')[1:3]), 23)
self.assertEqual(int(memoryview(b'123A')[1:3]), 23)
self.assertEqual(int(memoryview(b'1234')[1:3]), 23)
def test_string_float(self): def test_string_float(self):
self.assertRaises(ValueError, int, '1.2') self.assertRaises(ValueError, int, '1.2')
......
...@@ -10,6 +10,10 @@ Release date: tba ...@@ -10,6 +10,10 @@ Release date: tba
Core and Builtins Core and Builtins
----------------- -----------------
- Issue #24802: Avoid buffer overreads when int(), float(), compile(), exec()
and eval() are passed bytes-like objects. These objects are not
necessarily terminated by a null byte, but the functions assumed they were.
- Issue #24402: Fix input() to prompt to the redirected stdout when - Issue #24402: Fix input() to prompt to the redirected stdout when
sys.stdout.fileno() fails. sys.stdout.fileno() fails.
......
...@@ -1264,12 +1264,30 @@ PyNumber_Long(PyObject *o) ...@@ -1264,12 +1264,30 @@ PyNumber_Long(PyObject *o)
/* The below check is done in PyLong_FromUnicode(). */ /* The below check is done in PyLong_FromUnicode(). */
return PyLong_FromUnicodeObject(o, 10); return PyLong_FromUnicodeObject(o, 10);
if (PyObject_GetBuffer(o, &view, PyBUF_SIMPLE) == 0) { if (PyBytes_Check(o))
/* need to do extra error checking that PyLong_FromString() /* need to do extra error checking that PyLong_FromString()
* doesn't do. In particular int('9\x005') must raise an * doesn't do. In particular int('9\x005') must raise an
* exception, not truncate at the null. * exception, not truncate at the null.
*/ */
PyObject *result = _PyLong_FromBytes(view.buf, view.len, 10); return _PyLong_FromBytes(PyBytes_AS_STRING(o),
PyBytes_GET_SIZE(o), 10);
if (PyByteArray_Check(o))
return _PyLong_FromBytes(PyByteArray_AS_STRING(o),
PyByteArray_GET_SIZE(o), 10);
if (PyObject_GetBuffer(o, &view, PyBUF_SIMPLE) == 0) {
PyObject *result, *bytes;
/* Copy to NUL-terminated buffer. */
bytes = PyBytes_FromStringAndSize((const char *)view.buf, view.len);
if (bytes == NULL) {
PyBuffer_Release(&view);
return NULL;
}
result = _PyLong_FromBytes(PyBytes_AS_STRING(bytes),
PyBytes_GET_SIZE(bytes), 10);
Py_DECREF(bytes);
PyBuffer_Release(&view); PyBuffer_Release(&view);
return result; return result;
} }
......
...@@ -767,7 +767,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v) ...@@ -767,7 +767,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v)
int got_bracket=0; int got_bracket=0;
PyObject *s_buffer = NULL; PyObject *s_buffer = NULL;
Py_ssize_t len; Py_ssize_t len;
Py_buffer view = {NULL, NULL};
if (PyUnicode_Check(v)) { if (PyUnicode_Check(v)) {
s_buffer = _PyUnicode_TransformDecimalAndSpaceToASCII(v); s_buffer = _PyUnicode_TransformDecimalAndSpaceToASCII(v);
...@@ -777,10 +776,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v) ...@@ -777,10 +776,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v)
if (s == NULL) if (s == NULL)
goto error; goto error;
} }
else if (PyObject_GetBuffer(v, &view, PyBUF_SIMPLE) == 0) {
s = (const char *)view.buf;
len = view.len;
}
else { else {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"complex() argument must be a string or a number, not '%.200s'", "complex() argument must be a string or a number, not '%.200s'",
...@@ -895,7 +890,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v) ...@@ -895,7 +890,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v)
if (s-start != len) if (s-start != len)
goto parse_error; goto parse_error;
PyBuffer_Release(&view);
Py_XDECREF(s_buffer); Py_XDECREF(s_buffer);
return complex_subtype_from_doubles(type, x, y); return complex_subtype_from_doubles(type, x, y);
...@@ -903,7 +897,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v) ...@@ -903,7 +897,6 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v)
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"complex() arg is a malformed string"); "complex() arg is a malformed string");
error: error:
PyBuffer_Release(&view);
Py_XDECREF(s_buffer); Py_XDECREF(s_buffer);
return NULL; return NULL;
} }
......
...@@ -144,9 +144,24 @@ PyFloat_FromString(PyObject *v) ...@@ -144,9 +144,24 @@ PyFloat_FromString(PyObject *v)
return NULL; return NULL;
} }
} }
else if (PyBytes_Check(v)) {
s = PyBytes_AS_STRING(v);
len = PyBytes_GET_SIZE(v);
}
else if (PyByteArray_Check(v)) {
s = PyByteArray_AS_STRING(v);
len = PyByteArray_GET_SIZE(v);
}
else if (PyObject_GetBuffer(v, &view, PyBUF_SIMPLE) == 0) { else if (PyObject_GetBuffer(v, &view, PyBUF_SIMPLE) == 0) {
s = (const char *)view.buf; s = (const char *)view.buf;
len = view.len; len = view.len;
/* Copy to NUL-terminated buffer. */
s_buffer = PyBytes_FromStringAndSize(s, len);
if (s_buffer == NULL) {
PyBuffer_Release(&view);
return NULL;
}
s = PyBytes_AS_STRING(s_buffer);
} }
else { else {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
......
...@@ -560,20 +560,37 @@ Return a Unicode string of one character with ordinal i; 0 <= i <= 0x10ffff."); ...@@ -560,20 +560,37 @@ Return a Unicode string of one character with ordinal i; 0 <= i <= 0x10ffff.");
static const char * static const char *
source_as_string(PyObject *cmd, const char *funcname, const char *what, PyCompilerFlags *cf, Py_buffer *view) source_as_string(PyObject *cmd, const char *funcname, const char *what, PyCompilerFlags *cf, PyObject **cmd_copy)
{ {
const char *str; const char *str;
Py_ssize_t size; Py_ssize_t size;
Py_buffer view;
*cmd_copy = NULL;
if (PyUnicode_Check(cmd)) { if (PyUnicode_Check(cmd)) {
cf->cf_flags |= PyCF_IGNORE_COOKIE; cf->cf_flags |= PyCF_IGNORE_COOKIE;
str = PyUnicode_AsUTF8AndSize(cmd, &size); str = PyUnicode_AsUTF8AndSize(cmd, &size);
if (str == NULL) if (str == NULL)
return NULL; return NULL;
} }
else if (PyObject_GetBuffer(cmd, view, PyBUF_SIMPLE) == 0) { else if (PyBytes_Check(cmd)) {
str = (const char *)view->buf; str = PyBytes_AS_STRING(cmd);
size = view->len; size = PyBytes_GET_SIZE(cmd);
}
else if (PyByteArray_Check(cmd)) {
str = PyByteArray_AS_STRING(cmd);
size = PyByteArray_GET_SIZE(cmd);
}
else if (PyObject_GetBuffer(cmd, &view, PyBUF_SIMPLE) == 0) {
/* Copy to NUL-terminated buffer. */
*cmd_copy = PyBytes_FromStringAndSize(
(const char *)view.buf, view.len);
PyBuffer_Release(&view);
if (*cmd_copy == NULL) {
return NULL;
}
str = PyBytes_AS_STRING(*cmd_copy);
size = PyBytes_GET_SIZE(*cmd_copy);
} }
else { else {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
...@@ -585,7 +602,7 @@ source_as_string(PyObject *cmd, const char *funcname, const char *what, PyCompil ...@@ -585,7 +602,7 @@ source_as_string(PyObject *cmd, const char *funcname, const char *what, PyCompil
if (strlen(str) != size) { if (strlen(str) != size) {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"source code string cannot contain null bytes"); "source code string cannot contain null bytes");
PyBuffer_Release(view); Py_CLEAR(*cmd_copy);
return NULL; return NULL;
} }
return str; return str;
...@@ -594,7 +611,7 @@ source_as_string(PyObject *cmd, const char *funcname, const char *what, PyCompil ...@@ -594,7 +611,7 @@ source_as_string(PyObject *cmd, const char *funcname, const char *what, PyCompil
static PyObject * static PyObject *
builtin_compile(PyObject *self, PyObject *args, PyObject *kwds) builtin_compile(PyObject *self, PyObject *args, PyObject *kwds)
{ {
Py_buffer view = {NULL, NULL}; PyObject *cmd_copy;
const char *str; const char *str;
PyObject *filename; PyObject *filename;
char *startstr; char *startstr;
...@@ -681,12 +698,12 @@ builtin_compile(PyObject *self, PyObject *args, PyObject *kwds) ...@@ -681,12 +698,12 @@ builtin_compile(PyObject *self, PyObject *args, PyObject *kwds)
goto finally; goto finally;
} }
str = source_as_string(cmd, "compile", "string, bytes or AST", &cf, &view); str = source_as_string(cmd, "compile", "string, bytes or AST", &cf, &cmd_copy);
if (str == NULL) if (str == NULL)
goto error; goto error;
result = Py_CompileStringObject(str, filename, start[mode], &cf, optimize); result = Py_CompileStringObject(str, filename, start[mode], &cf, optimize);
PyBuffer_Release(&view); Py_XDECREF(cmd_copy);
goto finally; goto finally;
error: error:
...@@ -754,9 +771,8 @@ Return the tuple ((x-x%y)/y, x%y). Invariant: div*y + mod == x."); ...@@ -754,9 +771,8 @@ Return the tuple ((x-x%y)/y, x%y). Invariant: div*y + mod == x.");
static PyObject * static PyObject *
builtin_eval(PyObject *self, PyObject *args) builtin_eval(PyObject *self, PyObject *args)
{ {
PyObject *cmd, *result, *tmp = NULL; PyObject *cmd, *result, *cmd_copy;
PyObject *globals = Py_None, *locals = Py_None; PyObject *globals = Py_None, *locals = Py_None;
Py_buffer view = {NULL, NULL};
const char *str; const char *str;
PyCompilerFlags cf; PyCompilerFlags cf;
...@@ -806,7 +822,7 @@ builtin_eval(PyObject *self, PyObject *args) ...@@ -806,7 +822,7 @@ builtin_eval(PyObject *self, PyObject *args)
} }
cf.cf_flags = PyCF_SOURCE_IS_UTF8; cf.cf_flags = PyCF_SOURCE_IS_UTF8;
str = source_as_string(cmd, "eval", "string, bytes or code", &cf, &view); str = source_as_string(cmd, "eval", "string, bytes or code", &cf, &cmd_copy);
if (str == NULL) if (str == NULL)
return NULL; return NULL;
...@@ -815,8 +831,7 @@ builtin_eval(PyObject *self, PyObject *args) ...@@ -815,8 +831,7 @@ builtin_eval(PyObject *self, PyObject *args)
(void)PyEval_MergeCompilerFlags(&cf); (void)PyEval_MergeCompilerFlags(&cf);
result = PyRun_StringFlags(str, Py_eval_input, globals, locals, &cf); result = PyRun_StringFlags(str, Py_eval_input, globals, locals, &cf);
PyBuffer_Release(&view); Py_XDECREF(cmd_copy);
Py_XDECREF(tmp);
return result; return result;
} }
...@@ -882,12 +897,13 @@ builtin_exec(PyObject *self, PyObject *args) ...@@ -882,12 +897,13 @@ builtin_exec(PyObject *self, PyObject *args)
v = PyEval_EvalCode(prog, globals, locals); v = PyEval_EvalCode(prog, globals, locals);
} }
else { else {
Py_buffer view = {NULL, NULL}; PyObject *prog_copy;
const char *str; const char *str;
PyCompilerFlags cf; PyCompilerFlags cf;
cf.cf_flags = PyCF_SOURCE_IS_UTF8; cf.cf_flags = PyCF_SOURCE_IS_UTF8;
str = source_as_string(prog, "exec", str = source_as_string(prog, "exec",
"string, bytes or code", &cf, &view); "string, bytes or code", &cf,
&prog_copy);
if (str == NULL) if (str == NULL)
return NULL; return NULL;
if (PyEval_MergeCompilerFlags(&cf)) if (PyEval_MergeCompilerFlags(&cf))
...@@ -895,7 +911,7 @@ builtin_exec(PyObject *self, PyObject *args) ...@@ -895,7 +911,7 @@ builtin_exec(PyObject *self, PyObject *args)
locals, &cf); locals, &cf);
else else
v = PyRun_String(str, Py_file_input, globals, locals); v = PyRun_String(str, Py_file_input, globals, locals);
PyBuffer_Release(&view); Py_XDECREF(prog_copy);
} }
if (v == NULL) if (v == NULL)
return NULL; return NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment