|
24 | 24 | import sqlite3 as sqlite |
25 | 25 | import unittest |
26 | 26 |
|
| 27 | +from test.support import import_helper |
27 | 28 | from test.support.os_helper import TESTFN, unlink |
28 | 29 |
|
29 | 30 | from .util import memory_database, cx_limit, with_tracebacks |
30 | 31 | from .util import MemoryDatabaseMixin |
31 | 32 |
|
| 33 | +# TODO(picnixz): increase test coverage for other callbacks |
| 34 | +# such as 'func', 'step', 'finalize', and 'collation'. |
| 35 | + |
32 | 36 |
|
33 | 37 | class CollationTests(MemoryDatabaseMixin, unittest.TestCase): |
34 | 38 |
|
@@ -129,8 +133,59 @@ def test_deregister_collation(self): |
129 | 133 | self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') |
130 | 134 |
|
131 | 135 |
|
| 136 | +class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase): |
| 137 | + |
| 138 | + def assert_not_authorized(self, func, /, *args, **kwargs): |
| 139 | + with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"): |
| 140 | + func(*args, **kwargs) |
| 141 | + |
| 142 | + # When a handler has an invalid signature, the exception raised is |
| 143 | + # the same that would be raised if the handler "negatively" replied. |
| 144 | + |
| 145 | + def test_authorizer_invalid_signature(self): |
| 146 | + self.cx.execute("create table if not exists test(a number)") |
| 147 | + self.cx.set_authorizer(lambda: None) |
| 148 | + self.assert_not_authorized(self.cx.execute, "select * from test") |
| 149 | + |
| 150 | + # Tests for checking that callback context mutations do not crash. |
| 151 | + # Regression tests for https://github.com/python/cpython/issues/142830. |
| 152 | + |
| 153 | + @with_tracebacks(ZeroDivisionError, regex="hello world") |
| 154 | + def test_authorizer_concurrent_mutation_in_call(self): |
| 155 | + self.cx.execute("create table if not exists test(a number)") |
| 156 | + |
| 157 | + class Handler: |
| 158 | + cx = self.cx |
| 159 | + def __call__(self, *a, **kw): |
| 160 | + self.cx.set_authorizer(None) |
| 161 | + raise ZeroDivisionError("hello world") |
| 162 | + |
| 163 | + self.cx.set_authorizer(Handler()) |
| 164 | + self.assert_not_authorized(self.cx.execute, "select * from test") |
| 165 | + |
| 166 | + @with_tracebacks(OverflowError) |
| 167 | + def test_authorizer_concurrent_mutation_with_overflown_value(self): |
| 168 | + _testcapi = import_helper.import_module("_testcapi") |
| 169 | + self.cx.execute("create table if not exists test(a number)") |
| 170 | + |
| 171 | + class Handler: |
| 172 | + cx = self.cx |
| 173 | + def __call__(self, *a, **kw): |
| 174 | + self.cx.set_authorizer(None) |
| 175 | + # We expect 'int' at the C level, so this one will raise |
| 176 | + # when converting via PyLong_Int(). |
| 177 | + return _testcapi.INT_MAX + 1 |
| 178 | + |
| 179 | + self.cx.set_authorizer(Handler()) |
| 180 | + self.assert_not_authorized(self.cx.execute, "select * from test") |
| 181 | + |
| 182 | + |
132 | 183 | class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): |
133 | 184 |
|
| 185 | + def assert_interrupted(self, func, /, *args, **kwargs): |
| 186 | + with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"): |
| 187 | + func(*args, **kwargs) |
| 188 | + |
134 | 189 | def test_progress_handler_used(self): |
135 | 190 | """ |
136 | 191 | Test that the progress handler is invoked once it is set. |
@@ -219,11 +274,51 @@ def bad_progress(): |
219 | 274 | create table foo(a, b) |
220 | 275 | """) |
221 | 276 |
|
222 | | - def test_progress_handler_keyword_args(self): |
| 277 | + def test_set_progress_handler_keyword_args(self): |
223 | 278 | with self.assertRaisesRegex(TypeError, |
224 | 279 | 'takes at least 1 positional argument'): |
225 | 280 | self.con.set_progress_handler(progress_handler=lambda: None, n=1) |
226 | 281 |
|
| 282 | + # When a handler has an invalid signature, the exception raised is |
| 283 | + # the same that would be raised if the handler "negatively" replied. |
| 284 | + |
| 285 | + def test_progress_handler_invalid_signature(self): |
| 286 | + self.cx.execute("create table if not exists test(a number)") |
| 287 | + self.cx.set_progress_handler(lambda x: None, 1) |
| 288 | + self.assert_interrupted(self.cx.execute, "select * from test") |
| 289 | + |
| 290 | + # Tests for checking that callback context mutations do not crash. |
| 291 | + # Regression tests for https://github.com/python/cpython/issues/142830. |
| 292 | + |
| 293 | + @with_tracebacks(ZeroDivisionError, regex="hello world") |
| 294 | + def test_progress_handler_concurrent_mutation_in_call(self): |
| 295 | + self.cx.execute("create table if not exists test(a number)") |
| 296 | + |
| 297 | + class Handler: |
| 298 | + cx = self.cx |
| 299 | + def __call__(self, *a, **kw): |
| 300 | + self.cx.set_progress_handler(None, 1) |
| 301 | + raise ZeroDivisionError("hello world") |
| 302 | + |
| 303 | + self.cx.set_progress_handler(Handler(), 1) |
| 304 | + self.assert_interrupted(self.cx.execute, "select * from test") |
| 305 | + |
| 306 | + def test_progress_handler_concurrent_mutation_in_conversion(self): |
| 307 | + self.cx.execute("create table if not exists test(a number)") |
| 308 | + |
| 309 | + class Handler: |
| 310 | + cx = self.cx |
| 311 | + def __bool__(self): |
| 312 | + # clear the progress handler |
| 313 | + self.cx.set_progress_handler(None, 1) |
| 314 | + raise ValueError # force PyObject_True() to fail |
| 315 | + |
| 316 | + self.cx.set_progress_handler(Handler.__init__, 1) |
| 317 | + self.assert_interrupted(self.cx.execute, "select * from test") |
| 318 | + |
| 319 | + # Running with tracebacks makes the second execution of this |
| 320 | + # function raise another exception because of a database change. |
| 321 | + |
227 | 322 |
|
228 | 323 | class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): |
229 | 324 |
|
@@ -345,11 +440,42 @@ def test_trace_bad_handler(self): |
345 | 440 | cx.set_trace_callback(lambda stmt: 5/0) |
346 | 441 | cx.execute("select 1") |
347 | 442 |
|
348 | | - def test_trace_keyword_args(self): |
| 443 | + def test_set_trace_callback_keyword_args(self): |
349 | 444 | with self.assertRaisesRegex(TypeError, |
350 | 445 | 'takes exactly 1 positional argument'): |
351 | 446 | self.con.set_trace_callback(trace_callback=lambda: None) |
352 | 447 |
|
| 448 | + # When a handler has an invalid signature, the exception raised is |
| 449 | + # the same that would be raised if the handler "negatively" replied, |
| 450 | + # but for the trace handler, exceptions are never re-raised (only |
| 451 | + # printed when needed). |
| 452 | + |
| 453 | + @with_tracebacks( |
| 454 | + TypeError, |
| 455 | + regex=r".*<lambda>\(\) missing 6 required positional arguments", |
| 456 | + ) |
| 457 | + def test_trace_handler_invalid_signature(self): |
| 458 | + self.cx.execute("create table if not exists test(a number)") |
| 459 | + self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None) |
| 460 | + self.cx.execute("select * from test") |
| 461 | + |
| 462 | + # Tests for checking that callback context mutations do not crash. |
| 463 | + # Regression tests for https://github.com/python/cpython/issues/142830. |
| 464 | + |
| 465 | + @with_tracebacks(ZeroDivisionError, regex="hello world") |
| 466 | + def test_trace_callback_concurrent_mutation_in_call(self): |
| 467 | + self.cx.execute("create table if not exists test(a number)") |
| 468 | + |
| 469 | + class Handler: |
| 470 | + cx = self.cx |
| 471 | + def __call__(self, statement): |
| 472 | + # clear the progress handler |
| 473 | + self.cx.set_trace_callback(None) |
| 474 | + raise ZeroDivisionError("hello world") |
| 475 | + |
| 476 | + self.cx.set_trace_callback(Handler()) |
| 477 | + self.cx.execute("select * from test") |
| 478 | + |
353 | 479 |
|
354 | 480 | if __name__ == "__main__": |
355 | 481 | unittest.main() |
0 commit comments