Commit b9a0376b authored by gescheit's avatar gescheit Committed by Benjamin Peterson

closes bpo-37347: Fix refcount problem in sqlite3. (GH-14268)

parent 0827064c
...@@ -25,6 +25,7 @@ import datetime ...@@ -25,6 +25,7 @@ import datetime
import unittest import unittest
import sqlite3 as sqlite import sqlite3 as sqlite
import weakref import weakref
import functools
from test import support from test import support
class RegressionTests(unittest.TestCase): class RegressionTests(unittest.TestCase):
...@@ -383,72 +384,26 @@ class RegressionTests(unittest.TestCase): ...@@ -383,72 +384,26 @@ class RegressionTests(unittest.TestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
del self.con.isolation_level del self.con.isolation_level
def CheckBpo37347(self):
class Printer:
def log(self, *args):
return sqlite.SQLITE_OK
class UnhashableFunc: for method in [self.con.set_trace_callback,
__hash__ = None functools.partial(self.con.set_progress_handler, n=1),
self.con.set_authorizer]:
printer_instance = Printer()
method(printer_instance.log)
method(printer_instance.log)
self.con.execute("select 1") # trigger seg fault
method(None)
def __init__(self, return_value=None):
self.calls = 0
self.return_value = return_value
def __call__(self, *args, **kwargs):
self.calls += 1
return self.return_value
class UnhashableCallbacksTestCase(unittest.TestCase):
"""
https://bugs.python.org/issue34052
Registering unhashable callbacks raises TypeError, callbacks are not
registered in SQLite after such registration attempt.
"""
def setUp(self):
self.con = sqlite.connect(':memory:')
def tearDown(self):
self.con.close()
def test_progress_handler(self):
f = UnhashableFunc(return_value=0)
with self.assertRaisesRegex(TypeError, 'unhashable type'):
self.con.set_progress_handler(f, 1)
self.con.execute('SELECT 1')
self.assertFalse(f.calls)
def test_func(self):
func_name = 'func_name'
f = UnhashableFunc()
with self.assertRaisesRegex(TypeError, 'unhashable type'):
self.con.create_function(func_name, 0, f)
msg = 'no such function: %s' % func_name
with self.assertRaisesRegex(sqlite.OperationalError, msg):
self.con.execute('SELECT %s()' % func_name)
self.assertFalse(f.calls)
def test_authorizer(self):
f = UnhashableFunc(return_value=sqlite.SQLITE_DENY)
with self.assertRaisesRegex(TypeError, 'unhashable type'):
self.con.set_authorizer(f)
self.con.execute('SELECT 1')
self.assertFalse(f.calls)
def test_aggr(self):
class UnhashableType(type):
__hash__ = None
aggr_name = 'aggr_name'
with self.assertRaisesRegex(TypeError, 'unhashable type'):
self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {}))
msg = 'no such function: %s' % aggr_name
with self.assertRaisesRegex(sqlite.OperationalError, msg):
self.con.execute('SELECT %s()' % aggr_name)
def suite(): def suite():
regression_suite = unittest.makeSuite(RegressionTests, "Check") regression_suite = unittest.makeSuite(RegressionTests, "Check")
return unittest.TestSuite(( return unittest.TestSuite((
regression_suite, regression_suite,
unittest.makeSuite(UnhashableCallbacksTestCase),
)) ))
def test(): def test():
......
...@@ -1870,3 +1870,4 @@ Diego Rojas ...@@ -1870,3 +1870,4 @@ Diego Rojas
Edison Abahurire Edison Abahurire
Geoff Shannon Geoff Shannon
Batuhan Taskaya Batuhan Taskaya
Aleksandr Balezin
:meth:`sqlite3.Connection.create_aggregate`,
:meth:`sqlite3.Connection.create_function`,
:meth:`sqlite3.Connection.set_authorizer`,
:meth:`sqlite3.Connection.set_progress_handler`
:meth:`sqlite3.Connection.set_trace_callback`
methods lead to segfaults if some of these methods are called twice with an equal object but not the same. Now callbacks are stored more carefully. Patch by Aleksandr Balezin.
\ No newline at end of file
...@@ -186,10 +186,9 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject ...@@ -186,10 +186,9 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
} }
self->check_same_thread = check_same_thread; self->check_same_thread = check_same_thread;
Py_XSETREF(self->function_pinboard, PyDict_New()); self->function_pinboard_trace_callback = NULL;
if (!self->function_pinboard) { self->function_pinboard_progress_handler = NULL;
return -1; self->function_pinboard_authorizer_cb = NULL;
}
Py_XSETREF(self->collations, PyDict_New()); Py_XSETREF(self->collations, PyDict_New());
if (!self->collations) { if (!self->collations) {
...@@ -249,19 +248,18 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self) ...@@ -249,19 +248,18 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)
/* Clean up if user has not called .close() explicitly. */ /* Clean up if user has not called .close() explicitly. */
if (self->db) { if (self->db) {
Py_BEGIN_ALLOW_THREADS
SQLITE3_CLOSE(self->db); SQLITE3_CLOSE(self->db);
Py_END_ALLOW_THREADS
} }
Py_XDECREF(self->isolation_level); Py_XDECREF(self->isolation_level);
Py_XDECREF(self->function_pinboard); Py_XDECREF(self->function_pinboard_trace_callback);
Py_XDECREF(self->function_pinboard_progress_handler);
Py_XDECREF(self->function_pinboard_authorizer_cb);
Py_XDECREF(self->row_factory); Py_XDECREF(self->row_factory);
Py_XDECREF(self->text_factory); Py_XDECREF(self->text_factory);
Py_XDECREF(self->collations); Py_XDECREF(self->collations);
Py_XDECREF(self->statements); Py_XDECREF(self->statements);
Py_XDECREF(self->cursors); Py_XDECREF(self->cursors);
Py_TYPE(self)->tp_free((PyObject*)self); Py_TYPE(self)->tp_free((PyObject*)self);
} }
...@@ -342,9 +340,7 @@ PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args) ...@@ -342,9 +340,7 @@ PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)
pysqlite_do_all_statements(self, ACTION_FINALIZE, 1); pysqlite_do_all_statements(self, ACTION_FINALIZE, 1);
if (self->db) { if (self->db) {
Py_BEGIN_ALLOW_THREADS
rc = SQLITE3_CLOSE(self->db); rc = SQLITE3_CLOSE(self->db);
Py_END_ALLOW_THREADS
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
_pysqlite_seterror(self->db, NULL); _pysqlite_seterror(self->db, NULL);
...@@ -808,6 +804,11 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self) ...@@ -808,6 +804,11 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
Py_SETREF(self->cursors, new_list); Py_SETREF(self->cursors, new_list);
} }
static void _destructor(void* args)
{
Py_DECREF((PyObject*)args);
}
PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
{ {
static char *kwlist[] = {"name", "narg", "func", "deterministic", NULL}; static char *kwlist[] = {"name", "narg", "func", "deterministic", NULL};
...@@ -843,17 +844,16 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec ...@@ -843,17 +844,16 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec
flags |= SQLITE_DETERMINISTIC; flags |= SQLITE_DETERMINISTIC;
#endif #endif
} }
if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) { Py_INCREF(func);
return NULL; rc = sqlite3_create_function_v2(self->db,
}
rc = sqlite3_create_function(self->db,
name, name,
narg, narg,
flags, flags,
(void*)func, (void*)func,
_pysqlite_func_callback, _pysqlite_func_callback,
NULL, NULL,
NULL); NULL,
&_destructor); // will decref func
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
/* Workaround for SQLite bug: no error code or string is available here */ /* Workaround for SQLite bug: no error code or string is available here */
...@@ -880,11 +880,16 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje ...@@ -880,11 +880,16 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje
kwlist, &name, &n_arg, &aggregate_class)) { kwlist, &name, &n_arg, &aggregate_class)) {
return NULL; return NULL;
} }
Py_INCREF(aggregate_class);
if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) { rc = sqlite3_create_function_v2(self->db,
return NULL; name,
} n_arg,
rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback); SQLITE_UTF8,
(void*)aggregate_class,
0,
&_pysqlite_step_callback,
&_pysqlite_final_callback,
&_destructor); // will decref func
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
/* Workaround for SQLite bug: no error code or string is available here */ /* Workaround for SQLite bug: no error code or string is available here */
PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate"); PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate");
...@@ -1003,13 +1008,14 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P ...@@ -1003,13 +1008,14 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P
return NULL; return NULL;
} }
if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) {
return NULL;
}
rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb); rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback"); PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback");
Py_XSETREF(self->function_pinboard_authorizer_cb, NULL);
return NULL; return NULL;
} else {
Py_INCREF(authorizer_cb);
Py_XSETREF(self->function_pinboard_authorizer_cb, authorizer_cb);
} }
Py_RETURN_NONE; Py_RETURN_NONE;
} }
...@@ -1033,12 +1039,12 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s ...@@ -1033,12 +1039,12 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s
if (progress_handler == Py_None) { if (progress_handler == Py_None) {
/* None clears the progress handler previously set */ /* None clears the progress handler previously set */
sqlite3_progress_handler(self->db, 0, 0, (void*)0); sqlite3_progress_handler(self->db, 0, 0, (void*)0);
Py_XSETREF(self->function_pinboard_progress_handler, NULL);
} else { } else {
if (PyDict_SetItem(self->function_pinboard, progress_handler, Py_None) == -1)
return NULL;
sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler);
Py_INCREF(progress_handler);
Py_XSETREF(self->function_pinboard_progress_handler, progress_handler);
} }
Py_RETURN_NONE; Py_RETURN_NONE;
} }
...@@ -1060,10 +1066,11 @@ static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* sel ...@@ -1060,10 +1066,11 @@ static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* sel
if (trace_callback == Py_None) { if (trace_callback == Py_None) {
/* None clears the trace callback previously set */ /* None clears the trace callback previously set */
sqlite3_trace(self->db, 0, (void*)0); sqlite3_trace(self->db, 0, (void*)0);
Py_XSETREF(self->function_pinboard_trace_callback, NULL);
} else { } else {
if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1)
return NULL;
sqlite3_trace(self->db, _trace_callback, trace_callback); sqlite3_trace(self->db, _trace_callback, trace_callback);
Py_INCREF(trace_callback);
Py_XSETREF(self->function_pinboard_trace_callback, trace_callback);
} }
Py_RETURN_NONE; Py_RETURN_NONE;
......
...@@ -85,11 +85,10 @@ typedef struct ...@@ -85,11 +85,10 @@ typedef struct
*/ */
PyObject* text_factory; PyObject* text_factory;
/* remember references to functions/classes used in /* remember references to object used in trace_callback/progress_handler/authorizer_cb */
* create_function/create/aggregate, use these as dictionary keys, so we PyObject* function_pinboard_trace_callback;
* can keep the total system refcount constant by clearing that dictionary PyObject* function_pinboard_progress_handler;
* in connection_dealloc */ PyObject* function_pinboard_authorizer_cb;
PyObject* function_pinboard;
/* a dictionary of registered collation name => collation callable mappings */ /* a dictionary of registered collation name => collation callable mappings */
PyObject* collations; PyObject* collations;
......
...@@ -1357,7 +1357,7 @@ class PyBuildExt(build_ext): ...@@ -1357,7 +1357,7 @@ class PyBuildExt(build_ext):
] ]
if CROSS_COMPILING: if CROSS_COMPILING:
sqlite_inc_paths = [] sqlite_inc_paths = []
MIN_SQLITE_VERSION_NUMBER = (3, 3, 9) MIN_SQLITE_VERSION_NUMBER = (3, 7, 2)
MIN_SQLITE_VERSION = ".".join([str(x) MIN_SQLITE_VERSION = ".".join([str(x)
for x in MIN_SQLITE_VERSION_NUMBER]) for x in MIN_SQLITE_VERSION_NUMBER])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment