Skip to content

Commit f089fb6

Browse files
committed
fix more crashes and move test where they belong
1 parent 4dd0652 commit f089fb6

File tree

6 files changed

+139
-113
lines changed

6 files changed

+139
-113
lines changed

Lib/test/test_sqlite3/test_dbapi.py

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,104 +2029,5 @@ def test_row_is_a_sequence(self):
20292029
self.assertIsInstance(row, Sequence)
20302030

20312031

2032-
class CallbackTests(unittest.TestCase):
2033-
2034-
def setUp(self):
2035-
super().setUp()
2036-
self.cx = sqlite.connect(":memory:")
2037-
self.addCleanup(self.cx.close)
2038-
self.cu = self.cx.cursor()
2039-
self.cu.execute("create table test(a number)")
2040-
2041-
class Handler:
2042-
cx = self.cx
2043-
2044-
self.handler_class = Handler
2045-
2046-
def assert_not_authorized(self, func, /, *args, **kwargs):
2047-
with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
2048-
func(*args, **kwargs)
2049-
2050-
def assert_interrupted(self, func, /, *args, **kwargs):
2051-
with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
2052-
func(*args, **kwargs)
2053-
2054-
def assert_invalid_trace(self, func, /, *args, **kwargs):
2055-
# Exception in trace callbacks are entirely suppressed.
2056-
pass
2057-
2058-
# When a handler has an invalid signature, the exception raised is
2059-
# the same that would be raised if the handler "negatively" replied.
2060-
2061-
def test_authorizer_invalid_signature(self):
2062-
self.cx.set_authorizer(lambda: None)
2063-
self.assert_not_authorized(self.cx.execute, "select * from test")
2064-
2065-
def test_progress_handler_invalid_signature(self):
2066-
self.cx.set_progress_handler(lambda x: None, 1)
2067-
self.assert_interrupted(self.cx.execute, "select * from test")
2068-
2069-
def test_trace_callback_invalid_signature_traceback(self):
2070-
self.cx.set_trace_callback(lambda: None)
2071-
self.assert_invalid_trace(self.cx.execute, "select * from test")
2072-
2073-
# Tests for checking that callback context mutations do not crash.
2074-
# Regression tests for https://github.com/python/cpython/issues/142830.
2075-
2076-
def test_authorizer_concurrent_mutation_in_call(self):
2077-
class Handler(self.handler_class):
2078-
def __call__(self, *a, **kw):
2079-
self.cx.set_authorizer(None)
2080-
raise ValueError
2081-
2082-
self.cx.set_authorizer(Handler())
2083-
self.assert_not_authorized(self.cx.execute, "select * from test")
2084-
2085-
def test_authorizer_concurrent_mutation_with_overflown_value(self):
2086-
_testcapi = import_helper.import_module("_testcapi")
2087-
2088-
class Handler(self.handler_class):
2089-
def __call__(self, *a, **kw):
2090-
self.cx.set_authorizer(None)
2091-
# We expect 'int' at the C level, so this one will raise
2092-
# when converting via PyLong_Int().
2093-
return _testcapi.INT_MAX + 1
2094-
2095-
self.cx.set_authorizer(Handler())
2096-
self.assert_not_authorized(self.cx.execute, "select * from test")
2097-
2098-
def test_progress_handler_concurrent_mutation_in_call(self):
2099-
class Handler(self.handler_class):
2100-
def __call__(self, *a, **kw):
2101-
self.cx.set_authorizer(None)
2102-
raise ValueError
2103-
2104-
self.cx.set_progress_handler(Handler(), 1)
2105-
self.assert_interrupted(self.cx.execute, "select * from test")
2106-
2107-
def test_progress_handler_concurrent_mutation_in_conversion(self):
2108-
class Handler(self.handler_class):
2109-
def __bool__(self):
2110-
# clear the progress handler
2111-
self.cx.set_progress_handler(None, 1)
2112-
raise ValueError # force PyObject_True() to fail
2113-
2114-
self.cx.set_progress_handler(Handler.__init__, 1)
2115-
self.assert_interrupted(self.cx.execute, "select * from test")
2116-
2117-
def test_trace_callback_concurrent_mutation_in_call(self):
2118-
class Handler:
2119-
def __call__(self, statement):
2120-
# clear the progress handler
2121-
self.cx.set_progress_handler(None, 1)
2122-
raise ValueError
2123-
2124-
self.cx.set_trace_callback(Handler())
2125-
self.assert_invalid_trace(self.cx.execute, "select * from test")
2126-
2127-
# TODO(picnixz): increase test coverage for other callbacks
2128-
# such as 'func', 'step', 'finalize', and 'collation'.
2129-
2130-
21312032
if __name__ == "__main__":
21322033
unittest.main()

Lib/test/test_sqlite3/test_hooks.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@
2424
import sqlite3 as sqlite
2525
import unittest
2626

27+
from test.support import import_helper
2728
from test.support.os_helper import TESTFN, unlink
2829

2930
from .util import memory_database, cx_limit, with_tracebacks
3031
from .util import MemoryDatabaseMixin
3132

33+
# TODO(picnixz): increase test coverage for other callbacks
34+
# such as 'func', 'step', 'finalize', and 'collation'.
35+
3236

3337
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
3438

@@ -129,8 +133,59 @@ def test_deregister_collation(self):
129133
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
130134

131135

136+
class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase):
137+
138+
def assert_not_authorized(self, func, /, *args, **kwargs):
139+
with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
140+
func(*args, **kwargs)
141+
142+
# When a handler has an invalid signature, the exception raised is
143+
# the same that would be raised if the handler "negatively" replied.
144+
145+
def test_authorizer_invalid_signature(self):
146+
self.cx.execute("create table if not exists test(a number)")
147+
self.cx.set_authorizer(lambda: None)
148+
self.assert_not_authorized(self.cx.execute, "select * from test")
149+
150+
# Tests for checking that callback context mutations do not crash.
151+
# Regression tests for https://github.com/python/cpython/issues/142830.
152+
153+
@with_tracebacks(ZeroDivisionError, regex="hello world")
154+
def test_authorizer_concurrent_mutation_in_call(self):
155+
self.cx.execute("create table if not exists test(a number)")
156+
157+
class Handler:
158+
cx = self.cx
159+
def __call__(self, *a, **kw):
160+
self.cx.set_authorizer(None)
161+
raise ZeroDivisionError("hello world")
162+
163+
self.cx.set_authorizer(Handler())
164+
self.assert_not_authorized(self.cx.execute, "select * from test")
165+
166+
@with_tracebacks(OverflowError)
167+
def test_authorizer_concurrent_mutation_with_overflown_value(self):
168+
_testcapi = import_helper.import_module("_testcapi")
169+
self.cx.execute("create table if not exists test(a number)")
170+
171+
class Handler:
172+
cx = self.cx
173+
def __call__(self, *a, **kw):
174+
self.cx.set_authorizer(None)
175+
# We expect 'int' at the C level, so this one will raise
176+
# when converting via PyLong_Int().
177+
return _testcapi.INT_MAX + 1
178+
179+
self.cx.set_authorizer(Handler())
180+
self.assert_not_authorized(self.cx.execute, "select * from test")
181+
182+
132183
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
133184

185+
def assert_interrupted(self, func, /, *args, **kwargs):
186+
with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
187+
func(*args, **kwargs)
188+
134189
def test_progress_handler_used(self):
135190
"""
136191
Test that the progress handler is invoked once it is set.
@@ -219,11 +274,51 @@ def bad_progress():
219274
create table foo(a, b)
220275
""")
221276

222-
def test_progress_handler_keyword_args(self):
277+
def test_set_progress_handler_keyword_args(self):
223278
with self.assertRaisesRegex(TypeError,
224279
'takes at least 1 positional argument'):
225280
self.con.set_progress_handler(progress_handler=lambda: None, n=1)
226281

282+
# When a handler has an invalid signature, the exception raised is
283+
# the same that would be raised if the handler "negatively" replied.
284+
285+
def test_progress_handler_invalid_signature(self):
286+
self.cx.execute("create table if not exists test(a number)")
287+
self.cx.set_progress_handler(lambda x: None, 1)
288+
self.assert_interrupted(self.cx.execute, "select * from test")
289+
290+
# Tests for checking that callback context mutations do not crash.
291+
# Regression tests for https://github.com/python/cpython/issues/142830.
292+
293+
@with_tracebacks(ZeroDivisionError, regex="hello world")
294+
def test_progress_handler_concurrent_mutation_in_call(self):
295+
self.cx.execute("create table if not exists test(a number)")
296+
297+
class Handler:
298+
cx = self.cx
299+
def __call__(self, *a, **kw):
300+
self.cx.set_progress_handler(None, 1)
301+
raise ZeroDivisionError("hello world")
302+
303+
self.cx.set_progress_handler(Handler(), 1)
304+
self.assert_interrupted(self.cx.execute, "select * from test")
305+
306+
def test_progress_handler_concurrent_mutation_in_conversion(self):
307+
self.cx.execute("create table if not exists test(a number)")
308+
309+
class Handler:
310+
cx = self.cx
311+
def __bool__(self):
312+
# clear the progress handler
313+
self.cx.set_progress_handler(None, 1)
314+
raise ValueError # force PyObject_True() to fail
315+
316+
self.cx.set_progress_handler(Handler.__init__, 1)
317+
self.assert_interrupted(self.cx.execute, "select * from test")
318+
319+
# Running with tracebacks makes the second execution of this
320+
# function raise another exception because of a database change.
321+
227322

228323
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
229324

@@ -345,11 +440,42 @@ def test_trace_bad_handler(self):
345440
cx.set_trace_callback(lambda stmt: 5/0)
346441
cx.execute("select 1")
347442

348-
def test_trace_keyword_args(self):
443+
def test_set_trace_callback_keyword_args(self):
349444
with self.assertRaisesRegex(TypeError,
350445
'takes exactly 1 positional argument'):
351446
self.con.set_trace_callback(trace_callback=lambda: None)
352447

448+
# When a handler has an invalid signature, the exception raised is
449+
# the same that would be raised if the handler "negatively" replied,
450+
# but for the trace handler, exceptions are never re-raised (only
451+
# printed when needed).
452+
453+
@with_tracebacks(
454+
TypeError,
455+
regex=r".*<lambda>\(\) missing 6 required positional arguments",
456+
)
457+
def test_trace_handler_invalid_signature(self):
458+
self.cx.execute("create table if not exists test(a number)")
459+
self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None)
460+
self.cx.execute("select * from test")
461+
462+
# Tests for checking that callback context mutations do not crash.
463+
# Regression tests for https://github.com/python/cpython/issues/142830.
464+
465+
@with_tracebacks(ZeroDivisionError, regex="hello world")
466+
def test_trace_callback_concurrent_mutation_in_call(self):
467+
self.cx.execute("create table if not exists test(a number)")
468+
469+
class Handler:
470+
cx = self.cx
471+
def __call__(self, statement):
472+
# clear the progress handler
473+
self.cx.set_trace_callback(None)
474+
raise ZeroDivisionError("hello world")
475+
476+
self.cx.set_trace_callback(Handler())
477+
self.cx.execute("select * from test")
478+
353479

354480
if __name__ == "__main__":
355481
unittest.main()

Lib/test/test_sqlite3/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def check_tracebacks(self, cm, exc, exc_regex, msg_regex, obj_name):
5050
with contextlib.redirect_stderr(buf):
5151
yield
5252

53-
self.assertEqual(cm.unraisable.exc_type, exc)
53+
self.assertIsSubclass(cm.unraisable.exc_type, exc)
5454
if exc_regex:
5555
msg = str(cm.unraisable.exc_value)
5656
self.assertIsNotNone(exc_regex.search(msg), (exc_regex, msg))

Modules/_sqlite/connection.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,8 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
14501450

14511451
assert(ctx_vp != NULL);
14521452
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
1453+
// Hold a reference to 'ctx' to prevent concurrent mutations.
1454+
Py_INCREF(ctx);
14531455
pysqlite_state *state = ctx->state;
14541456
assert(state != NULL);
14551457

@@ -1474,9 +1476,7 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
14741476
sqlite3_free((void *)expanded_sql);
14751477
}
14761478
if (py_statement) {
1477-
Py_INCREF(ctx);
14781479
PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
1479-
Py_DECREF(ctx);
14801480
Py_DECREF(py_statement);
14811481
Py_XDECREF(ret);
14821482
}
@@ -1485,6 +1485,7 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
14851485
}
14861486

14871487
exit:
1488+
Py_DECREF(ctx);
14881489
PyGILState_Release(gilstate);
14891490
return 0;
14901491
}

Modules/_sqlite/context.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ callback_context_new_impl(PyTypeObject *type, PyObject *callable)
4949
ctx->state = state;
5050
PyObject_GC_Track(ctx);
5151
return (PyObject *)ctx;
52-
5352
}
5453

5554
static int
@@ -104,8 +103,7 @@ static PyType_Spec callback_context_spec = {
104103
PyObject *
105104
pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable)
106105
{
107-
PyTypeObject *type = state->CallbackContextType;
108-
return callback_context_new_impl(type, callable);
106+
return callback_context_new_impl(state->CallbackContextType, callable);
109107
}
110108

111109
int

Modules/_sqlite/context.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ extern "C" {
99

1010
#include "module.h"
1111

12+
/*
13+
* UDF/callback context structure.
14+
*
15+
* In order to ensure that the state pointer always outlives the callback
16+
* context, we make sure it owns a reference to the module itself.
17+
*/
1218
typedef struct {
1319
PyObject_HEAD
1420
PyObject *callable;
@@ -18,12 +24,6 @@ typedef struct {
1824

1925
#define pysqlite_CallbackContext_CAST(op) ((pysqlite_CallbackContext *)(op))
2026

21-
/* Allocate a UDF/callback context structure. In order to ensure that the state
22-
* pointer always outlives the callback context, we make sure it owns a
23-
* reference to the module itself. create_callback_context() is always called
24-
* from connection methods, so we use the defining class to fetch the module
25-
* pointer.
26-
*/
2727
PyObject *
2828
pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable);
2929

0 commit comments

Comments
 (0)