diff --git a/CHANGELOG.md b/CHANGELOG.md index 91c7001..afa8c77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,14 @@ -## Upcoming (TBC) +## Upcoming (TBD) + +### Features + +* Add support for opening 'file:' URIs with parameters. [(#234)](https://github.com/dbcli/litecli/pull/234) ### Bug Fixes * Avoid Click 8.1.* to prevent messing up the pager when the PAGER env var has a string with spaces. + ## 1.16.0 - 2025-08-16 ### Features diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py index bba774a..263cfb3 100644 --- a/litecli/sqlexecute.py +++ b/litecli/sqlexecute.py @@ -14,6 +14,7 @@ import sqlparse import os.path +from urllib.parse import urlparse from .packages import special @@ -68,12 +69,19 @@ def connect(self, database=None): db = database or self.dbname _logger.debug("Connection DB Params: \n\tdatabase: %r", db) - db_name = os.path.expanduser(db) - db_dir_name = os.path.dirname(os.path.abspath(db_name)) - if not os.path.exists(db_dir_name): - raise Exception("Path does not exist: {}".format(db_dir_name)) + location = urlparse(db) + if location.scheme and location.scheme == "file": + uri = True + db_name = db + db_filename = location.path + else: + uri = False + db_filename = db_name = os.path.expanduser(db) + db_dir_name = os.path.dirname(os.path.abspath(db_filename)) + if not os.path.exists(db_dir_name): + raise Exception("Path does not exist: {}".format(db_dir_name)) - conn = sqlite3.connect(database=db_name, isolation_level=None) + conn = sqlite3.connect(database=db_name, isolation_level=None, uri=uri) conn.text_factory = lambda x: x.decode("utf-8", "backslashreplace") if self.conn: self.conn.close() @@ -81,7 +89,7 @@ def connect(self, database=None): self.conn = conn # Update them after the connection is made to ensure that it was a # successful connection. - self.dbname = db + self.dbname = db_filename def run(self, statement): """Execute the sql in the database and return the results. The results diff --git a/tests/test_main.py b/tests/test_main.py index 6d4bc1d..f2365be 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -6,11 +6,12 @@ from unittest.mock import patch import click +import pytest from click.testing import CliRunner from litecli.main import cli, LiteCli from litecli.packages.special.main import COMMANDS as SPECIAL_COMMANDS -from utils import dbtest, run +from utils import dbtest, run, create_db, db_connection test_dir = os.path.abspath(os.path.dirname(__file__)) project_dir = os.path.dirname(test_dir) @@ -330,3 +331,30 @@ def test_get_prompt(mock_datetime): # 12. Windows path lc.connect("C:\\Users\\litecli\\litecli_test.db") assert lc.get_prompt(r"\d") == "C:\\Users\\litecli\\litecli_test.db" + + +@pytest.mark.parametrize( + "uri, expected_dbname", + [ + ("file:{tmp_path}/test.db", "{tmp_path}/test.db"), + ("file:{tmp_path}/test.db?mode=ro", "{tmp_path}/test.db"), + ("file:{tmp_path}/test.db?mode=ro&cache=shared", "{tmp_path}/test.db"), + ], +) +def test_file_uri(tmp_path, uri, expected_dbname): + """ + Test that `file:` URIs are correctly handled + ref: + https://docs.python.org/3/library/sqlite3.html#sqlite3-uri-tricks + https://www.sqlite.org/c3ref/open.html#urifilenameexamples + """ + # - ensure db exists + db_path = tmp_path / "test.db" + create_db(db_path) + db_connection(db_path) + uri = uri.format(tmp_path=tmp_path) + + lc = LiteCli() + lc.connect(uri) + + assert lc.get_prompt(r"\d") == expected_dbname.format(tmp_path=tmp_path)