From 3e9a23bdd7fe62a54a91aab6889fdd46fef099dd Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sat, 7 Dec 2024 10:01:23 -0800 Subject: [PATCH 01/21] Initial working implementation. --- litecli/packages/special/dbcommands.py | 73 +++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py index d779764..57dc702 100644 --- a/litecli/packages/special/dbcommands.py +++ b/litecli/packages/special/dbcommands.py @@ -6,6 +6,9 @@ import platform import shlex +import click +import llm + from litecli import __version__ from litecli.packages.special import iocommands from .main import special_command, RAW_QUERY, PARSED_QUERY @@ -90,6 +93,74 @@ def show_schema(cur, arg=None, **_): return [(None, tables, headers, status)] +@special_command( + ".ai", + ".llm", + "Ask AI to answer your question.", + arg_type=PARSED_QUERY, + case_sensitive=False, + aliases=("\\ai",), +) +def use_ai(cur, arg=None, **_): + schema_query = """ + SELECT sql FROM sqlite_master + WHERE sql IS NOT NULL + ORDER BY tbl_name, type DESC, name + """ + log.debug(schema_query) + cur.execute(schema_query) + db_schema = "\n".join([x for (x,) in cur.fetchall()]) + tables_query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + log.debug(tables_query) + cur.execute(tables_query) + sample_row = "SELECT * FROM {table} LIMIT 1" + sample_data = {} + for (table,) in cur.fetchall(): + sample_row_query = sample_row.format(table=table) + cur.execute(sample_row_query) + cols = [x[0] for x in cur.description] + row = cur.fetchone() + if row is None: + continue + sample_data[table] = list(zip(cols, row)) + + sys_prompt = f"""A SQLite database has the following schema: + {db_schema} + + Here is a sample data for each table: {sample_data} + + + Use the provided schema and the sample data to construct a SQL query that + can be run in SQLite3 to answer + + {arg} + + Do NOT include any additional formatting or explanation just the SQL query. + """ + log.debug(sys_prompt) + # model = llm.get_model("qwen2.5-coder") + model = llm.get_model("o1-preview") + # model = llm.get_model("o1-mini") + # model = llm.get_model("llama3.2") + # model = llm.get_model("gpt-4o") + resp = model.prompt(sys_prompt) + status = "" + headers = "" + click.echo(resp.text()) + click.echo("Ok to execute?") + ans = click.prompt("y/n") + if ans == "y": + results = cur.execute(resp.text()) + else: + results = None + + return [(None, results, headers, status)] + + @special_command( ".databases", ".databases", @@ -234,7 +305,7 @@ def describe(cur, arg, **_): @special_command( - ".import", + ".itables_query", ".import filename table", "Import data from filename into an existing table", arg_type=PARSED_QUERY, From e1defd524d9bf441fcb29affc3f1982e931e5bd4 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Fri, 13 Dec 2024 17:01:52 -0800 Subject: [PATCH 02/21] Move llm related commands to special/iocommands. --- litecli/main.py | 29 +++++++ litecli/packages/special/dbcommands.py | 70 ---------------- litecli/packages/special/iocommands.py | 110 +++++++++++++++++++++++-- litecli/packages/special/main.py | 8 ++ 4 files changed, 139 insertions(+), 78 deletions(-) diff --git a/litecli/main.py b/litecli/main.py index a0607ab..b7bcbcb 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -347,6 +347,27 @@ def handle_editor_command(self, text): continue return text + def handle_llm_command(self, text): + if not special.is_llm_command(text): + return text + + cur = self.sqlexecute.conn.cursor() + try: + question = special.get_llm_question(text) + context, sql = special.sql_using_llm(cur=cur, question=question) + except Exception as e: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(e) + while True: + try: + click.echo(context) + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + return text + def run_cli(self): iterations = 0 sqlexecute = self.sqlexecute @@ -402,6 +423,14 @@ def one_iteration(text=None): self.echo(str(e), err=True, fg="red") return + try: + text = self.handle_llm_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + return + if not text.strip(): return diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py index 57dc702..66ee8e2 100644 --- a/litecli/packages/special/dbcommands.py +++ b/litecli/packages/special/dbcommands.py @@ -6,8 +6,6 @@ import platform import shlex -import click -import llm from litecli import __version__ from litecli.packages.special import iocommands @@ -93,74 +91,6 @@ def show_schema(cur, arg=None, **_): return [(None, tables, headers, status)] -@special_command( - ".ai", - ".llm", - "Ask AI to answer your question.", - arg_type=PARSED_QUERY, - case_sensitive=False, - aliases=("\\ai",), -) -def use_ai(cur, arg=None, **_): - schema_query = """ - SELECT sql FROM sqlite_master - WHERE sql IS NOT NULL - ORDER BY tbl_name, type DESC, name - """ - log.debug(schema_query) - cur.execute(schema_query) - db_schema = "\n".join([x for (x,) in cur.fetchall()]) - tables_query = """ - SELECT name FROM sqlite_master - WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' - ORDER BY 1 - """ - log.debug(tables_query) - cur.execute(tables_query) - sample_row = "SELECT * FROM {table} LIMIT 1" - sample_data = {} - for (table,) in cur.fetchall(): - sample_row_query = sample_row.format(table=table) - cur.execute(sample_row_query) - cols = [x[0] for x in cur.description] - row = cur.fetchone() - if row is None: - continue - sample_data[table] = list(zip(cols, row)) - - sys_prompt = f"""A SQLite database has the following schema: - {db_schema} - - Here is a sample data for each table: {sample_data} - - - Use the provided schema and the sample data to construct a SQL query that - can be run in SQLite3 to answer - - {arg} - - Do NOT include any additional formatting or explanation just the SQL query. - """ - log.debug(sys_prompt) - # model = llm.get_model("qwen2.5-coder") - model = llm.get_model("o1-preview") - # model = llm.get_model("o1-mini") - # model = llm.get_model("llama3.2") - # model = llm.get_model("gpt-4o") - resp = model.prompt(sys_prompt) - status = "" - headers = "" - click.echo(resp.text()) - click.echo("Ok to execute?") - ans = click.prompt("y/n") - if ans == "y": - results = cur.execute(resp.text()) - else: - results = None - - return [(None, results, headers, status)] - - @special_command( ".databases", ".databases", diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index eeba814..a932d7f 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -1,22 +1,26 @@ from __future__ import unicode_literals -import os -import re + import locale import logging -import subprocess +import os +import re import shlex +import subprocess from io import open from time import sleep +from typing import Optional, Tuple import click +import llm import sqlparse from configobj import ConfigObj +from litecli.packages.prompt_utils import confirm_destructive_query + from . import export -from .main import special_command, NO_QUERY, PARSED_QUERY from .favoritequeries import FavoriteQueries +from .main import NO_QUERY, PARSED_QUERY, special_command from .utils import handle_cd_command -from litecli.packages.prompt_utils import confirm_destructive_query use_expanded_output = False PAGER_ENABLED = True @@ -27,6 +31,8 @@ written_to_pipe_once_process = False favoritequeries = FavoriteQueries(ConfigObj()) +log = logging.getLogger(__name__) + @export def set_favorite_queries(config): @@ -95,9 +101,6 @@ def is_expanded_output(): return use_expanded_output -_logger = logging.getLogger(__name__) - - @export def editor_command(command): """ @@ -171,6 +174,97 @@ def open_external_editor(filename=None, sql=None): return (query, message) +@export +def is_llm_command(command) -> bool: + """ + Is this an llm/ai command? + """ + return ( + command.strip().startswith("\\llm") + or command.strip().startswith("\\ai") + or command.strip().startswith(".llm") + or command.strip().startswith(".ai") + ) + + +@export +def get_llm_question(command) -> Optional[str]: + """ + Remove the llm/ai prefix + """ + command = command.removeprefix("\\llm") + command = command.removeprefix("\\ai") + command = command.removeprefix(".llm") + command = command.removeprefix(".ai") + return command + + +@export +def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: + _pattern = r"```sql\n(.*?)\n```" + schema_query = """ + SELECT sql FROM sqlite_master + WHERE sql IS NOT NULL + ORDER BY tbl_name, type DESC, name + """ + tables_query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + sample_row_query = "SELECT * FROM {table} LIMIT 1" + log.debug(schema_query) + cur.execute(schema_query) + db_schema = "\n".join([x for (x,) in cur.fetchall()]) + log.debug(tables_query) + cur.execute(tables_query) + sample_data = {} + for (table,) in cur.fetchall(): + sample_row = sample_row_query.format(table=table) + cur.execute(sample_row) + cols = [x[0] for x in cur.description] + row = cur.fetchone() + if row is None: # Skip empty tables + continue + sample_data[table] = list(zip(cols, row)) + + sys_prompt = f"""A SQLite database has the following schema: + {db_schema} + + Here is a sample data for each table: {sample_data} + + Use the provided schema and the sample data to construct a SQL query that + can be run in SQLite3 to answer + + {question} + + Explain the reason for choosing each table in the SQL query you have + written. Include a brief explanation of any built in SQLite3 functions. + Finally include the sql query in a code fence such as this one: + + ```sql + SELECT count(*) FROM table_name; + ``` + """ + log.debug(sys_prompt) + # model = llm.get_model("llama3.3") + # model = llm.get_model("qwq") + # model = llm.get_model("o1-preview") + # model = llm.get_model("o1-mini") + # model = llm.get_model("llama3.2") + model = llm.get_model("gpt-4o") + # model = llm.get_model("claude-3.5-haiku") + resp = model.prompt(sys_prompt) + result = resp.text() + match = re.search(_pattern, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + sql = "" + + return result, sql + + @special_command( "\\f", "\\f [name [args..]]", diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py index 49abdf0..9544811 100644 --- a/litecli/packages/special/main.py +++ b/litecli/packages/special/main.py @@ -152,5 +152,13 @@ def quit(*_args): arg_type=NO_QUERY, case_sensitive=True, ) +@special_command( + "\\llm", + "\\ai", + "Use LLM to construct a SQL query.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=(".ai", ".llm"), +) def stub(): raise NotImplementedError From c49fe59349117f57812da67798292acaad7bb96e Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Fri, 13 Dec 2024 22:46:30 -0800 Subject: [PATCH 03/21] Add concise to prompt. --- litecli/packages/special/iocommands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index a932d7f..510fffb 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -239,7 +239,7 @@ def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: {question} Explain the reason for choosing each table in the SQL query you have - written. Include a brief explanation of any built in SQLite3 functions. + written. Keep the explanation concise and to the point. Finally include the sql query in a code fence such as this one: ```sql From 3b0aed544bc5366cb0b1ad1f52b646f345dc0cf4 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Fri, 13 Dec 2024 22:48:01 -0800 Subject: [PATCH 04/21] Add llm as optional dep for ai extra. --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5caeb84..9ee326d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ build-backend = "setuptools.build_meta" litecli = "litecli.main:cli" [project.optional-dependencies] +ai = ["llm"] + dev = [ "behave>=1.2.6", "coverage>=7.2.7", From 6485fa6d126bec212b83167cd413989dcaae0292 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sun, 15 Dec 2024 19:57:53 -0800 Subject: [PATCH 05/21] Fix the grammar in the prompt. --- litecli/packages/special/iocommands.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index 510fffb..acf3d98 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -231,7 +231,7 @@ def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: sys_prompt = f"""A SQLite database has the following schema: {db_schema} - Here is a sample data for each table: {sample_data} + Here is a sample row of data from each table: {sample_data} Use the provided schema and the sample data to construct a SQL query that can be run in SQLite3 to answer @@ -253,6 +253,7 @@ def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: # model = llm.get_model("o1-mini") # model = llm.get_model("llama3.2") model = llm.get_model("gpt-4o") + # model = llm.get_model("gemini-2.0-flash-exp") # model = llm.get_model("claude-3.5-haiku") resp = model.prompt(sys_prompt) result = resp.text() From 75ff16536b92595a994d6fd9e7b808e64cd33fc2 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sun, 15 Dec 2024 22:39:44 -0800 Subject: [PATCH 06/21] Add a test to ensure mulitple columns are suggested during completion. --- tests/test_completion_engine.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_completion_engine.py b/tests/test_completion_engine.py index b04e184..86053d1 100644 --- a/tests/test_completion_engine.py +++ b/tests/test_completion_engine.py @@ -357,6 +357,18 @@ def test_sub_select_multiple_col_name_completion(): ) +def test_suggested_multiple_column_names(): + suggestions = suggest_type("SELECT id, from users", "SELECT id, ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "users", None)]}, + {"type": "function", "schema": []}, + {"type": "alias", "aliases": ["users"]}, + {"type": "keyword"}, + ] + ) + + def test_sub_select_dot_col_name_completion(): suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.") assert sorted_dicts(suggestions) == sorted_dicts( From 27e2a69d3584763a462b2cba14740457e58c55c0 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 30 Dec 2024 09:05:03 -0800 Subject: [PATCH 07/21] WIP --- litecli/main.py | 128 +++++++++--------- litecli/packages/completion_engine.py | 5 + litecli/packages/special/__init__.py | 1 + litecli/packages/special/iocommands.py | 94 ------------- litecli/packages/special/llm.py | 176 +++++++++++++++++++++++++ pyproject.toml | 2 + 6 files changed, 244 insertions(+), 162 deletions(-) create mode 100644 litecli/packages/special/llm.py diff --git a/litecli/main.py b/litecli/main.py index b7bcbcb..0ff8c11 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -347,27 +347,6 @@ def handle_editor_command(self, text): continue return text - def handle_llm_command(self, text): - if not special.is_llm_command(text): - return text - - cur = self.sqlexecute.conn.cursor() - try: - question = special.get_llm_question(text) - context, sql = special.sql_using_llm(cur=cur, question=question) - except Exception as e: - # Something went wrong. Raise an exception and bail. - raise RuntimeError(e) - while True: - try: - click.echo(context) - text = self.prompt_app.prompt(default=sql) - break - except KeyboardInterrupt: - sql = "" - - return text - def run_cli(self): iterations = 0 sqlexecute = self.sqlexecute @@ -406,6 +385,47 @@ def get_continuation(width, line_number, is_soft_wrap): def show_suggestion_tip(): return iterations < 2 + def output_res(res, start): + result_count = 0 + mutating = False + for title, cur, headers, status in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if is_select(status) and cur and cur.rowcount > threshold: + self.echo( + "The result set has more than {} rows.".format(threshold), + fg="red", + ) + if not confirm("Do you want to continue?"): + self.echo("Aborted!", err=True, fg="red") + break + + if self.auto_vertical_output: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width) + + t = time() - start + try: + if result_count > 0: + self.echo("") + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + self.echo("Time: %0.03fs" % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or is_mutating(status) + return mutating + def one_iteration(text=None): if text is None: try: @@ -423,13 +443,21 @@ def one_iteration(text=None): self.echo(str(e), err=True, fg="red") return - try: - text = self.handle_llm_command(text) - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return + if special.is_llm_command(text): + try: + start = time() + cur = self.sqlexecute.conn.cursor() + context, sql = special.handle_llm(text, cur) + if context: + click.echo(context) + text = self.prompt_app.prompt(default=sql) + except special.FinishIteration as e: + return output_res(e.results, start) if e.results else None + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + return if not text.strip(): return @@ -444,9 +472,6 @@ def one_iteration(text=None): self.echo("Wise choice!") return - # Keep track of whether or not the query is mutating. In case - # of a multi-statement query, the overall query is considered - # mutating if any one of the component statements is mutating mutating = False try: @@ -463,44 +488,11 @@ def one_iteration(text=None): res = sqlexecute.run(text) self.formatter.query = text successful = True - result_count = 0 - for title, cur, headers, status in res: - logger.debug("headers: %r", headers) - logger.debug("rows: %r", cur) - logger.debug("status: %r", status) - threshold = 1000 - if is_select(status) and cur and cur.rowcount > threshold: - self.echo( - "The result set has more than {} rows.".format(threshold), - fg="red", - ) - if not confirm("Do you want to continue?"): - self.echo("Aborted!", err=True, fg="red") - break - - if self.auto_vertical_output: - max_width = self.prompt_app.output.get_size().columns - else: - max_width = None - - formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width) - - t = time() - start - try: - if result_count > 0: - self.echo("") - try: - self.output(formatted, status) - except KeyboardInterrupt: - pass - self.echo("Time: %0.03fs" % t) - except KeyboardInterrupt: - pass - - start = time() - result_count += 1 - mutating = mutating or is_mutating(status) special.unset_once_if_written() + # Keep track of whether or not the query is mutating. In case + # of a multi-statement query, the overall query is considered + # mutating if any one of the component statements is mutating + mutating = output_res(res, start) special.unset_pipe_once_if_written() except EOFError as e: raise e diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py index 05b70ac..68c1392 100644 --- a/litecli/packages/completion_engine.py +++ b/litecli/packages/completion_engine.py @@ -118,6 +118,11 @@ def suggest_special(text): else: return [{"type": "table", "schema": []}] + if cmd in [".llm", ".ai", "\\llm", "\\ai"]: + word_before_cursor = last_word(arg, include="many_punctuations") + + return [{"type": "llm", "subcommand": word_before_cursor}] + return [{"type": "keyword"}, {"type": "special"}] diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py index 5924d09..0338c36 100644 --- a/litecli/packages/special/__init__.py +++ b/litecli/packages/special/__init__.py @@ -12,3 +12,4 @@ def export(defn): from . import dbcommands from . import iocommands +from . import llm diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index acf3d98..ddc927a 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -8,10 +8,8 @@ import subprocess from io import open from time import sleep -from typing import Optional, Tuple import click -import llm import sqlparse from configobj import ConfigObj @@ -174,98 +172,6 @@ def open_external_editor(filename=None, sql=None): return (query, message) -@export -def is_llm_command(command) -> bool: - """ - Is this an llm/ai command? - """ - return ( - command.strip().startswith("\\llm") - or command.strip().startswith("\\ai") - or command.strip().startswith(".llm") - or command.strip().startswith(".ai") - ) - - -@export -def get_llm_question(command) -> Optional[str]: - """ - Remove the llm/ai prefix - """ - command = command.removeprefix("\\llm") - command = command.removeprefix("\\ai") - command = command.removeprefix(".llm") - command = command.removeprefix(".ai") - return command - - -@export -def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: - _pattern = r"```sql\n(.*?)\n```" - schema_query = """ - SELECT sql FROM sqlite_master - WHERE sql IS NOT NULL - ORDER BY tbl_name, type DESC, name - """ - tables_query = """ - SELECT name FROM sqlite_master - WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' - ORDER BY 1 - """ - sample_row_query = "SELECT * FROM {table} LIMIT 1" - log.debug(schema_query) - cur.execute(schema_query) - db_schema = "\n".join([x for (x,) in cur.fetchall()]) - log.debug(tables_query) - cur.execute(tables_query) - sample_data = {} - for (table,) in cur.fetchall(): - sample_row = sample_row_query.format(table=table) - cur.execute(sample_row) - cols = [x[0] for x in cur.description] - row = cur.fetchone() - if row is None: # Skip empty tables - continue - sample_data[table] = list(zip(cols, row)) - - sys_prompt = f"""A SQLite database has the following schema: - {db_schema} - - Here is a sample row of data from each table: {sample_data} - - Use the provided schema and the sample data to construct a SQL query that - can be run in SQLite3 to answer - - {question} - - Explain the reason for choosing each table in the SQL query you have - written. Keep the explanation concise and to the point. - Finally include the sql query in a code fence such as this one: - - ```sql - SELECT count(*) FROM table_name; - ``` - """ - log.debug(sys_prompt) - # model = llm.get_model("llama3.3") - # model = llm.get_model("qwq") - # model = llm.get_model("o1-preview") - # model = llm.get_model("o1-mini") - # model = llm.get_model("llama3.2") - model = llm.get_model("gpt-4o") - # model = llm.get_model("gemini-2.0-flash-exp") - # model = llm.get_model("claude-3.5-haiku") - resp = model.prompt(sys_prompt) - result = resp.text() - match = re.search(_pattern, result, re.DOTALL) - if match: - sql = match.group(1).strip() - else: - sql = "" - - return result, sql - - @special_command( "\\f", "\\f [name [args..]]", diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py new file mode 100644 index 0000000..4903391 --- /dev/null +++ b/litecli/packages/special/llm.py @@ -0,0 +1,176 @@ +from collections import defaultdict +import logging +import re +import sys +from typing import Optional, Tuple +from runpy import run_module + +import click +import llm +from llm.cli import cli + +from . import export +from .main import parse_special_command + +log = logging.getLogger(__name__) +LLM_CLI_COMMANDS = list(cli.commands.keys()) +SUBCOMMANDS = [] +MODELS = [] + + +def list_all_commands(cmd, prefix=""): + """Recursively list all commands and subcommands. + + Args: + cmd (click.Command or click.Group): The Click command/group to inspect. + prefix (str): The command prefix for nested commands. + """ + results = defaultdict(list) + if isinstance(cmd, click.Group): + for name, subcmd in cmd.commands.items(): + results[name].append(subcmd) + if isinstance(subcmd, click.Group): + list_all_commands(subcmd, prefix=f"{full_command} ") + else: + # It's a single command without subcommands + print(prefix + cmd.name) + + +@export +class FinishIteration(Exception): + def __init__(self, results=None): + self.results = results + + +USAGE = """ +Use an LLM to create SQL queries to answer questions from your database. +Examples: + +# Ask a question. +> \\llm Most visited urls? + +# List available models +> \\llm models +gpt-4o +gpt-3.5-turbo +qwq + +# Change default model +> \\llm models default llama3 + +# Set api key (not required for local models) +> \\llm keys set openai sg-1234 +API key set for openai. + +# Install a model plugin +> \\llm install llm-ollama +llm-ollama installed. + +# Models directory +# https://llm.datasette.io/en/stable/plugins/directory.html +""" + + +@export +def handle_llm(text, cur) -> Tuple[str, Optional[str]]: + cmd, verbose, arg = parse_special_command(text) + + if not arg.strip(): # No question provided. Print usage and bail. + output = [(None, None, None, USAGE)] + raise FinishIteration(output) + + parts = arg.split() + + if parts[0].startswith("-") or parts[0] in LLM_CLI_COMMANDS: + # If the first argument is a flag or a valid llm command then + # invoke the llm cli. + sys.argv = ["llm"] + parts + try: + run_module("llm", run_name="__main__") + except SystemExit: + raise FinishIteration(None) + + try: + context, sql = sql_using_llm(cur=cur, question=arg, verbose=verbose) + if not verbose: + context = "" + return context, sql + except Exception as e: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(e) + + +@export +def is_llm_command(command) -> bool: + """ + Is this an llm/ai command? + """ + cmd, _, _ = parse_special_command(command) + return cmd in ("\\llm", "\\ai", ".llm", ".ai") + + +@export +def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str]]: + _pattern = r"```sql\n(.*?)\n```" + schema_query = """ + SELECT sql FROM sqlite_master + WHERE sql IS NOT NULL + ORDER BY tbl_name, type DESC, name + """ + tables_query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + sample_row_query = "SELECT * FROM {table} LIMIT 1" + log.debug(schema_query) + cur.execute(schema_query) + db_schema = "\n".join([x for (x,) in cur.fetchall()]) + log.debug(tables_query) + cur.execute(tables_query) + sample_data = {} + for (table,) in cur.fetchall(): + sample_row = sample_row_query.format(table=table) + cur.execute(sample_row) + cols = [x[0] for x in cur.description] + row = cur.fetchone() + if row is None: # Skip empty tables + continue + sample_data[table] = list(zip(cols, row)) + + sys_prompt = f"""A SQLite database has the following schema: + {db_schema} + + Here is a sample row of data from each table: {sample_data} + + Use the provided schema and the sample data to construct a SQL query that + can be run in SQLite3 to answer + + {question} + + Explain the reason for choosing each table in the SQL query you have + written. Keep the explanation concise and to the point. + Finally include the sql query in a code fence such as this one: + + ```sql + SELECT count(*) FROM table_name; + ``` + """ + log.debug(sys_prompt) + # model = llm.get_model("llama3.3") + # model = llm.get_model("qwq") + # model = llm.get_model("o1-preview") + # model = llm.get_model("o1-mini") + # model = llm.get_model("llama3.2") + model = llm.get_model("gpt-4o") + # model = llm.get_model("gemini-2.0-flash-exp") + # model = llm.get_model("claude-3.5-haiku") + resp = model.prompt(sys_prompt) + result = resp.text() + match = re.search(_pattern, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + sql = "" + + return result, sql diff --git a/pyproject.toml b/pyproject.toml index 9ee326d..fa8b624 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,8 @@ dependencies = [ "prompt-toolkit>=3.0.3,<4.0.0", "pygments>=1.6", "sqlparse>=0.4.4", + "setuptools", # Required by llm commands to install models + "pip", ] [build-system] From 135143c14f58cf731e6315c9e72cfacb071b6b5d Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Tue, 31 Dec 2024 21:03:15 -0800 Subject: [PATCH 08/21] Initial llm completion impl. --- litecli/packages/completion_engine.py | 4 +- litecli/packages/special/llm.py | 54 ++++++++++++++++++++------- litecli/sqlcompleter.py | 11 ++++++ 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py index 68c1392..2d9a033 100644 --- a/litecli/packages/completion_engine.py +++ b/litecli/packages/completion_engine.py @@ -119,9 +119,7 @@ def suggest_special(text): return [{"type": "table", "schema": []}] if cmd in [".llm", ".ai", "\\llm", "\\ai"]: - word_before_cursor = last_word(arg, include="many_punctuations") - - return [{"type": "llm", "subcommand": word_before_cursor}] + return [{"type": "llm"}] return [{"type": "keyword"}, {"type": "special"}] diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 4903391..614e627 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -1,9 +1,8 @@ -from collections import defaultdict import logging import re import sys -from typing import Optional, Tuple from runpy import run_module +from typing import Optional, Tuple import click import llm @@ -14,26 +13,55 @@ log = logging.getLogger(__name__) LLM_CLI_COMMANDS = list(cli.commands.keys()) -SUBCOMMANDS = [] -MODELS = [] -def list_all_commands(cmd, prefix=""): - """Recursively list all commands and subcommands. +def build_command_tree(cmd): + """Recursively build a command tree for a Click app. Args: cmd (click.Command or click.Group): The Click command/group to inspect. - prefix (str): The command prefix for nested commands. + + Returns: + dict: A nested dictionary representing the command structure. """ - results = defaultdict(list) + tree = {} if isinstance(cmd, click.Group): for name, subcmd in cmd.commands.items(): - results[name].append(subcmd) - if isinstance(subcmd, click.Group): - list_all_commands(subcmd, prefix=f"{full_command} ") + # Recursively build the tree for subcommands + tree[name] = build_command_tree(subcmd) else: - # It's a single command without subcommands - print(prefix + cmd.name) + # Leaf command with no subcommands + tree = None + return tree + + +# Generate the tree +COMMAND_TREE = build_command_tree(cli) + + +def get_completions(tokens, tree=COMMAND_TREE): + """Get autocompletions for the current command tokens. + + Args: + tree (dict): The command tree. + tokens (list): List of tokens (command arguments). + + Returns: + list: List of possible completions. + """ + current_tree = tree + for token in tokens: + if token.startswith("-"): + # Skip options (flags) + continue + if token in current_tree: + current_tree = current_tree[token] + else: + # No completions available + return [] + + # Return possible completions (keys of the current tree level) + return list(current_tree.keys()) if current_tree else [] @export diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py index b154432..c01a6f1 100644 --- a/litecli/sqlcompleter.py +++ b/litecli/sqlcompleter.py @@ -9,6 +9,7 @@ from .packages.completion_engine import suggest_type from .packages.parseutils import last_word from .packages.special.iocommands import favoritequeries +from .packages.special import llm from .packages.filepaths import parse_path, complete_path, suggest_path _logger = logging.getLogger(__name__) @@ -529,6 +530,16 @@ def get_completions(self, document, complete_event): elif suggestion["type"] == "file_name": file_names = self.find_files(word_before_cursor) completions.extend(file_names) + elif suggestion["type"] == "llm": + tokens = document.text.split() + possible_entries = llm.get_completions(tokens[1:]) + subcommands = self.find_matches( + word_before_cursor, + possible_entries, + start_only=False, + fuzzy=True, + ) + completions.extend(subcommands) return completions From e8707b724b75ac47bb07c9e93e7da6ab1c8e45bf Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Wed, 1 Jan 2025 07:31:13 -0800 Subject: [PATCH 09/21] Now with working autocompletion. --- litecli/packages/special/llm.py | 8 ++++++-- litecli/sqlcompleter.py | 7 +++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 614e627..11dcaba 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -13,6 +13,7 @@ log = logging.getLogger(__name__) LLM_CLI_COMMANDS = list(cli.commands.keys()) +MODELS = {x.model_id: None for x in llm.get_models()} def build_command_tree(cmd): @@ -27,8 +28,11 @@ def build_command_tree(cmd): tree = {} if isinstance(cmd, click.Group): for name, subcmd in cmd.commands.items(): - # Recursively build the tree for subcommands - tree[name] = build_command_tree(subcmd) + if cmd.name == "models" and name == "default": + tree[name] = MODELS + else: + # Recursively build the tree for subcommands + tree[name] = build_command_tree(subcmd) else: # Leaf command with no subcommands tree = None diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py index c01a6f1..d6d21c7 100644 --- a/litecli/sqlcompleter.py +++ b/litecli/sqlcompleter.py @@ -531,8 +531,11 @@ def get_completions(self, document, complete_event): file_names = self.find_files(word_before_cursor) completions.extend(file_names) elif suggestion["type"] == "llm": - tokens = document.text.split() - possible_entries = llm.get_completions(tokens[1:]) + if not word_before_cursor: + tokens = document.text.split()[1:] + else: + tokens = document.text.split()[1:-1] + possible_entries = llm.get_completions(tokens) subcommands = self.find_matches( word_before_cursor, possible_entries, From 647e1f2277b04c17a2843c6f6eef47858742b404 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sat, 4 Jan 2025 20:11:50 -0800 Subject: [PATCH 10/21] Install llm and restart litecli. --- litecli/packages/special/llm.py | 42 ++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 11dcaba..514f0a0 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -1,3 +1,4 @@ +import os import logging import re import sys @@ -5,15 +6,21 @@ from typing import Optional, Tuple import click -import llm -from llm.cli import cli + +try: + import llm + from llm.cli import cli + + LLM_CLI_COMMANDS = list(cli.commands.keys()) + MODELS = {x.model_id: None for x in llm.get_models()} +except ImportError: + llm = None + cli = None from . import export from .main import parse_special_command log = logging.getLogger(__name__) -LLM_CLI_COMMANDS = list(cli.commands.keys()) -MODELS = {x.model_id: None for x in llm.get_models()} def build_command_tree(cmd): @@ -53,19 +60,18 @@ def get_completions(tokens, tree=COMMAND_TREE): Returns: list: List of possible completions. """ - current_tree = tree for token in tokens: if token.startswith("-"): # Skip options (flags) continue - if token in current_tree: - current_tree = current_tree[token] + if tree and token in tree: + tree = tree[token] else: # No completions available return [] # Return possible completions (keys of the current tree level) - return list(current_tree.keys()) if current_tree else [] + return list(tree.keys()) if tree else [] @export @@ -107,6 +113,26 @@ def __init__(self, results=None): def handle_llm(text, cur) -> Tuple[str, Optional[str]]: cmd, verbose, arg = parse_special_command(text) + if llm is None: + original_exe = sys.executable + original_args = sys.argv + # LLM is not installed. + # Offer to install it. + if click.confirm("This feature requires additional libraries. Install LLM library?", default=False): + click.echo("Installing LLM library. Please wait...") + sys.argv = ["pip", "install", "--quiet", "llm"] + try: + run_module("pip", run_name="__main__") + except SystemExit: + # output = [(None, None, None, "Please restart litecli to use this feature.")] + # raise FinishIteration(output) + pass + if click.confirm("LLM library installed. Would you like to restart litecli now?", default=True): + click.echo("Restarting litecli...") + os.execv(original_exe, [original_exe] + original_args) + + raise FinishIteration(None) + if not arg.strip(): # No question provided. Print usage and bail. output = [(None, None, None, USAGE)] raise FinishIteration(output) From bf7e550d5099ffe97914b0d3834af8879c9c148b Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sat, 4 Jan 2025 23:41:30 -0800 Subject: [PATCH 11/21] Using llm template. --- litecli/packages/special/llm.py | 121 +++++++++++++++++++------------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 514f0a0..7abbf80 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -1,6 +1,9 @@ -import os +import contextlib +import io import logging +import os import re +import shlex import sys from runpy import run_module from typing import Optional, Tuple @@ -108,36 +111,62 @@ def __init__(self, results=None): # https://llm.datasette.io/en/stable/plugins/directory.html """ +PROMPT = """A SQLite database has the following schema: + +$db_schema + +Here is a sample row of data from each table: $sample_data + +Use the provided schema and the sample data to construct a SQL query that +can be run in SQLite3 to answer + +$question + +Explain the reason for choosing each table in the SQL query you have +written. Keep the explanation concise. +Finally include a sql query in a code fence such as this one: + +```sql +SELECT count(*) FROM table_name; +``` +""" + + +def initialize_llm(): + # Initialize the LLM library. + # Create a template called litecli with the default prompt. + original_exe = sys.executable + original_args = sys.argv + if click.confirm("This feature requires additional libraries. Install LLM library?", default=True): + click.echo("Installing LLM library. Please wait...") + sys.argv = ["pip", "install", "--quiet", "llm"] + try: + run_module("pip", run_name="__main__") + except SystemExit: + pass + sys.argv = ["llm", PROMPT, "--save", "litecli"] # TODO: check if the template already exists + try: + run_module("llm", run_name="__main__") + except SystemExit: + pass + click.echo("Restarting litecli...") + os.execv(original_exe, [original_exe] + original_args) + @export def handle_llm(text, cur) -> Tuple[str, Optional[str]]: - cmd, verbose, arg = parse_special_command(text) + _, verbose, arg = parse_special_command(text) + # LLM is not installed. if llm is None: - original_exe = sys.executable - original_args = sys.argv - # LLM is not installed. - # Offer to install it. - if click.confirm("This feature requires additional libraries. Install LLM library?", default=False): - click.echo("Installing LLM library. Please wait...") - sys.argv = ["pip", "install", "--quiet", "llm"] - try: - run_module("pip", run_name="__main__") - except SystemExit: - # output = [(None, None, None, "Please restart litecli to use this feature.")] - # raise FinishIteration(output) - pass - if click.confirm("LLM library installed. Would you like to restart litecli now?", default=True): - click.echo("Restarting litecli...") - os.execv(original_exe, [original_exe] + original_args) - + initialize_llm() raise FinishIteration(None) if not arg.strip(): # No question provided. Print usage and bail. output = [(None, None, None, USAGE)] raise FinishIteration(output) - parts = arg.split() + parts = shlex.split(arg) if parts[0].startswith("-") or parts[0] in LLM_CLI_COMMANDS: # If the first argument is a flag or a valid llm command then @@ -184,6 +213,7 @@ def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str] log.debug(schema_query) cur.execute(schema_query) db_schema = "\n".join([x for (x,) in cur.fetchall()]) + log.debug(tables_query) cur.execute(tables_query) sample_data = {} @@ -196,35 +226,30 @@ def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str] continue sample_data[table] = list(zip(cols, row)) - sys_prompt = f"""A SQLite database has the following schema: - {db_schema} - - Here is a sample row of data from each table: {sample_data} - - Use the provided schema and the sample data to construct a SQL query that - can be run in SQLite3 to answer - - {question} - - Explain the reason for choosing each table in the SQL query you have - written. Keep the explanation concise and to the point. - Finally include the sql query in a code fence such as this one: + sys.argv = [ + "llm", + "--no-stream", + "--template", + "litecli", + "--param", + "db_schema", + db_schema, + "--param", + "sample_data", + sample_data, + "--param", + "question", + question, + " ", # Dummy argument to prevent llm from waiting on stdin + ] + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + try: + run_module("llm", run_name="__main__") + except SystemExit: + pass - ```sql - SELECT count(*) FROM table_name; - ``` - """ - log.debug(sys_prompt) - # model = llm.get_model("llama3.3") - # model = llm.get_model("qwq") - # model = llm.get_model("o1-preview") - # model = llm.get_model("o1-mini") - # model = llm.get_model("llama3.2") - model = llm.get_model("gpt-4o") - # model = llm.get_model("gemini-2.0-flash-exp") - # model = llm.get_model("claude-3.5-haiku") - resp = model.prompt(sys_prompt) - result = resp.text() + result = buffer.getvalue() match = re.search(_pattern, result, re.DOTALL) if match: sql = match.group(1).strip() From 74395021cb8eae595014af13eada15d1a725c0f2 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sun, 5 Jan 2025 20:43:07 -0800 Subject: [PATCH 12/21] Tiny refactors. --- litecli/packages/special/dbcommands.py | 2 +- litecli/packages/special/iocommands.py | 3 +-- litecli/packages/special/llm.py | 25 ++++++++++++++++++------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py index 66ee8e2..315f6c7 100644 --- a/litecli/packages/special/dbcommands.py +++ b/litecli/packages/special/dbcommands.py @@ -235,7 +235,7 @@ def describe(cur, arg, **_): @special_command( - ".itables_query", + ".import", ".import filename table", "Import data from filename into an existing table", arg_type=PARSED_QUERY, diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index ddc927a..ec65672 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -13,8 +13,7 @@ import sqlparse from configobj import ConfigObj -from litecli.packages.prompt_utils import confirm_destructive_query - +from ..prompt_utils import confirm_destructive_query from . import export from .favoritequeries import FavoriteQueries from .main import NO_QUERY, PARSED_QUERY, special_command diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 7abbf80..9974ce0 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -111,6 +111,7 @@ def __init__(self, results=None): # https://llm.datasette.io/en/stable/plugins/directory.html """ +_SQL_CODE_FENCE = r"```sql\n(.*?)\n```" PROMPT = """A SQLite database has the following schema: $db_schema @@ -171,11 +172,23 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]: if parts[0].startswith("-") or parts[0] in LLM_CLI_COMMANDS: # If the first argument is a flag or a valid llm command then # invoke the llm cli. + # Check if there is a SQL fenced code and return it. sys.argv = ["llm"] + parts - try: - run_module("llm", run_name="__main__") - except SystemExit: - raise FinishIteration(None) + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + try: + run_module("llm", run_name="__main__") + except SystemExit: + pass + result = buffer.getvalue() + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + output = [(None, None, None, result)] + raise FinishIteration(output) + + return result if verbose else "", sql try: context, sql = sql_using_llm(cur=cur, question=arg, verbose=verbose) @@ -198,7 +211,6 @@ def is_llm_command(command) -> bool: @export def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str]]: - _pattern = r"```sql\n(.*?)\n```" schema_query = """ SELECT sql FROM sqlite_master WHERE sql IS NOT NULL @@ -228,7 +240,6 @@ def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str] sys.argv = [ "llm", - "--no-stream", "--template", "litecli", "--param", @@ -250,7 +261,7 @@ def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str] pass result = buffer.getvalue() - match = re.search(_pattern, result, re.DOTALL) + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) if match: sql = match.group(1).strip() else: From 381096ec492b0bc71e98cb28c9ff58d9cae5b634 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 13 Jan 2025 18:24:18 -0800 Subject: [PATCH 13/21] Handle various ways to invoke llm. Decide when to pass the context and when to capture output. --- litecli/packages/special/llm.py | 69 ++++++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 9974ce0..867d274 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -156,6 +156,14 @@ def initialize_llm(): @export def handle_llm(text, cur) -> Tuple[str, Optional[str]]: + """This function handles the special command `\\llm`. + + If it deals with a question that results in a SQL query then it will return + the query. + If it deals with a subcommand like `models` or `keys` then it will raise + FinishIteration() which will be caught by the main loop AND print any + output that was supplied (or None). + """ _, verbose, arg = parse_special_command(text) # LLM is not installed. @@ -169,26 +177,57 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]: parts = shlex.split(arg) - if parts[0].startswith("-") or parts[0] in LLM_CLI_COMMANDS: - # If the first argument is a flag or a valid llm command then - # invoke the llm cli. - # Check if there is a SQL fenced code and return it. + # If the parts has `-c` then capture the output and check for fenced SQL. + # User is continuing a previous question. + # eg: \llm -m ollama -c "Show ony the top 5 results" + if "-c" in parts: + capture_output = True + use_context = False + # If the parts has `pormpt` command without `-c` then use context to the prompt. + # \llm -m ollama prompt "Most visited urls?" + elif "prompt" in parts: # User might invoke prompt with an option flag in the first argument. + capture_output = True + use_context = True + # If the parts starts with any of the known LLM_CLI_COMMANDS then invoke + # the llm and don't capture output. This is to handle commands like `models` or `keys`. + elif parts[0] in LLM_CLI_COMMANDS: + capture_output = False + use_context = False + # If the parts doesn't have any known LLM_CLI_COMMANDS then the user is + # invoking a question. eg: \llm -m ollama "Most visited urls?" + elif not set(parts).intersection(LLM_CLI_COMMANDS): + capture_output = True + use_context = True + # User invoked llm with a question without `prompt` subcommand. Capture the + # output and check for fenced SQL. eg: \llm "Most visited urls?" + else: + capture_output = True + use_context = True + + if not use_context: sys.argv = ["llm"] + parts - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): + if capture_output: + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + try: + run_module("llm", run_name="__main__") + except SystemExit: + pass + result = buffer.getvalue() + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + output = [(None, None, None, result)] + raise FinishIteration(output) + + return result if verbose else "", sql + else: try: run_module("llm", run_name="__main__") except SystemExit: pass - result = buffer.getvalue() - match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) - if match: - sql = match.group(1).strip() - else: - output = [(None, None, None, result)] - raise FinishIteration(output) - - return result if verbose else "", sql + raise FinishIteration(None) try: context, sql = sql_using_llm(cur=cur, question=arg, verbose=verbose) From 75821297a85cda8caa9236728c683bd1aa08ff38 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 13 Jan 2025 20:18:43 -0800 Subject: [PATCH 14/21] Handle Ctrl-C to abort a command without quitting. --- litecli/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litecli/main.py b/litecli/main.py index 0ff8c11..a9eced3 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -451,6 +451,8 @@ def one_iteration(text=None): if context: click.echo(context) text = self.prompt_app.prompt(default=sql) + except KeyboardInterrupt: + return except special.FinishIteration as e: return output_res(e.results, start) if e.results else None except RuntimeError as e: From 12146c79afd3e2c174ca90358ad3b493f2cf192f Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 13 Jan 2025 21:08:23 -0800 Subject: [PATCH 15/21] Fix up the usage to include quotes. --- litecli/packages/special/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 867d274..7f8171c 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -88,7 +88,7 @@ def __init__(self, results=None): Examples: # Ask a question. -> \\llm Most visited urls? +> \\llm 'Most visited urls?' # List available models > \\llm models From 06eacba36f6aab7d09f7568c0b33db2c5303dd40 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sun, 19 Jan 2025 20:37:58 -0800 Subject: [PATCH 16/21] Abstract the external calls to a separate function. --- litecli/packages/special/llm.py | 98 ++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 37 deletions(-) diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 7f8171c..3f4cab3 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -26,6 +26,40 @@ log = logging.getLogger(__name__) +def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True): + original_exe = sys.executable + original_args = sys.argv + + try: + sys.argv = [cmd] + list(args) + code = 0 + + if capture_output: + buffer = io.StringIO() + redirect = contextlib.redirect_stdout(buffer) + else: + # Use nullcontext to do nothing when not capturing output + redirect = contextlib.nullcontext() + + with redirect: + try: + run_module(cmd, run_name="__main__") + except SystemExit as e: + code = e.code + if code != 0 and raise_exception: + raise + + if restart_cli and code == 0: + os.execv(original_exe, [original_exe] + original_args) + + if capture_output: + return code, buffer.getvalue() + else: + return code, "" + finally: + sys.argv = original_args + + def build_command_tree(cmd): """Recursively build a command tree for a Click app. @@ -135,23 +169,24 @@ def __init__(self, results=None): def initialize_llm(): # Initialize the LLM library. - # Create a template called litecli with the default prompt. - original_exe = sys.executable - original_args = sys.argv if click.confirm("This feature requires additional libraries. Install LLM library?", default=True): click.echo("Installing LLM library. Please wait...") - sys.argv = ["pip", "install", "--quiet", "llm"] - try: - run_module("pip", run_name="__main__") - except SystemExit: - pass - sys.argv = ["llm", PROMPT, "--save", "litecli"] # TODO: check if the template already exists - try: - run_module("llm", run_name="__main__") - except SystemExit: - pass - click.echo("Restarting litecli...") - os.execv(original_exe, [original_exe] + original_args) + run_external_cmd("pip", "install", "--quiet", "llm", restart_cli=True) + ensure_litecli_template() + + +def ensure_litecli_template(replace=False): + """ + Create a template called litecli with the default prompt. + """ + if not replace: + # Check if it already exists. + code, _ = run_external_cmd("llm", "templates", "show", "litecli", capture_output=True, raise_exception=False) + if code == 0: # Template already exists. No need to create it. + return + + run_external_cmd("llm", PROMPT, "--save", "litecli") + return @export @@ -177,6 +212,7 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]: parts = shlex.split(arg) + restart = False # If the parts has `-c` then capture the output and check for fenced SQL. # User is continuing a previous question. # eg: \llm -m ollama -c "Show ony the top 5 results" @@ -188,6 +224,10 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]: elif "prompt" in parts: # User might invoke prompt with an option flag in the first argument. capture_output = True use_context = True + elif "install" in parts or "uninstall" in parts: + capture_output = False + use_context = False + restart = True # If the parts starts with any of the known LLM_CLI_COMMANDS then invoke # the llm and don't capture output. This is to handle commands like `models` or `keys`. elif parts[0] in LLM_CLI_COMMANDS: @@ -205,15 +245,9 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]: use_context = True if not use_context: - sys.argv = ["llm"] + parts + args = parts if capture_output: - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - try: - run_module("llm", run_name="__main__") - except SystemExit: - pass - result = buffer.getvalue() + _, result = run_external_cmd("llm", *args, capture_output=capture_output) match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) if match: sql = match.group(1).strip() @@ -223,13 +257,11 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]: return result if verbose else "", sql else: - try: - run_module("llm", run_name="__main__") - except SystemExit: - pass + run_external_cmd("llm", *args, restart_cli=restart) raise FinishIteration(None) try: + ensure_litecli_template() context, sql = sql_using_llm(cur=cur, question=arg, verbose=verbose) if not verbose: context = "" @@ -277,8 +309,7 @@ def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str] continue sample_data[table] = list(zip(cols, row)) - sys.argv = [ - "llm", + args = [ "--template", "litecli", "--param", @@ -292,14 +323,7 @@ def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str] question, " ", # Dummy argument to prevent llm from waiting on stdin ] - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - try: - run_module("llm", run_name="__main__") - except SystemExit: - pass - - result = buffer.getvalue() + _, result = run_external_cmd("llm", *args, capture_output=True) match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) if match: sql = match.group(1).strip() From 1953f5c2cf47db0ec92fbbd55918da918656d22e Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Sun, 19 Jan 2025 21:51:55 -0800 Subject: [PATCH 17/21] Add some tests to the llm functionality. --- litecli/packages/special/llm.py | 2 +- tests/test_llm_special.py | 162 ++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 tests/test_llm_special.py diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py index 3f4cab3..0cf4c6f 100644 --- a/litecli/packages/special/llm.py +++ b/litecli/packages/special/llm.py @@ -215,7 +215,7 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]: restart = False # If the parts has `-c` then capture the output and check for fenced SQL. # User is continuing a previous question. - # eg: \llm -m ollama -c "Show ony the top 5 results" + # eg: \llm -m ollama -c "Show only the top 5 results" if "-c" in parts: capture_output = True use_context = False diff --git a/tests/test_llm_special.py b/tests/test_llm_special.py new file mode 100644 index 0000000..2f3b010 --- /dev/null +++ b/tests/test_llm_special.py @@ -0,0 +1,162 @@ +import pytest +from unittest.mock import patch +from litecli.packages.special.llm import handle_llm, FinishIteration, USAGE + + +@patch("litecli.packages.special.llm.initialize_llm") +@patch("litecli.packages.special.llm.llm", new=None) +def test_llm_command_without_install(mock_initialize_llm, executor): + """ + Test that handle_llm initializes llm when it is None and raises FinishIteration. + """ + test_text = r"\llm" + cur_mock = executor + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, cur_mock) + + mock_initialize_llm.assert_called_once() + assert exc_info.value.args[0] is None + + +@patch("litecli.packages.special.llm.llm") +def test_llm_command_without_args(mock_llm, executor): + r""" + Invoking \llm without any arguments should print the usage and raise + FinishIteration. + """ + assert mock_llm is not None + test_text = r"\llm" + cur_mock = executor + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, cur_mock) + + assert exc_info.value.args[0] == [(None, None, None, USAGE)] + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor): + # Suppose the LLM returns some text without fenced SQL + mock_run_cmd.return_value = (0, "Hello, I have no SQL for you today.") + + test_text = r"\llm -c 'Something interesting?'" + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + + # We expect no code fence => FinishIteration with that output + assert exc_info.value.args[0] == [(None, None, None, "Hello, I have no SQL for you today.")] + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor): + # The luscious SQL is inside triple backticks + return_text = "Here is your query:\n" "```sql\nSELECT * FROM table;\n```" + mock_run_cmd.return_value = (0, return_text) + + test_text = r"\llm -c 'Rewrite the SQL without CTE'" + + result, sql = handle_llm(test_text, executor) + + # We expect the function to return (result, sql), but result might be "" if verbose is not set + # By default, `verbose` is false unless text has something like \llm --verbose? + # The function code: return result if verbose else "", sql + # Our test_text doesn't set verbose => we expect "" for the returned context. + assert result == "" + assert sql == "SELECT * FROM table;" + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): + """ + If the parts[0] is in LLM_CLI_COMMANDS, we do NOT capture output, we just call run_external_cmd + and then raise FinishIteration. + """ + # Let's assume 'models' is in LLM_CLI_COMMANDS + test_text = r"\llm models" + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + + # We check that run_external_cmd was called with these arguments: + mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False) + # And the function should raise FinishIteration(None) + assert exc_info.value.args[0] is None + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor): + """ + If 'install' or 'uninstall' is in the parts, we do not capture output but restart the CLI. + """ + test_text = r"\llm install openai" + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + + # We expect a restart + mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True) + assert exc_info.value.args[0] is None + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.ensure_litecli_template") +@patch("litecli.packages.special.llm.sql_using_llm") +def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm prompt "some question" + Should use context, capture output, and call sql_using_llm. + """ + # Mock out the return from sql_using_llm + mock_sql_using_llm.return_value = ("context from LLM", "SELECT 1;") + + test_text = r"\llm prompt 'Magic happening here?'" + context, sql = handle_llm(test_text, executor) + + # ensure_litecli_template should be called + mock_ensure_template.assert_called_once() + # sql_using_llm should be called with question=arg, which is "prompt 'Magic happening here?'" + # Actually, the question is the entire "prompt 'Magic happening here?'" minus the \llm + # But in the function we do parse shlex.split. + mock_sql_using_llm.assert_called() + assert context == "" + assert sql == "SELECT 1;" + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.ensure_litecli_template") +@patch("litecli.packages.special.llm.sql_using_llm") +def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + """ + If arg doesn't contain any known command, it's treated as a question => capture output + context. + """ + mock_sql_using_llm.return_value = ("You have context!", "SELECT 2;") + + test_text = r"\llm 'Top 10 downloads by size.'" + context, sql = handle_llm(test_text, executor) + + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "" + assert sql == "SELECT 2;" + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.ensure_litecli_template") +@patch("litecli.packages.special.llm.sql_using_llm") +def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + Invoking \llm+ returns the context and the SQL query. + """ + mock_sql_using_llm.return_value = ("Verbose context, oh yeah!", "SELECT 42;") + + test_text = r"\llm+ 'Top 10 downloads by size.'" + context, sql = handle_llm(test_text, executor) + + assert context == "Verbose context, oh yeah!" + assert sql == "SELECT 42;" From 0d9e93106f345945c7b51af29c9dc7815a01ec25 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 20 Jan 2025 12:33:12 -0800 Subject: [PATCH 18/21] Update changelog. --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1e8c68..c4e29c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +## 1.14.0 - 2025-01-20 + +### Features + +* Add LLM feature to ask an LLM to create a SQL query. + - This adds a new `\llm` special command + - eg: `\llm "Who is the largest customer based on revenue?"` + ## 1.13.2 - 2024-11-24 ### Internal From ddfe5250910bbd8a8917fb05ef3d9adfb51d12f8 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 20 Jan 2025 14:54:34 -0800 Subject: [PATCH 19/21] Add llm to the dev deps. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index fa8b624..ade230a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "pytest-cov>=4.1.0", "tox>=4.8.0", "pdbpp>=0.10.3", + "llm>=0.19.0", ] [tool.setuptools.packages.find] From 23fdbd65e6346020bd12360c26168575d3f0da8d Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 20 Jan 2025 14:57:25 -0800 Subject: [PATCH 20/21] Change min python version to 3.9 --- .github/workflows/ci.yml | 2 +- .github/workflows/publish.yml | 2 +- CHANGELOG.md | 4 ++++ pyproject.toml | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8e327a9..b1586e9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6073ec5..f58a4e4 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/CHANGELOG.md b/CHANGELOG.md index c4e29c4..bebe66b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ - This adds a new `\llm` special command - eg: `\llm "Who is the largest customer based on revenue?"` +### Internal + +* Change min required python version to 3.9+ + ## 1.13.2 - 2024-11-24 ### Internal diff --git a/pyproject.toml b/pyproject.toml index ade230a..ba9a9a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "litecli" dynamic = ["version"] description = "CLI for SQLite Databases with auto-completion and syntax highlighting." readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.9" license = { text = "BSD" } authors = [{ name = "dbcli", email = "litecli-users@googlegroups.com" }] urls = { "homepage" = "https://github.com/dbcli/litecli" } From f901c7e4de3a333de1b23fe5d9cc0c569c3dd044 Mon Sep 17 00:00:00 2001 From: Amjith Ramanujam Date: Mon, 20 Jan 2025 14:59:13 -0800 Subject: [PATCH 21/21] Include python 3.13 to test matrix. --- .github/workflows/ci.yml | 2 +- .github/workflows/publish.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1586e9..2b71bcb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f58a4e4..0491657 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4