Commit a7996f74 authored by Alexandre Vassalotti's avatar Alexandre Vassalotti

Use PyDict_GetItemWithError instead of PyDict_GetItem in cpickle.

parent 28774771
...@@ -1691,7 +1691,7 @@ fast_save_enter(PicklerObject *self, PyObject *obj) ...@@ -1691,7 +1691,7 @@ fast_save_enter(PicklerObject *self, PyObject *obj)
key = PyLong_FromVoidPtr(obj); key = PyLong_FromVoidPtr(obj);
if (key == NULL) if (key == NULL)
return 0; return 0;
if (PyDict_GetItem(self->fast_memo, key)) { if (PyDict_GetItemWithError(self->fast_memo, key)) {
Py_DECREF(key); Py_DECREF(key);
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"fast mode: can't pickle cyclic objects " "fast mode: can't pickle cyclic objects "
...@@ -1700,6 +1700,9 @@ fast_save_enter(PicklerObject *self, PyObject *obj) ...@@ -1700,6 +1700,9 @@ fast_save_enter(PicklerObject *self, PyObject *obj)
self->fast_nesting = -1; self->fast_nesting = -1;
return 0; return 0;
} }
if (PyErr_Occurred()) {
return 0;
}
if (PyDict_SetItem(self->fast_memo, key, Py_None) < 0) { if (PyDict_SetItem(self->fast_memo, key, Py_None) < 0) {
Py_DECREF(key); Py_DECREF(key);
self->fast_nesting = -1; self->fast_nesting = -1;
...@@ -3142,12 +3145,17 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) ...@@ -3142,12 +3145,17 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
if (extension_key == NULL) { if (extension_key == NULL) {
goto error; goto error;
} }
code_obj = PyDict_GetItem(st->extension_registry, extension_key); code_obj = PyDict_GetItemWithError(st->extension_registry,
extension_key);
Py_DECREF(extension_key); Py_DECREF(extension_key);
/* The object is not registered in the extension registry. /* The object is not registered in the extension registry.
This is the most likely code path. */ This is the most likely code path. */
if (code_obj == NULL) if (code_obj == NULL) {
if (PyErr_Occurred()) {
goto error;
}
goto gen_global; goto gen_global;
}
/* XXX: pickle.py doesn't check neither the type, nor the range /* XXX: pickle.py doesn't check neither the type, nor the range
of the value returned by the extension_registry. It should for of the value returned by the extension_registry. It should for
...@@ -3712,12 +3720,21 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3712,12 +3720,21 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
*/ */
if (self->dispatch_table == NULL) { if (self->dispatch_table == NULL) {
PickleState *st = _Pickle_GetGlobalState(); PickleState *st = _Pickle_GetGlobalState();
reduce_func = PyDict_GetItem(st->dispatch_table, (PyObject *)type); reduce_func = PyDict_GetItemWithError(st->dispatch_table,
/* PyDict_GetItem() unlike PyObject_GetItem() and (PyObject *)type);
PyObject_GetAttr() returns a borrowed ref */ if (reduce_func == NULL) {
Py_XINCREF(reduce_func); if (PyErr_Occurred()) {
goto error;
}
} else {
/* PyDict_GetItemWithError() returns a borrowed reference.
Increase the reference count to be consistent with
PyObject_GetItem and _PyObject_GetAttrId used below. */
Py_INCREF(reduce_func);
}
} else { } else {
reduce_func = PyObject_GetItem(self->dispatch_table, (PyObject *)type); reduce_func = PyObject_GetItem(self->dispatch_table,
(PyObject *)type);
if (reduce_func == NULL) { if (reduce_func == NULL) {
if (PyErr_ExceptionMatches(PyExc_KeyError)) if (PyErr_ExceptionMatches(PyExc_KeyError))
PyErr_Clear(); PyErr_Clear();
...@@ -5564,20 +5581,26 @@ load_extension(UnpicklerObject *self, int nbytes) ...@@ -5564,20 +5581,26 @@ load_extension(UnpicklerObject *self, int nbytes)
py_code = PyLong_FromLong(code); py_code = PyLong_FromLong(code);
if (py_code == NULL) if (py_code == NULL)
return -1; return -1;
obj = PyDict_GetItem(st->extension_cache, py_code); obj = PyDict_GetItemWithError(st->extension_cache, py_code);
if (obj != NULL) { if (obj != NULL) {
/* Bingo. */ /* Bingo. */
Py_DECREF(py_code); Py_DECREF(py_code);
PDATA_APPEND(self->stack, obj, -1); PDATA_APPEND(self->stack, obj, -1);
return 0; return 0;
} }
if (PyErr_Occurred()) {
Py_DECREF(py_code);
return -1;
}
/* Look up the (module_name, class_name) pair. */ /* Look up the (module_name, class_name) pair. */
pair = PyDict_GetItem(st->inverted_registry, py_code); pair = PyDict_GetItemWithError(st->inverted_registry, py_code);
if (pair == NULL) { if (pair == NULL) {
Py_DECREF(py_code); Py_DECREF(py_code);
PyErr_Format(PyExc_ValueError, "unregistered extension " if (!PyErr_Occurred()) {
"code %ld", code); PyErr_Format(PyExc_ValueError, "unregistered extension "
"code %ld", code);
}
return -1; return -1;
} }
/* Since the extension registry is manipulable via Python code, /* Since the extension registry is manipulable via Python code,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment