diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index b22273d..855ba7b 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,18 +1,6 @@ { "name": "Python 3.10", "image": "mcr.microsoft.com/devcontainers/python:3.10", - "runArgs": [ - "--runtime", - "nvidia", - "--gpus", - "all", - // optional but make sure CUDA workloads are available - "--env", - "NVIDIA_VISIBLE_DEVICES=all", - // optional but make sure CUDA workloads are available - "--env", - "NVIDIA_DRIVER_CAPABILITIES=compute,utility" - ], "customizations": { "vscode": { "extensions": [ @@ -26,13 +14,5 @@ "hbenl.vscode-test-explorer" ] } - }, - "hostRequirements": { - "gpu": "optional" - }, - "remoteEnv": { - // optional but make sure CUDA workloads are available - "NVIDIA_VISIBLE_DEVICES": "all", - "NVIDIA_DRIVER_CAPABILITIES": "compute,utility" } } diff --git a/.devcontainer/py3.11/devcontainer.json b/.devcontainer/py3.11/devcontainer.json new file mode 100644 index 0000000..42bb225 --- /dev/null +++ b/.devcontainer/py3.11/devcontainer.json @@ -0,0 +1,38 @@ +{ + "name": "Python 3.11", + "image": "mcr.microsoft.com/devcontainers/python:3.11", + "runArgs": [ + "--runtime", + "nvidia", + "--gpus", + "all", + // optional but make sure CUDA workloads are available + "--env", + "NVIDIA_VISIBLE_DEVICES=all", + // optional but make sure CUDA workloads are available + "--env", + "NVIDIA_DRIVER_CAPABILITIES=compute,utility" + ], + "customizations": { + "vscode": { + "extensions": [ + "ms-python.black-formatter", + "ms-python.flake8", + "ms-python.isort", + "ms-python.vscode-pylance", + "ms-python.python", + "ms-python.debugpy", + "ms-python.vscode-python-envs", + "hbenl.vscode-test-explorer" + ] + } + }, + "hostRequirements": { + "gpu": "optional" + }, + "remoteEnv": { + // optional but make sure CUDA workloads are available + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "compute,utility" + } +} diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 36c77c2..a929f23 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -49,6 +49,10 @@ jobs: - name: Test # Using default directory for models + # COVERAGE_CORE=pytrace: Workaround for Python 3.11 segfault with SQLite extensions + C tracer + # See: https://github.com/nedbat/coveragepy/issues/1665 + env: + COVERAGE_CORE: ${{ matrix.python-version == '3.11' && 'pytrace' || '' }} run: | pytest --cov --cov-branch --cov-report=xml -v -m "not slow" ./tests diff --git a/src/sqlite_rag/cli.py b/src/sqlite_rag/cli.py index 6574889..7035ff5 100644 --- a/src/sqlite_rag/cli.py +++ b/src/sqlite_rag/cli.py @@ -439,12 +439,12 @@ def reset( def search( ctx: typer.Context, query: str, - limit: int = typer.Option(10, help="Number of results to return"), + limit: int = typer.Option(5, help="Number of results to return"), debug: bool = typer.Option( False, "-d", "--debug", - help="Print extra debug information with modern formatting", + help="Print extra debug information with sentence-level details", ), peek: bool = typer.Option( False, "--peek", help="Print debug information using compact table format" diff --git a/src/sqlite_rag/database.py b/src/sqlite_rag/database.py index 8ad1791..f15450f 100644 --- a/src/sqlite_rag/database.py +++ b/src/sqlite_rag/database.py @@ -76,18 +76,32 @@ def _create_schema(conn: sqlite3.Connection, settings: Settings): ) # TODO: this table is not ready for sqlite-sync, it uses the id AUTOINCREMENT - cursor.execute( + cursor.executescript( """ CREATE TABLE IF NOT EXISTS chunks ( id INTEGER PRIMARY KEY AUTOINCREMENT, document_id TEXT, content TEXT, - embedding BLOB, - FOREIGN KEY (document_id) REFERENCES documents (id) ON DELETE CASCADE + embedding BLOB ); + CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks (document_id); """ ) + cursor.executescript( + """ + CREATE TABLE IF NOT EXISTS sentences ( + id TEXT PRIMARY KEY, + chunk_id INTEGER, + content TEXT, + embedding BLOB, + start_offset INTEGER, + end_offset INTEGER + ); + CREATE INDEX IF NOT EXISTS idx_sentences_chunk_id ON sentences (chunk_id); + """ + ) + cursor.execute( """ CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(content, content='chunks', content_rowid='id'); @@ -95,9 +109,17 @@ def _create_schema(conn: sqlite3.Connection, settings: Settings): ) cursor.execute( - f""" - SELECT vector_init('chunks', 'embedding', 'type={settings.vector_type},dimension={settings.embedding_dim},{settings.other_vector_options}'); - """ + """ + SELECT vector_init('chunks', 'embedding', ?); + """, + (settings.get_vector_init_options(),), + ) + # TODO: same configuration as chunks (or different options?) + cursor.execute( + """ + SELECT vector_init('sentences', 'embedding', ?); + """, + (settings.get_vector_init_options(),), ) conn.commit() diff --git a/src/sqlite_rag/engine.py b/src/sqlite_rag/engine.py index c2372d2..1496437 100644 --- a/src/sqlite_rag/engine.py +++ b/src/sqlite_rag/engine.py @@ -2,9 +2,12 @@ import re import sqlite3 from pathlib import Path +from typing import List from sqlite_rag.logger import Logger from sqlite_rag.models.document_result import DocumentResult +from sqlite_rag.models.sentence_result import SentenceResult +from sqlite_rag.sentence_splitter import SentenceSplitter from .chunker import Chunker from .models.document import Document @@ -15,10 +18,17 @@ class Engine: # Considered a good default to normilize the score for RRF DEFAULT_RRF_K = 60 - def __init__(self, conn: sqlite3.Connection, settings: Settings, chunker: Chunker): + def __init__( + self, + conn: sqlite3.Connection, + settings: Settings, + chunker: Chunker, + sentence_splitter: SentenceSplitter, + ): self._conn = conn self._settings = settings self._chunker = chunker + self._sentence_splitter = sentence_splitter self._logger = Logger() def load_model(self): @@ -30,7 +40,7 @@ def load_model(self): self._conn.execute( "SELECT llm_model_load(?, ?);", - (self._settings.model_path, self._settings.model_options), + (self._settings.model_path, self._settings.other_model_options), ) def process(self, document: Document) -> Document: @@ -46,6 +56,11 @@ def process(self, document: Document) -> Document: chunk.title = document.get_title() chunk.embedding = self.generate_embedding(chunk.get_embedding_text()) + sentences = self._sentence_splitter.split(chunk) + for sentence in sentences: + sentence.embedding = self.generate_embedding(sentence.content) + chunk.sentences = sentences + document.chunks = chunks return document @@ -72,6 +87,7 @@ def quantize(self) -> None: cursor = self._conn.cursor() cursor.execute("SELECT vector_quantize('chunks', 'embedding');") + cursor.execute("SELECT vector_quantize('sentences', 'embedding');") self._conn.commit() self._logger.debug("Quantization completed.") @@ -81,21 +97,25 @@ def quantize_preload(self) -> None: cursor = self._conn.cursor() cursor.execute("SELECT vector_quantize_preload('chunks', 'embedding');") + cursor.execute("SELECT vector_quantize_preload('sentences', 'embedding');") def quantize_cleanup(self) -> None: """Clean up internal structures related to a previously quantized table/column.""" cursor = self._conn.cursor() cursor.execute("SELECT vector_quantize_cleanup('chunks', 'embedding');") + cursor.execute("SELECT vector_quantize_cleanup('sentences', 'embedding');") self._conn.commit() def create_new_context(self) -> None: - """""" + """Create a new LLM context with optional runtime overrides.""" cursor = self._conn.cursor() + context_options = self._settings.get_embeddings_context_options() cursor.execute( - "SELECT llm_context_create(?);", (self._settings.model_context_options,) + "SELECT llm_context_create(?);", + (context_options,), ) def free_context(self) -> None: @@ -104,13 +124,38 @@ def free_context(self) -> None: cursor.execute("SELECT llm_context_free();") - def search(self, query: str, top_k: int = 10) -> list[DocumentResult]: - """Semantic search and full-text search sorted with Reciprocal Rank Fusion.""" - query_embedding = self.generate_embedding(query) + def search(self, query, top_k: int = 10) -> list[DocumentResult]: + """Semantic search and full-text search sorted with Reciprocal Rank Fusion + with top matching sentences to highlight.""" + semantic_query = query + if self._settings.use_prompt_templates: + semantic_query = self._settings.prompt_template_retrieval_query.format( + content=query + ) # Clean up and split into words # '*' is used to match while typing - query = " ".join(re.findall(r"\b\w+\b", query.lower())) + "*" + fts_query = " ".join(re.findall(r"\b\w+\b", query.lower())) + "*" + + query_embedding = self.generate_embedding(semantic_query) + + results = self.search_documents(query_embedding, fts_query, top_k=top_k) + + # Refine chunks with top sentences + for result in results: + result.sentences = self.search_sentences( + query_embedding, result.chunk_id, top_k=self._settings.top_k_sentences + ) + + return results + + def search_documents( + self, query_embedding: bytes, fts_query: str, top_k: int + ) -> list[DocumentResult]: + """Semantic search and full-text search sorted with Reciprocal Rank Fusion.""" + # invalid query + if query_embedding == b"" or fts_query.strip() == "": + return [] vector_scan_type = ( "vector_quantize_scan" @@ -119,8 +164,7 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]: ) cursor = self._conn.cursor() - # TODO: understand how to sort results depending on the distance metric - # Eg, for cosine distance, higher is better (closer to 1) + cursor.execute( f""" -- sqlite-vector KNN vector search results @@ -163,7 +207,8 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]: documents.uri, documents.content as document_content, documents.metadata, - chunks.content AS snippet, + chunks.id AS chunk_id, + chunks.content AS chunk_content, vec_rank, fts_rank, combined_rank, @@ -176,7 +221,7 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]: ; """, # nosec B608 { - "query": query, + "query": fts_query, "query_embedding": query_embedding, "k": top_k, "rrf_k": Engine.DEFAULT_RRF_K, @@ -186,7 +231,7 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]: ) rows = cursor.fetchall() - return [ + results = [ DocumentResult( document=Document( id=row["id"], @@ -194,7 +239,8 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]: content=row["document_content"], metadata=json.loads(row["metadata"]) if row["metadata"] else {}, ), - snippet=row["snippet"], + chunk_id=row["chunk_id"], + chunk_content=row["chunk_content"], vec_rank=row["vec_rank"], fts_rank=row["fts_rank"], combined_rank=row["combined_rank"], @@ -204,6 +250,72 @@ def search(self, query: str, top_k: int = 10) -> list[DocumentResult]: for row in rows ] + return results + + def search_sentences( + self, query_embedding: bytes, chunk_id: int, top_k: int + ) -> List[SentenceResult]: + """Semantic search for sentences within a chunk.""" + vector_scan_type = ( + "vector_quantize_scan_stream" + if self._settings.quantize_scan + else "vector_full_scan_stream" + ) + + cursor = self._conn.cursor() + + cursor.execute( + f""" + WITH vec_matches AS ( + SELECT + v.rowid AS sentence_id, + row_number() OVER (ORDER BY v.distance) AS rank_number, + v.distance + FROM {vector_scan_type}('sentences', 'embedding', :query_embedding) AS v + JOIN sentences ON sentences.rowid = v.rowid + WHERE sentences.chunk_id = :chunk_id + LIMIT :top_k + ) + SELECT + sentence_id, + -- Extract sentence directly from document content + COALESCE( + substr(chunks.content, sentences.start_offset + 1, sentences.end_offset - sentences.start_offset), + "" + ) AS content, + sentences.start_offset AS sentence_start_offset, + sentences.end_offset AS sentence_end_offset, + rank_number, + distance + FROM vec_matches + JOIN sentences ON sentences.rowid = vec_matches.sentence_id + JOIN chunks ON chunks.id = sentences.chunk_id + ORDER BY rank_number ASC + """, # nosec B608 + { + "query_embedding": query_embedding, + "top_k": top_k, + "chunk_id": chunk_id, + }, + ) + + rows = cursor.fetchall() + sentences = [] + for row in rows: + sentences.append( + SentenceResult( + id=row["sentence_id"], + chunk_id=chunk_id, + content=row["content"].strip(), + rank=row["rank_number"], + distance=row["distance"], + start_offset=row["sentence_start_offset"], + end_offset=row["sentence_end_offset"], + ) + ) + + return sentences[:top_k] + def versions(self) -> dict: """Get versions of the loaded extensions.""" cursor = self._conn.cursor() diff --git a/src/sqlite_rag/formatters.py b/src/sqlite_rag/formatters.py index 255f3f2..66a7ae9 100644 --- a/src/sqlite_rag/formatters.py +++ b/src/sqlite_rag/formatters.py @@ -2,12 +2,19 @@ """Output formatters for CLI search results.""" from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List import typer from .models.document_result import DocumentResult +# Display constants +BOX_CONTENT_WIDTH = 75 +BOX_TOTAL_WIDTH = 77 +SNIPPET_MAX_LENGTH = 400 +SENTENCE_PREVIEW_LENGTH = 50 +MAX_SENTENCES_DISPLAY = 5 + class SearchResultFormatter(ABC): """Base class for search result formatters.""" @@ -40,7 +47,10 @@ def _get_file_icon(self, uri: str) -> str: return "📄" def _clean_and_wrap_snippet( - self, snippet: str, width: int = 75, max_length: int = 400 + self, + snippet: str, + width: int = BOX_CONTENT_WIDTH, + max_length: int = SNIPPET_MAX_LENGTH, ) -> List[str]: """Clean snippet and wrap to specified width with max length limit.""" # Clean the snippet @@ -69,7 +79,9 @@ def _clean_and_wrap_snippet( return lines - def _format_uri_display(self, uri: str, icon: str, max_width: int = 75) -> str: + def _format_uri_display( + self, uri: str, icon: str, max_width: int = BOX_CONTENT_WIDTH + ) -> str: """Format URI for display with icon and truncation.""" if not uri: return "" @@ -82,7 +94,15 @@ def _format_uri_display(self, uri: str, icon: str, max_width: int = 75) -> str: class BoxedFormatter(SearchResultFormatter): - """Base class for boxed result formatters.""" + """Boxed formatter for search results with optional debug information.""" + + def __init__(self, show_debug: bool = False): + """Initialize formatter. + + Args: + show_debug: Whether to show debug information and sentence details + """ + self.show_debug = show_debug def format_results(self, results: List[DocumentResult], query: str) -> None: if not results: @@ -98,52 +118,39 @@ def format_results(self, results: List[DocumentResult], query: str) -> None: def _format_single_result(self, doc: DocumentResult, idx: int) -> None: """Format a single result with box layout.""" icon = self._get_file_icon(doc.document.uri or "") - snippet_lines = self._clean_and_wrap_snippet( - doc.snippet, width=75, max_length=400 - ) + snippet_text = doc.get_preview(max_chars=SNIPPET_MAX_LENGTH) + snippet_lines = self._clean_and_wrap_snippet(snippet_text) - # Draw the result box header - header = f"┌─ Result #{idx} " + "─" * (67 - len(str(idx))) + # Draw box header + header = f"┌─ Result #{idx} " + "─" * (BOX_TOTAL_WIDTH - 10 - len(str(idx))) typer.echo(header) - # Display URI if available + # Display URI and debug info if doc.document.uri: - uri_display = self._format_uri_display(doc.document.uri, icon, 75) - typer.echo(f"│ {uri_display:<75}│") + uri_display = self._format_uri_display(doc.document.uri, icon) + typer.echo(f"│ {uri_display:<{BOX_CONTENT_WIDTH}}│") - # Add debug info if needed - debug_line = self._get_debug_line(doc) - if debug_line: - typer.echo(debug_line) + if self.show_debug: + self._print_debug_line(doc) - typer.echo("├" + "─" * 77 + "┤") - elif self._should_show_debug(): - debug_line = self._get_debug_line(doc) - if debug_line: - typer.echo(debug_line) - typer.echo("├" + "─" * 77 + "┤") + typer.echo("├" + "─" * BOX_TOTAL_WIDTH + "┤") + elif self.show_debug: + self._print_debug_line(doc) + typer.echo("├" + "─" * BOX_TOTAL_WIDTH + "┤") # Display snippet for line in snippet_lines: - typer.echo(f"│ {line:<75} │") - - typer.echo("└" + "─" * 77 + "┘") - typer.echo() - - def _get_debug_line(self, doc: DocumentResult) -> Optional[str]: - """Get debug information line. Override in subclasses.""" - return None + typer.echo(f"│ {line:<{BOX_CONTENT_WIDTH}} │") - def _should_show_debug(self) -> bool: - """Whether to show debug information. Override in subclasses.""" - return False + # Display sentence details in debug mode + if self.show_debug and doc.sentences: + self._print_sentence_details(doc) + typer.echo("└" + "─" * BOX_TOTAL_WIDTH + "┘") + typer.echo() -class BoxedDebugFormatter(BoxedFormatter): - """Modern detailed formatter with debug information in boxes.""" - - def _get_debug_line(self, doc: DocumentResult) -> str: - """Format debug metrics line.""" + def _print_debug_line(self, doc: DocumentResult) -> None: + """Print debug metrics line.""" combined = ( f"{doc.combined_rank:.5f}" if doc.combined_rank is not None else "N/A" ) @@ -157,10 +164,36 @@ def _get_debug_line(self, doc: DocumentResult) -> str: if doc.fts_rank is not None else "N/A" ) - return f"│ Combined: {combined} │ Vector: {vec_info} │ FTS: {fts_info}" + debug_line = f"│ Combined: {combined} │ Vector: {vec_info} │ FTS: {fts_info}" + typer.echo(debug_line) + + def _print_sentence_details(self, doc: DocumentResult) -> None: + """Print sentence-level details.""" + typer.echo("├" + "─" * BOX_TOTAL_WIDTH + "┤") + typer.echo(f"│ Sentences:{' ' * (BOX_CONTENT_WIDTH - 10)}│") + + for sentence in doc.sentences[:MAX_SENTENCES_DISPLAY]: + distance_str = ( + f"{sentence.distance:.6f}" if sentence.distance is not None else "N/A" + ) + rank_str = f"#{sentence.rank}" if sentence.rank is not None else "N/A" + + # Extract sentence preview + if sentence.start_offset is not None and sentence.end_offset is not None: + sentence_text = doc.chunk_content[ + sentence.start_offset : sentence.end_offset + ].strip() + sentence_preview = sentence_text.replace("\n", " ").replace("\r", "") + if len(sentence_preview) > SENTENCE_PREVIEW_LENGTH: + sentence_preview = ( + sentence_preview[: SENTENCE_PREVIEW_LENGTH - 3] + "..." + ) + else: + sentence_preview = "[No offset info]" - def _should_show_debug(self) -> bool: - return True + # Format and print sentence line + sentence_line = f"│ {rank_str:>3} ({distance_str}) | {sentence_preview}" + typer.echo(sentence_line.ljust(BOX_TOTAL_WIDTH + 1) + " │") class TableDebugFormatter(SearchResultFormatter): @@ -199,8 +232,11 @@ def _print_table_header(self) -> None: def _print_table_row(self, idx: int, doc: DocumentResult) -> None: """Print a single table row.""" + # Get snippet from DocumentResult (handles sentence-based preview automatically) + snippet = doc.get_preview(max_chars=52) + # Clean snippet display - snippet = doc.snippet.replace("\n", " ").replace("\r", "") + snippet = snippet.replace("\n", " ").replace("\r", "") snippet = snippet[:49] + "..." if len(snippet) > 52 else snippet # Clean URI display @@ -227,10 +263,15 @@ def _print_table_row(self, idx: int, doc: DocumentResult) -> None: def get_formatter( debug: bool = False, table_view: bool = False ) -> SearchResultFormatter: - """Factory function to get the appropriate formatter.""" + """Factory function to get the appropriate formatter. + + Args: + debug: Show debug information and sentence details + table_view: Use table format instead of boxed format + + Returns: + SearchResultFormatter instance + """ if table_view: return TableDebugFormatter() - elif debug: - return BoxedDebugFormatter() - else: - return BoxedFormatter() + return BoxedFormatter(show_debug=debug) diff --git a/src/sqlite_rag/models/chunk.py b/src/sqlite_rag/models/chunk.py index 15bb26b..89b987e 100644 --- a/src/sqlite_rag/models/chunk.py +++ b/src/sqlite_rag/models/chunk.py @@ -1,4 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field + +from sqlite_rag.models.sentence import Sentence @dataclass @@ -6,7 +8,8 @@ class Chunk: id: int | None = None document_id: int | None = None # The human readable content of the chunk - # (not the representation of the embedding vector) + # (it does not represent the embedding vector which + # may be altered with prompt or overlap text) content: str = "" embedding: str | bytes = b"" @@ -14,6 +17,8 @@ class Chunk: head_overlap_text: str = "" title: str | None = None + sentences: list[Sentence] = field(default_factory=list) + def get_embedding_text(self) -> str: """Get the content used to generate the embedding from. It can be enriched with overlap text and prompt instructions, diff --git a/src/sqlite_rag/models/document.py b/src/sqlite_rag/models/document.py index e8e4685..535b08b 100644 --- a/src/sqlite_rag/models/document.py +++ b/src/sqlite_rag/models/document.py @@ -18,7 +18,7 @@ class Document: created_at: datetime | None = None updated_at: datetime | None = None - chunks: list["Chunk"] = field(default_factory=list) + chunks: list[Chunk] = field(default_factory=list) def hash(self) -> str: """Generate a hash for the document content using SHA-3 for maximum collision resistance""" diff --git a/src/sqlite_rag/models/document_result.py b/src/sqlite_rag/models/document_result.py index 2a89298..07b364f 100644 --- a/src/sqlite_rag/models/document_result.py +++ b/src/sqlite_rag/models/document_result.py @@ -1,13 +1,14 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from .document import Document +from .sentence_result import SentenceResult @dataclass class DocumentResult: document: Document - snippet: str + chunk_id: int combined_rank: float vec_rank: float | None = None @@ -15,3 +16,54 @@ class DocumentResult: vec_distance: float | None = None fts_score: float | None = None + + chunk_content: str = "" + + # highlight sentences + sentences: list[SentenceResult] = field(default_factory=list) + + def get_preview( + self, top_k_sentences: int = 3, max_chars: int = 400, gap: str = "[...]" + ) -> str: + """Build preview from top ranked sentences with [...] for gaps. + + Args: + top_k_sentences: Number of top sentences to include in preview + max_chars: Maximum total characters for preview + + Returns: + Preview string with top sentences and [...] separators. + Falls back to truncated chunk_content if sentences have no offsets. + """ + top_sentences = self.sentences[:top_k_sentences] if self.sentences else [] + + if not top_sentences: + # Fallback: no sentences, return truncated chunk content + return self.chunk_content[:max_chars] + + # Sort by start_offset to maintain document order + top_sentences = sorted( + top_sentences, + key=lambda s: s.start_offset if s.start_offset is not None else -1, + ) + + preview_parts = [] + total_chars = 0 + prev_end_offset = None + + for sentence in top_sentences: + sentence_text = sentence.content + + if prev_end_offset is not None and sentence.start_offset is not None: + gap_size = sentence.start_offset - prev_end_offset + if gap_size > 10: + preview_parts.append(gap) + total_chars += len(gap) + + preview_parts.append(sentence_text) + total_chars += len(sentence_text) + prev_end_offset = sentence.end_offset + + preview = " ".join(preview_parts) + + return preview[: max_chars - 3] + "..." if len(preview) > max_chars else preview diff --git a/src/sqlite_rag/models/sentence.py b/src/sqlite_rag/models/sentence.py new file mode 100644 index 0000000..064b233 --- /dev/null +++ b/src/sqlite_rag/models/sentence.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class Sentence: + id: int | None = None + content: str = "" + embedding: str | bytes = b"" + start_offset: int | None = None + end_offset: int | None = None diff --git a/src/sqlite_rag/models/sentence_result.py b/src/sqlite_rag/models/sentence_result.py new file mode 100644 index 0000000..d2ffa1d --- /dev/null +++ b/src/sqlite_rag/models/sentence_result.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + + +@dataclass +class SentenceResult: + id: int | None = None + chunk_id: int | None = None + + content: str = "" + + rank: float | None = None + distance: float | None = None + + start_offset: int | None = None + end_offset: int | None = None diff --git a/src/sqlite_rag/repository.py b/src/sqlite_rag/repository.py index be1bc50..005bf80 100644 --- a/src/sqlite_rag/repository.py +++ b/src/sqlite_rag/repository.py @@ -32,11 +32,27 @@ def add_document(self, document: Document) -> str: "INSERT INTO chunks (document_id, content, embedding) VALUES (?, ?, ?)", (document_id, chunk.content, chunk.embedding), ) + + chunk_id = cursor.lastrowid + cursor.execute( - "INSERT INTO chunks_fts (rowid, content) VALUES (last_insert_rowid(), ?)", - (chunk.content,), + "INSERT INTO chunks_fts (rowid, content) VALUES (?, ?)", + (chunk_id, chunk.content), ) + for sentence in chunk.sentences: + cursor.execute( + "INSERT INTO sentences (id, chunk_id, content, embedding, start_offset, end_offset) VALUES (?, ?, ?, ?, ?, ?)", + ( + str(uuid4()), + chunk_id, + sentence.content, + sentence.embedding, + sentence.start_offset, + sentence.end_offset, + ), + ) + self._conn.commit() return document_id diff --git a/src/sqlite_rag/sentence_splitter.py b/src/sqlite_rag/sentence_splitter.py new file mode 100644 index 0000000..c177504 --- /dev/null +++ b/src/sqlite_rag/sentence_splitter.py @@ -0,0 +1,50 @@ +import re +from typing import List + +from sqlite_rag.models.chunk import Chunk +from sqlite_rag.models.sentence import Sentence + + +class SentenceSplitter: + MIN_CHARS_PER_SENTENCE = 20 + + def split(self, chunk: Chunk) -> List[Sentence]: + """Split chunk into sentences.""" + # Split on: sentence endings, semicolons, or paragraph breaks + sentence_regex = re.compile(r'(?<=[.!?;])(?:"|\')?\s+(?=[A-Z])|[\n]{2,}') + + sentences = [] + last_end = 0 + text = chunk.content + + for match in sentence_regex.finditer(text): + segment = text[last_end : match.end()] + + segment = segment.strip() + if len(segment) > self.MIN_CHARS_PER_SENTENCE: + sentences.append( + Sentence( + content=segment, + start_offset=last_end, + end_offset=last_end + len(segment), + ) + ) + + # Position after the current match + last_end = match.end() + + # Last segment + if last_end < len(text): + segment = text[last_end:] + + segment = segment.strip() + if len(segment) > self.MIN_CHARS_PER_SENTENCE: + sentences.append( + Sentence( + content=segment, + start_offset=last_end, + end_offset=last_end + len(segment), + ) + ) + + return sentences diff --git a/src/sqlite_rag/settings.py b/src/sqlite_rag/settings.py index ef41fb2..42b39fc 100644 --- a/src/sqlite_rag/settings.py +++ b/src/sqlite_rag/settings.py @@ -15,11 +15,14 @@ class Settings: "./models/unsloth/embeddinggemma-300m-GGUF/embeddinggemma-300M-Q8_0.gguf" ) # See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_model_loadpath-text-options-text - model_options: str = "" + other_model_options: str = "" + # See: https://github.com/sqliteai/sqlite-ai/blob/main/API.md#llm_context_createoptions-text - model_context_options: str = ( - "generate_embedding=1,normalize_embedding=1,pooling_type=mean,embedding_type=INT8" - ) + other_model_context_options: str = "" + + # How the model pools token embeddings into a single embedding + # Options: "mean", "max", "min", "last", "first" + pooling_type: str = "mean" # Allow the sqlite-ai extension to use the GPU # See: https://github.com/sqliteai/sqlite-ai @@ -27,14 +30,15 @@ class Settings: vector_type: str = "INT8" embedding_dim: int = 768 + other_vector_options: str = ( "distance=cosine" # e.g. distance=metric,other=value,... ) # It includes the overlap size and the prompt template length - chunk_size: int = 512 + chunk_size: int = 2048 # Tokens overlap between chunks - chunk_overlap: int = 61 + chunk_overlap: int = 256 # # Search settings @@ -46,7 +50,7 @@ class Settings: quantize_preload: bool = False # Weights for combining FTS and vector search results - weight_fts: float = 1.0 + weight_fts: float = 1.5 weight_vec: float = 1.0 # @@ -61,7 +65,7 @@ class Settings: # Template to index documents for retrieval, use `{title}` with the title or the string `"none"` prompt_template_retrieval_document: str = "title: {title} | text: {content}" - prompt_template_retrieval_query: str = "task: search result | query: {content}" + prompt_template_retrieval_query: str = 'title: "none" | text: {content}' # # Index settings @@ -71,6 +75,31 @@ class Settings: max_document_size_bytes: int = 5 * 1024 * 1024 # 5 MB # Zero means no limit max_chunks_per_document: int = 1000 + # Number of top sentences to return per document + top_k_sentences: int = 3 + + def get_embeddings_context_options(self) -> str: + """Get the context options for embeddings generation.""" + options = { + "n_ctx": self.chunk_size, + "embedding_type": self.vector_type, + "pooling_type": self.pooling_type, + "generate_embedding": 1, + "normalize_embedding": 1, + } + + return ",".join(f"{k}={v}" for k, v in options.items()) + ( + f",{self.other_model_context_options}" + if self.other_model_context_options + else "" + ) + + def get_vector_init_options(self) -> str: + """Get the vector init options for the vector store.""" + options = {"type": self.vector_type, "dimension": self.embedding_dim} + return ",".join(f"{k}={v}" for k, v in options.items()) + ( + f",{self.other_vector_options}" if self.other_vector_options else "" + ) class SettingsManager: @@ -177,4 +206,5 @@ def has_critical_changes( new_settings.model_path != current_settings.model_path or new_settings.embedding_dim != current_settings.embedding_dim or new_settings.vector_type != current_settings.vector_type + or new_settings.pooling_type != current_settings.pooling_type ) diff --git a/src/sqlite_rag/sqliterag.py b/src/sqlite_rag/sqliterag.py index 8be35b6..6aa1037 100644 --- a/src/sqlite_rag/sqliterag.py +++ b/src/sqlite_rag/sqliterag.py @@ -6,6 +6,7 @@ from sqlite_rag.extractor import Extractor from sqlite_rag.logger import Logger from sqlite_rag.models.document_result import DocumentResult +from sqlite_rag.sentence_splitter import SentenceSplitter from .chunker import Chunker from .database import Database @@ -25,7 +26,12 @@ def __init__(self, connection: sqlite3.Connection, settings: Settings): self._repository = Repository(self._conn, settings) self._chunker = Chunker(self._conn, settings) - self._engine = Engine(self._conn, settings, chunker=self._chunker) + self._engine = Engine( + self._conn, + settings, + chunker=self._chunker, + sentence_splitter=SentenceSplitter(), + ) self._extractor = Extractor() self.ready = False @@ -310,9 +316,6 @@ def search( if new_context: self._engine.create_new_context() - if self._settings.use_prompt_templates: - query = self._settings.prompt_template_retrieval_query.format(content=query) - return self._engine.search(query, top_k=top_k) def get_settings(self) -> dict: diff --git a/tests/conftest.py b/tests/conftest.py index 477832d..2bd096a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,13 @@ import sqlite3 import tempfile +from collections.abc import Generator import pytest from sqlite_rag.chunker import Chunker from sqlite_rag.database import Database from sqlite_rag.engine import Engine +from sqlite_rag.sentence_splitter import SentenceSplitter from sqlite_rag.settings import Settings @@ -25,12 +27,24 @@ def db_conn(): @pytest.fixture -def engine(db_conn) -> Engine: +def engine(db_conn) -> Generator[Engine, None, None]: conn, settings = db_conn - engine = Engine(conn, settings, chunker=Chunker(conn, settings)) + engine = Engine( + conn, + settings, + chunker=Chunker(conn, settings), + sentence_splitter=SentenceSplitter(), + ) engine.load_model() engine.quantize() engine.create_new_context() - return engine + yield engine + + # Cleanup resources to prevent segfaults in Python 3.11 + # Must explicitly free resources before garbage collection + try: + engine.close() + except Exception: + pass diff --git a/tests/integration/test_engine.py b/tests/integration/test_engine.py index 9b99ff6..d289198 100644 --- a/tests/integration/test_engine.py +++ b/tests/integration/test_engine.py @@ -1,12 +1,19 @@ import random import string +from sqlite3 import OperationalError import pytest +from sqlite_rag.chunker import Chunker +from sqlite_rag.engine import Engine +from sqlite_rag.models.document import Document +from sqlite_rag.repository import Repository +from sqlite_rag.sentence_splitter import SentenceSplitter + class TestEngine: @pytest.mark.slow - def test_stress_embedding_generation(self, engine): + def test_stress_embedding_generation(self, engine: Engine): """Test embedding generation with a large number of chunks to not fail and to never generate duplicated embeddings.""" @@ -26,3 +33,290 @@ def random_string(length=30): # Assert assert len(result_chunks) == 1000 + + +class TestEngineQuantization: + def test_quantize_embedding(self, engine: Engine): + """Test quantize called for chunks and sentences embeddings.""" + engine.quantize() + + # If no exception is raised, the test passes + engine.search("test query") + + def test_quantize_cleanup(self, engine: Engine): + """Test quantize cleanup works without errors.""" + engine.quantize() + engine.quantize_cleanup() + + with pytest.raises(OperationalError) as exc_info: + engine.search("test query") + assert "Ensure that vector_quantize() has been called" in str(exc_info.value) + + +class TestEngineSearch: + def test_search(self, engine: Engine): + # Arrange + doc1 = Document( + content="The quick brown fox jumps over the lazy dog.", + uri="document1.txt", + ) + doc2 = Document( + content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", + uri="document2.txt", + ) + + engine.create_new_context() + engine.process(doc1) + engine.process(doc2) + + repository = Repository(engine._conn, engine._settings) + repository.add_document(doc1) + repository.add_document(doc2) + + # Act + results = engine.search("quick brown fox") + + # Assert + assert len(results) > 0 + assert results[0].document.uri == "document1.txt" + + +class TestEngineSearchDocuments: + def test_search_with_empty_database(self, engine: Engine): + results = engine.search_documents(b"132456", "myquery", top_k=5) + + assert len(results) == 0 + + def test_search_with_invalid_query(self, engine: Engine): + results = engine.search_documents(b"", "", top_k=5) + + assert len(results) == 0 + + def test_search_with_semantic_and_fts(self, db_conn): + # Arrange + conn, settings = db_conn + + engine = Engine(conn, settings, Chunker(conn, settings), SentenceSplitter()) + engine.load_model() + engine.create_new_context() + + doc1 = Document( + content="The quick brown fox jumps over the lazy dog.", + uri="document1.txt", + ) + doc2 = Document( + content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", + uri="document2.txt", + ) + doc3 = Document( + content="This document discusses about woodcutters and wood.", + uri="document3.txt", + ) + + engine.process(doc1) + engine.process(doc2) + engine.process(doc3) + + repository = Repository(conn, settings) + repository.add_document(doc1) + repository.add_document(doc2) + doc3_id = repository.add_document(doc3) + + embedding = engine.generate_embedding("about lumberjack") + engine.quantize() + + # Act + results = engine.search_documents(embedding, "about lumberjack", top_k=5) + + assert len(results) > 0 + assert doc3_id == results[0].document.id + + def test_search_semantic_result(self, db_conn): + # Arrange + conn, settings = db_conn + + engine = Engine(conn, settings, Chunker(conn, settings), SentenceSplitter()) + engine.load_model() + engine.create_new_context() + + doc1 = Document( + content="The quick brown fox jumps over the lazy dog.", + uri="document1.txt", + ) + doc2 = Document( + content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", + uri="document2.txt", + ) + doc3 = Document( + content="This document discusses about woodcutters and wood.", + uri="document3.txt", + ) + + engine.process(doc1) + engine.process(doc2) + engine.process(doc3) + + repository = Repository(conn, settings) + repository.add_document(doc1) + repository.add_document(doc2) + doc3_id = repository.add_document(doc3) + + embedding = engine.generate_embedding("about lumberjack") + engine.quantize() + + # Act + results = engine.search_documents(embedding, "about lumberjack", top_k=5) + + assert len(results) > 0 + assert doc3_id == results[0].document.id + + def test_search_fts_results(self, db_conn): + # Arrange + conn, settings = db_conn + + engine = Engine(conn, settings, Chunker(conn, settings), SentenceSplitter()) + engine.load_model() + engine.create_new_context() + + doc1 = Document( + content="The quick brown fox jumps over the lazy dog.", + uri="document1.txt", + ) + doc2 = Document( + content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", + uri="document2.txt", + ) + doc3 = Document( + content="This document discusses about woodcutters and wood.", + uri="document3.txt", + ) + + engine.process(doc1) + engine.process(doc2) + engine.process(doc3) + + repository = Repository(conn, settings) + doc1_id = repository.add_document(doc1) + repository.add_document(doc2) + repository.add_document(doc3) + + embedding = engine.generate_embedding("quick brown fox") + engine.quantize() + + # Act + results = engine.search_documents(embedding, "quick brown fox", top_k=5) + + assert len(results) > 0 + assert doc1_id == results[0].document.id + assert results[0].fts_rank + assert results[0].fts_rank == 1 + assert results[0].fts_score + + def test_search_without_quantization(self, db_conn): + # Arrange + conn, settings = db_conn + settings.quantize_scan = False + + engine = Engine(conn, settings, Chunker(conn, settings), SentenceSplitter()) + engine.load_model() + + doc = Document( + content="The quick brown fox jumps over the lazy dog.", + uri="document1.txt", + ) + + engine.create_new_context() + engine.process(doc) + + repository = Repository(conn, settings) + doc_id = repository.add_document(doc) + + embedding = engine.generate_embedding("wood lumberjack") + + # Act + results = engine.search_documents(embedding, "wood lumberjack", top_k=5) + + assert len(results) > 0 + assert doc_id == results[0].document.id + + def test_search_exact_match(self, db_conn): + conn, settings = db_conn + # cosin distance for searching embedding is exact 0.0 when strings match + settings.other_vector_options = "distance=cosine" + settings.use_prompt_templates = False + + engine = Engine(conn, settings, Chunker(conn, settings), SentenceSplitter()) + engine.load_model() + engine.create_new_context() + + doc1 = Document( + content="The quick brown fox jumps over the lazy dog", + uri="document1.txt", + ) + doc2 = Document( + content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", + uri="document2.txt", + ) + + engine.process(doc1) + engine.process(doc2) + + repository = Repository(conn, settings) + doc1_id = repository.add_document(doc1) + repository.add_document(doc2) + + embedding = engine.generate_embedding( + "The quick brown fox jumps over the lazy dog" + ) + engine.quantize() + + # Act + results = engine.search_documents( + embedding, "The quick brown fox jumps over the lazy dog", top_k=5 + ) + + assert len(results) > 0 + assert doc1_id == results[0].document.id + assert 0.0 == results[0].vec_distance + + +class TestEngineSearchSentences: + def test_search_sentences(self, db_conn): + conn, settings = db_conn + settings.use_prompt_templates = False + settings.quantize_scan = False + + engine = Engine(conn, settings, Chunker(conn, settings), SentenceSplitter()) + engine.load_model() + engine.create_new_context() + + doc = Document( + content=( + """The quick brown fox jumps over the lazy dog. + A stitch in time saves nine. + An apple a day keeps the doctor away. + """ + ), + uri="document1.txt", + ) + + engine.process(doc) + + repository = Repository(conn, settings) + doc_id = repository.add_document(doc) + + cursor = conn.execute("SELECT id FROM chunks WHERE document_id = ?", (doc_id,)) + chunk_id = cursor.fetchone()[0] + + embedding = engine.generate_embedding("stitch time") + + # Act + results = engine.search_sentences( + embedding, + chunk_id, + top_k=1, + ) + + assert len(results) > 0 + assert results[0].start_offset == 61 # it's the second sentence + assert results[0].end_offset == 89 diff --git a/tests/models/test_document_result.py b/tests/models/test_document_result.py new file mode 100644 index 0000000..1246ce6 --- /dev/null +++ b/tests/models/test_document_result.py @@ -0,0 +1,253 @@ +from sqlite_rag.models.document import Document +from sqlite_rag.models.document_result import DocumentResult +from sqlite_rag.models.sentence_result import SentenceResult + + +class TestDocumentResult: + def test_get_preview_no_sentences(self): + doc = Document(uri="test.txt", content="test content") + result = DocumentResult( + document=doc, + chunk_id=1, + combined_rank=1.0, + sentences=[], + ) + + preview = result.get_preview(max_chars=100) + assert preview == "" + + def test_get_preview_with_single_sentence(self): + doc = Document(uri="test.txt", content="test content") + + sentences = [ + SentenceResult( + chunk_id=1, + id=2, + content="Second sentence there.", + rank=1, + distance=0.1, + start_offset=15, + end_offset=36, + ), + ] + + result = DocumentResult( + document=doc, + chunk_id=1, + combined_rank=1.0, + sentences=sentences, + ) + + preview = result.get_preview(max_chars=400) + assert preview == "Second sentence there." + + def test_get_preview_with_gaps(self): + """Test get_preview adds [...] separator for gaps.""" + doc = Document(uri="test.txt", content="test content") + + sentences = [ + SentenceResult( + chunk_id=1, + id=1, + content="First sentence at the beginning.", + rank=1, + distance=0.1, + start_offset=0, + end_offset=32, + ), + SentenceResult( + chunk_id=1, + id=3, + content="Last sentence at the end.", + rank=2, + distance=0.2, + start_offset=75, + end_offset=103, + ), + ] + + result = DocumentResult( + document=doc, + chunk_id=1, + combined_rank=1.0, + sentences=sentences, + ) + + preview = result.get_preview(max_chars=400) + assert ( + "First sentence at the beginning. [...] Last sentence at the end." + == preview + ) + + def test_get_preview_respects_max_chars(self): + """Test get_preview truncates when exceeding max_chars.""" + doc = Document(uri="test.txt", content="test content") + content = "A very long sentence that exceeds the maximum character limit. " * 10 + + sentences = [ + SentenceResult( + chunk_id=1, + id=1, + content=content, + rank=1, + distance=0.1, + start_offset=0, + end_offset=200, + ), + ] + + result = DocumentResult( + document=doc, + chunk_id=1, + combined_rank=1.0, + sentences=sentences, + ) + + preview = result.get_preview(max_chars=50) + assert len(preview) <= 50 + + def test_get_preview_with_multiple_consecutive_and_ordered_sentences(self): + doc = Document(uri="test.txt", content="test content") + + sentences = [ + SentenceResult( + chunk_id=1, + id=1, + content="First sentence.", + rank=1, + distance=0.1, + start_offset=0, + end_offset=15, + ), + SentenceResult( + chunk_id=1, + id=2, + content="Second sentence.", + rank=2, + distance=0.2, + start_offset=16, + end_offset=32, + ), + ] + + result = DocumentResult( + document=doc, + chunk_id=1, + combined_rank=1.0, + sentences=sentences, + ) + + preview = result.get_preview(max_chars=400) + assert preview == "First sentence. Second sentence." + + def test_get_preview_orders_sentences_by_offset(self): + """Test get_preview reorders sentences by start_offset (document order).""" + doc = Document(uri="test.txt", content="test content") + + # Sentences in reverse rank order (rank 1 is last in document) + sentences = [ + SentenceResult( + chunk_id=1, + id=3, + content="Third sentence.", + rank=1, # higher rank but appears latter in document + distance=0.1, + start_offset=66, + end_offset=82, + ), + SentenceResult( + chunk_id=1, + id=1, + content="First sentence.", + rank=2, + distance=0.2, + start_offset=0, + end_offset=15, + ), + ] + + result = DocumentResult( + document=doc, + chunk_id=1, + combined_rank=1.0, + sentences=sentences, + ) + + preview = result.get_preview(max_chars=400) + # Should be in document order despite rank order + assert "First sentence. [...] Third sentence." == preview + + def test_get_preview_limits_to_top_k_sentences(self): + """Test get_preview respects top_k_sentences parameter.""" + doc = Document(uri="test.txt", content="test content") + + # 5 sentences, but only top 2 should be used + sentences = [ + SentenceResult( + chunk_id=1, + id=1, + content="First.", + rank=1, + distance=0.1, + start_offset=0, + end_offset=6, + ), + SentenceResult( + chunk_id=1, + id=2, + content="Second.", + rank=2, + distance=0.2, + start_offset=7, + end_offset=14, + ), + SentenceResult( + chunk_id=1, + id=3, + content="Third.", + rank=3, + distance=0.3, + start_offset=15, + end_offset=21, + ), + SentenceResult( + chunk_id=1, + id=4, + content="Fourth.", + rank=4, + distance=0.4, + start_offset=22, + end_offset=29, + ), + SentenceResult( + chunk_id=1, + id=5, + content="Fifth.", + rank=5, + distance=0.5, + start_offset=30, + end_offset=36, + ), + ] + + result = DocumentResult( + document=doc, + chunk_id=1, + combined_rank=1.0, + sentences=sentences, + ) + + preview = result.get_preview(top_k_sentences=2, max_chars=400) + assert "First." in preview + assert "Second." in preview + assert "Third" not in preview + assert "Fourth" not in preview + assert "Fifth" not in preview + + # Test with default top_k=3 + preview_default = result.get_preview(max_chars=400) + assert "First." in preview_default + assert "Second." in preview_default + assert "Third." in preview_default + assert "Fourth" not in preview_default + assert "Fifth" not in preview_default diff --git a/tests/test_chunker.py b/tests/test_chunker.py index 8c54949..792e21c 100644 --- a/tests/test_chunker.py +++ b/tests/test_chunker.py @@ -322,9 +322,9 @@ def test_chunk_size_equals_overlap(self, mock_conn): chunker = Chunker(mock_conn, settings) text = "This is a test sentence that should be handled gracefully." - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError) as exc_info: chunker.chunk(Document(content=text)) - assert "Chunk size must be greater than chunk overlap." in str(excinfo.value) + assert "Chunk size must be greater than chunk overlap." in str(exc_info.value) def test_very_small_chunk_size(self, mock_conn): """Test with chunk_size = 1.""" diff --git a/tests/test_engine.py b/tests/test_engine.py index 2f38f35..1cdd3f3 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,10 +1,9 @@ import pytest -from sqlite_rag.chunker import Chunker from sqlite_rag.engine import Engine from sqlite_rag.models.chunk import Chunk from sqlite_rag.models.document import Document -from sqlite_rag.repository import Repository +from sqlite_rag.models.sentence import Sentence from sqlite_rag.settings import Settings @@ -32,8 +31,10 @@ def test_process_uses_get_embedding_text(self, mocker): mock_conn = mocker.Mock() mock_chunker = mocker.Mock() mock_chunker.chunk.return_value = [mock_chunk] + mock_sentence_splitter = mocker.Mock() + mock_sentence_splitter.split.return_value = [] - engine = Engine(mock_conn, settings, mock_chunker) + engine = Engine(mock_conn, settings, mock_chunker, mock_sentence_splitter) # Mock generate_embedding completely mock_generate = mocker.patch.object( @@ -65,8 +66,10 @@ def test_process_with_max_chunks_per_document( settings = Settings(max_chunks_per_document=max_chunks_per_document) mock_chunker = mocker.Mock() mock_chunker.chunk.return_value = chunks + mock_sentence_splitter = mocker.Mock() + mock_sentence_splitter.split.return_value = [] - engine = Engine(mock_conn, settings, mock_chunker) + engine = Engine(mock_conn, settings, mock_chunker, mock_sentence_splitter) mock_generate_embedding = mocker.patch.object(engine, "generate_embedding") mock_generate_embedding = mocker.spy( @@ -84,183 +87,175 @@ def test_process_with_max_chunks_per_document( chunks = call_args[0][0] # First argument assert len(chunks) == expected_chunk_count - -class TestEngineSearch: - def test_search_with_empty_database(self, engine): - results = engine.search("nonexistent query", top_k=5) - - assert len(results) == 0 - - def test_search_with_semantic_and_fts(self, db_conn): + def test_process_with_sentences(self, mocker): # Arrange - conn, settings = db_conn - - engine = Engine(conn, settings, Chunker(conn, settings)) - engine.load_model() - engine.create_new_context() + chunks = [Chunk(content="Chunk 1"), Chunk(content="Chunk 2")] - doc1 = Document( - content="The quick brown fox jumps over the lazy dog.", - uri="document1.txt", - ) - doc2 = Document( - content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", - uri="document2.txt", - ) - doc3 = Document( - content="This document discusses about woodcutters and wood.", - uri="document3.txt", - ) + mock_conn = mocker.Mock() + settings = Settings() + mock_chunker = mocker.Mock() + mock_chunker.chunk.return_value = chunks + mock_sentence_splitter = mocker.Mock() + # return different number of sentences per chunk + mock_sentence_splitter.split.side_effect = [ + [Sentence(content="Sentence 1.1")], + [Sentence(content="Sentence 2.1"), Sentence(content="Sentence 2.2")], + ] - engine.process(doc1) - engine.process(doc2) - engine.process(doc3) + engine = Engine(mock_conn, settings, mock_chunker, mock_sentence_splitter) - repository = Repository(conn, settings) - repository.add_document(doc1) - repository.add_document(doc2) - doc3_id = repository.add_document(doc3) + mock_generate_embedding = mocker.patch.object(engine, "generate_embedding") + mock_generate_embedding = mocker.spy( + mock_generate_embedding, "generate_embedding" + ) + mock_generate_embedding.return_value = chunks - engine.quantize() + document = Document(content="Test document content") # Act - results = engine.search("wood lumberjack", top_k=5) + engine.process(document) - assert len(results) > 0 - assert doc3_id == results[0].document.id + # Assert + assert len(document.chunks) == 2 + assert len(document.chunks[0].sentences) == 1 + assert len(document.chunks[1].sentences) == 2 - def test_search_semantic_result(self, db_conn): + def test_process_without_sentences(self, mocker): # Arrange - conn, settings = db_conn + chunks = [Chunk(content="Chunk 1")] - engine = Engine(conn, settings, Chunker(conn, settings)) - engine.load_model() - engine.create_new_context() - - doc1 = Document( - content="The quick brown fox jumps over the lazy dog.", - uri="document1.txt", - ) - doc2 = Document( - content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", - uri="document2.txt", - ) - doc3 = Document( - content="This document discusses about woodcutters and wood.", - uri="document3.txt", - ) + mock_conn = mocker.Mock() + settings = Settings() + mock_chunker = mocker.Mock() + mock_chunker.chunk.return_value = chunks + mock_sentence_splitter = mocker.Mock() + mock_sentence_splitter.split.return_value = [] - engine.process(doc1) - engine.process(doc2) - engine.process(doc3) + engine = Engine(mock_conn, settings, mock_chunker, mock_sentence_splitter) - repository = Repository(conn, settings) - repository.add_document(doc1) - repository.add_document(doc2) - doc3_id = repository.add_document(doc3) + mock_generate_embedding = mocker.patch.object(engine, "generate_embedding") + mock_generate_embedding = mocker.spy( + mock_generate_embedding, "generate_embedding" + ) + mock_generate_embedding.return_value = chunks - engine.quantize() + document = Document(content="Test document content") # Act - results = engine.search("about lumberjack", top_k=5) + engine.process(document) - assert len(results) > 0 - assert doc3_id == results[0].document.id + # Assert + assert len(document.chunks) == 1 + assert len(document.chunks[0].sentences) == 0 - def test_search_fts_results(self, db_conn): + def test_search(self, mocker): # Arrange - conn, settings = db_conn - - engine = Engine(conn, settings, Chunker(conn, settings)) - engine.load_model() - engine.create_new_context() + mock_conn = mocker.Mock() + settings = Settings() + engine = Engine(mock_conn, settings, mocker.Mock(), mocker.Mock()) - doc1 = Document( - content="The quick brown fox jumps over the lazy dog.", - uri="document1.txt", + mock_generate = mocker.patch.object( + engine, "generate_embedding", return_value=b"embedding" ) - doc2 = Document( - content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", - uri="document2.txt", + mock_search_docs = mocker.patch.object( + engine, + "search_documents", + return_value=[ + mocker.Mock(chunk_id=1, sentences=[]), + mocker.Mock(chunk_id=2, sentences=[]), + ], ) - doc3 = Document( - content="This document discusses about woodcutters and wood.", - uri="document3.txt", + mock_search_sents = mocker.patch.object( + engine, "search_sentences", return_value=[] ) - engine.process(doc1) - engine.process(doc2) - engine.process(doc3) - - repository = Repository(conn, settings) - doc1_id = repository.add_document(doc1) - repository.add_document(doc2) - repository.add_document(doc3) - - engine.quantize() - # Act - results = engine.search("quick brown fox", top_k=5) + engine.search("test query", top_k=5) - assert len(results) > 0 - assert doc1_id == results[0].document.id + # Assert + mock_generate.assert_called_once_with('title: "none" | text: test query') + mock_search_docs.assert_called_once_with(b"embedding", "test query*", top_k=5) + assert mock_search_sents.call_count == 2 + mock_search_sents.assert_any_call( + b"embedding", 1, top_k=settings.top_k_sentences + ) + mock_search_sents.assert_any_call( + b"embedding", 2, top_k=settings.top_k_sentences + ) - def test_search_without_quantization(self, db_conn): + def test_search_uses_retrieval_query_template(self, mocker): # Arrange - conn, settings = db_conn - settings.quantize_scan = False - - engine = Engine(conn, settings, Chunker(conn, settings)) - engine.load_model() + template = "task: search | Do something with {content}" - doc = Document( - content="The quick brown fox jumps over the lazy dog.", - uri="document1.txt", - ) + settings = Settings(prompt_template_retrieval_query=template) - engine.create_new_context() - engine.process(doc) + mock_conn = mocker.Mock() + engine = Engine(mock_conn, settings, mocker.Mock(), mocker.Mock()) - repository = Repository(conn, settings) - doc_id = repository.add_document(doc) + mock_generate = mocker.patch.object( + engine, "generate_embedding", return_value=b"embedding" + ) + mock_search_docs = mocker.patch.object( + engine, + "search_documents", + return_value=[ + mocker.Mock(chunk_id=1, sentences=[]), + ], + ) + mock_search_sents = mocker.patch.object( + engine, "search_sentences", return_value=[] + ) # Act - results = engine.search("wood lumberjack") - - assert len(results) > 0 - assert doc_id == results[0].document.id + query = "test query" + engine.search(query, top_k=10) - def test_search_exact_match(self, db_conn): - conn, settings = db_conn - # cosin distance for searching embedding is exact 0.0 when strings match - settings.other_vector_options = "distance=cosine" - settings.use_prompt_templates = False + expected_fts_query = query + "*" - engine = Engine(conn, settings, Chunker(conn, settings)) - engine.load_model() - engine.create_new_context() - - doc1 = Document( - content="The quick brown fox jumps over the lazy dog", - uri="document1.txt", + # Assert + # Is called with the formatted template + mock_generate.assert_called_once_with( + "task: search | Do something with test query" ) - doc2 = Document( - content="How much wood would a woodchuck chuck if a woodchuck could chuck wood?", - uri="document2.txt", + mock_search_docs.assert_called_once_with( + b"embedding", expected_fts_query, top_k=10 + ) + mock_search_sents.assert_called_once_with( + b"embedding", 1, top_k=settings.top_k_sentences ) - engine.process(doc1) - engine.process(doc2) + @pytest.mark.parametrize("use_prompt_templates", [True, False]) + def test_search_with_prompt_template(self, mocker, use_prompt_templates): + # Arrange + settings = Settings( + use_prompt_templates=use_prompt_templates, + prompt_template_retrieval_query="task: search result | query: {content}", + ) - repository = Repository(conn, settings) - doc1_id = repository.add_document(doc1) - repository.add_document(doc2) + mock_conn = mocker.Mock() + engine = Engine(mock_conn, settings, mocker.Mock(), mocker.Mock()) - engine.quantize() + mock_generate_embedding = mocker.patch.object( + engine, "generate_embedding", return_value=b"embedding" + ) + mocker.patch.object( + engine, + "search_documents", + return_value=[ + mocker.Mock(chunk_id=1, sentences=[]), + ], + ) + mocker.patch.object(engine, "search_sentences", return_value=[]) # Act - results = engine.search("The quick brown fox jumps over the lazy dog") + query = "test query" + engine.search(query) + + # Assert - verify engine.search was called with correct formatted query + expected_semantic_query = ( + "task: search result | query: test query" + if use_prompt_templates + else "test query" + ) - assert len(results) > 0 - assert doc1_id == results[0].document.id - assert 0.0 == results[0].vec_distance + mock_generate_embedding.assert_called_once_with(expected_semantic_query) diff --git a/tests/test_formatters.py b/tests/test_formatters.py new file mode 100644 index 0000000..c738c91 --- /dev/null +++ b/tests/test_formatters.py @@ -0,0 +1,299 @@ +from sqlite_rag.formatters import ( + BoxedFormatter, + TableDebugFormatter, + get_formatter, +) +from sqlite_rag.models.document import Document +from sqlite_rag.models.document_result import DocumentResult +from sqlite_rag.models.sentence_result import SentenceResult + + +class TestGetFormatter: + """Test the get_formatter factory function.""" + + def test_get_formatter_default(self): + """Test getting formatter with default parameters.""" + formatter = get_formatter() + assert isinstance(formatter, BoxedFormatter) + assert formatter.show_debug is False + + def test_get_formatter_debug(self): + """Test getting formatter with debug=True.""" + formatter = get_formatter(debug=True) + assert isinstance(formatter, BoxedFormatter) + assert formatter.show_debug is True + + def test_get_formatter_table_view(self): + """Test getting table formatter.""" + formatter = get_formatter(table_view=True) + assert isinstance(formatter, TableDebugFormatter) + + def test_get_formatter_table_view_takes_precedence(self): + """Test that table_view takes precedence over debug.""" + formatter = get_formatter(debug=True, table_view=True) + assert isinstance(formatter, TableDebugFormatter) + # Table formatter doesn't have show_debug attribute + + +class TestSearchResultFormatter: + """Test base SearchResultFormatter methods.""" + + def setup_method(self): + """Set up test fixtures.""" + self.formatter = BoxedFormatter() + + def test_get_file_icon_python(self): + """Test getting icon for Python files.""" + assert self.formatter._get_file_icon("test.py") == "🐍" + assert self.formatter._get_file_icon("test.pyx") == "🐍" + + def test_get_file_icon_javascript(self): + """Test getting icon for JavaScript/TypeScript files.""" + assert self.formatter._get_file_icon("test.js") == "⚡" + assert self.formatter._get_file_icon("test.ts") == "⚡" + assert self.formatter._get_file_icon("test.jsx") == "⚡" + assert self.formatter._get_file_icon("test.tsx") == "⚡" + + def test_get_file_icon_markdown(self): + """Test getting icon for Markdown files.""" + assert self.formatter._get_file_icon("README.md") == "📄" + assert self.formatter._get_file_icon("doc.markdown") == "📄" + + def test_get_file_icon_case_insensitive(self): + """Test that file icon detection is case insensitive.""" + assert self.formatter._get_file_icon("TEST.PY") == "🐍" + assert self.formatter._get_file_icon("Test.Js") == "⚡" + + def test_get_file_icon_empty_uri(self): + """Test getting icon for empty URI.""" + assert self.formatter._get_file_icon("") == "📝" + + def test_get_file_icon_unknown_extension(self): + """Test getting default icon for unknown extensions.""" + assert self.formatter._get_file_icon("test.xyz") == "📄" + + def test_clean_and_wrap_snippet_basic(self): + """Test basic snippet cleaning and wrapping.""" + snippet = "This is a simple test snippet." + result = self.formatter._clean_and_wrap_snippet(snippet, width=30) + assert len(result) > 0 + assert all(len(line) <= 30 for line in result) + + def test_clean_and_wrap_snippet_removes_newlines(self): + """Test that newlines and carriage returns are removed.""" + snippet = "Line 1\nLine 2\r\nLine 3" + result = self.formatter._clean_and_wrap_snippet(snippet) + combined = " ".join(result) + assert "\n" not in combined + assert "\r" not in combined + assert "Line 1 Line 2 Line 3" == combined + + def test_clean_and_wrap_snippet_truncates_long_text(self): + """Test that long snippets are truncated.""" + snippet = "A" * 500 + result = self.formatter._clean_and_wrap_snippet(snippet, max_length=100) + combined = "".join(result) + assert len(combined) <= 103 # 100 + "..." + assert combined.endswith("...") + + def test_format_uri_display_basic(self): + """Test basic URI formatting.""" + uri_display = self.formatter._format_uri_display( + "path/to/file.py", "🐍", max_width=50 + ) + assert uri_display == "🐍 path/to/file.py" + + def test_format_uri_display_truncates_long_uri(self): + """Test that long URIs are truncated.""" + long_uri = "very/long/path/" * 10 + "file.py" + uri_display = self.formatter._format_uri_display(long_uri, "🐍", max_width=50) + assert len(uri_display) <= 50 + assert uri_display.startswith("🐍 ...") + + def test_format_uri_display_empty_uri(self): + """Test formatting empty URI.""" + assert self.formatter._format_uri_display("", "🐍") == "" + + +class TestBoxedFormatter: + """Test BoxedFormatter functionality.""" + + def test_init_default(self): + """Test BoxedFormatter initialization with default parameters.""" + formatter = BoxedFormatter() + assert formatter.show_debug is False + + def test_init_with_debug(self): + """Test BoxedFormatter initialization with debug enabled.""" + formatter = BoxedFormatter(show_debug=True) + assert formatter.show_debug is True + + def test_format_results_empty(self, mocker): + """Test formatting with empty results.""" + formatter = BoxedFormatter() + mock_echo = mocker.patch("typer.echo") + formatter.format_results([], "test query") + mock_echo.assert_called_once_with("No documents found matching the query.") + + def test_format_results_with_results(self, mocker): + """Test formatting with actual results.""" + doc = Document(uri="test.py", content="test content") + result = DocumentResult( + document=doc, + chunk_id=1, + chunk_content="This is test content.", + combined_rank=0.95, + vec_rank=1, + fts_rank=2, + vec_distance=0.1, + fts_score=5.0, + ) + + formatter = BoxedFormatter() + mock_echo = mocker.patch("typer.echo") + formatter.format_results([result], "test query") + # Should print header, result box, and empty line + assert mock_echo.call_count > 3 + # Check that it prints the search results header + first_call = mock_echo.call_args_list[0][0][0] + assert "Search Results" in first_call + assert "1 matches" in first_call + + def test_format_results_with_debug(self, mocker): + """Test formatting with debug information.""" + doc = Document(uri="test.py", content="test content") + result = DocumentResult( + document=doc, + chunk_id=1, + chunk_content="This is test content.", + combined_rank=0.95, + vec_rank=1, + fts_rank=2, + vec_distance=0.123456, + fts_score=5.678901, + ) + + formatter = BoxedFormatter(show_debug=True) + mock_echo = mocker.patch("typer.echo") + formatter.format_results([result], "test query") + # Check that debug info is printed + output = "\n".join( + [ + str(call.args[0]) if call.args else "" + for call in mock_echo.call_args_list + ] + ) + assert "Combined:" in output + assert "Vector:" in output + assert "FTS:" in output + + def test_format_results_with_sentences_in_debug_mode(self, mocker): + """Test formatting with sentence details in debug mode.""" + doc = Document(uri="test.py", content="test content") + sentences = [ + SentenceResult( + id=1, + chunk_id=1, + content="First sentence.", + rank=1, + distance=0.1, + start_offset=0, + end_offset=15, + ), + SentenceResult( + id=2, + chunk_id=1, + content="Second sentence.", + rank=2, + distance=0.2, + start_offset=16, + end_offset=32, + ), + ] + result = DocumentResult( + document=doc, + chunk_id=1, + chunk_content="First sentence. Second sentence.", + combined_rank=0.95, + sentences=sentences, + ) + + formatter = BoxedFormatter(show_debug=True) + mock_echo = mocker.patch("typer.echo") + formatter.format_results([result], "test query") + output = "\n".join( + [ + str(call.args[0]) if call.args else "" + for call in mock_echo.call_args_list + ] + ) + assert "Sentences:" in output + + def test_format_results_without_sentences_in_non_debug_mode(self, mocker): + """Test that sentences are not shown in non-debug mode.""" + doc = Document(uri="test.py", content="test content") + sentences = [ + SentenceResult( + id=1, + chunk_id=1, + content="First sentence.", + rank=1, + distance=0.1, + start_offset=0, + end_offset=15, + ), + ] + result = DocumentResult( + document=doc, + chunk_id=1, + chunk_content="First sentence.", + combined_rank=0.95, + sentences=sentences, + ) + + formatter = BoxedFormatter(show_debug=False) + mock_echo = mocker.patch("typer.echo") + formatter.format_results([result], "test query") + output = "\n".join( + [ + str(call.args[0]) if call.args else "" + for call in mock_echo.call_args_list + ] + ) + assert "Sentences:" not in output + + +class TestTableDebugFormatter: + """Test TableDebugFormatter functionality.""" + + def test_format_results_empty(self, mocker): + """Test table formatting with empty results.""" + formatter = TableDebugFormatter() + mock_echo = mocker.patch("typer.echo") + formatter.format_results([], "test query") + mock_echo.assert_called_once_with("No documents found matching the query.") + + def test_format_results_with_results(self, mocker): + """Test table formatting with actual results.""" + doc = Document(uri="test.py", content="test content") + result = DocumentResult( + document=doc, + chunk_id=1, + chunk_content="This is test content.", + combined_rank=0.95, + vec_rank=1, + fts_rank=2, + vec_distance=0.1, + fts_score=5.0, + ) + + formatter = TableDebugFormatter() + mock_echo = mocker.patch("typer.echo") + formatter.format_results([result], "test query") + # Should print header, table header, separator, and row + assert mock_echo.call_count >= 4 + # Check that headers are printed + output = "\n".join([str(call[0][0]) for call in mock_echo.call_args_list]) + assert "Preview" in output + assert "URI" in output + assert "C.Rank" in output diff --git a/tests/test_sentence_splitter.py b/tests/test_sentence_splitter.py new file mode 100644 index 0000000..7030b68 --- /dev/null +++ b/tests/test_sentence_splitter.py @@ -0,0 +1,67 @@ +from sqlite_rag.models.chunk import Chunk +from sqlite_rag.sentence_splitter import SentenceSplitter + + +class TestSentenceSplitter: + def test_split(self): + + splitter = SentenceSplitter() + + chunk = Chunk( + id=1, + document_id=1, + title="Test Chunk", + content="This is the first sentence.\nHere is the second sentence! And what about the third?", + embedding=b"", + sentences=[], + ) + + sentences = splitter.split(chunk) + + assert len(sentences) == 3 + assert sentences[0].content == "This is the first sentence." + assert sentences[0].start_offset == 0 + assert sentences[0].end_offset == 27 + + assert sentences[1].content == "Here is the second sentence!" + assert sentences[1].start_offset == 28 + assert sentences[1].end_offset == 28 + 28 + + assert sentences[2].content == "And what about the third?" + assert sentences[2].start_offset == 57 + assert sentences[2].end_offset == 57 + 25 + + def test_split_empty(self): + splitter = SentenceSplitter() + + chunk = Chunk( + id=1, + document_id=1, + title="Empty Chunk", + content="", + embedding=b"", + sentences=[], + ) + + sentences = splitter.split(chunk) + + assert len(sentences) == 0 + + def test_split_no_punctuation(self): + splitter = SentenceSplitter() + + chunk = Chunk( + id=1, + document_id=1, + title="No Punctuation Chunk", + content="This is a sentence without punctuation and another one follows it", + embedding=b"", + sentences=[], + ) + + sentences = splitter.split(chunk) + + assert len(sentences) == 1 + assert sentences[0].content == chunk.content + assert sentences[0].start_offset == 0 + assert sentences[0].end_offset == len(chunk.content) diff --git a/tests/test_settings.py b/tests/test_settings.py index c8b6e3e..a26f0eb 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -9,7 +9,7 @@ def test_store_settings(self, db_conn): settings_manager = SettingsManager(db_conn[0]) settings = Settings( model_path="test_model", - model_options="test_config", + other_model_options="test_config", embedding_dim=768, vector_type="test_store", chunk_overlap=100, @@ -23,7 +23,7 @@ def test_store_settings(self, db_conn): assert stored_settings is not None assert stored_settings.model_path == "test_model" - assert stored_settings.model_options == "test_config" + assert stored_settings.other_model_options == "test_config" assert stored_settings.embedding_dim == 768 assert stored_settings.vector_type == "test_store" assert stored_settings.chunk_overlap == 100 @@ -34,7 +34,7 @@ def test_store_settings_when_exist(self, db_conn): settings_manager = SettingsManager(db_conn[0]) settings = Settings( model_path="test_model", - model_options="test_config", + other_model_options="test_config", embedding_dim=768, vector_type="test_store", chunk_overlap=100, @@ -47,7 +47,7 @@ def test_store_settings_when_exist(self, db_conn): # Store again with different values new_settings = Settings( model_path="new_model", - model_options="new_config", + other_model_options="new_config", embedding_dim=512, vector_type="new_store", chunk_overlap=50, @@ -60,7 +60,7 @@ def test_store_settings_when_exist(self, db_conn): assert stored_settings is not None assert stored_settings.model_path == "new_model" - assert stored_settings.model_options == "new_config" + assert stored_settings.other_model_options == "new_config" assert stored_settings.embedding_dim == 512 assert stored_settings.vector_type == "new_store" assert stored_settings.chunk_overlap == 50 @@ -82,7 +82,7 @@ def test_load_settings_with_defaults(self, db_conn): assert loaded_settings is not None assert loaded_settings.model_path == settings.model_path - assert loaded_settings.model_options == settings.model_options + assert loaded_settings.other_model_options == settings.other_model_options assert loaded_settings.embedding_dim == settings.embedding_dim assert loaded_settings.vector_type == settings.vector_type assert loaded_settings.chunk_overlap == settings.chunk_overlap diff --git a/tests/test_sqlite_rag.py b/tests/test_sqlite_rag.py index 19fedb7..3cdddc0 100644 --- a/tests/test_sqlite_rag.py +++ b/tests/test_sqlite_rag.py @@ -6,7 +6,6 @@ import pytest from sqlite_rag import SQLiteRag -from sqlite_rag.settings import Settings class TestSQLiteRagAdd: @@ -821,53 +820,3 @@ def test_search_samples_exact_match_by_scan_type(self, quantize_scan: bool): # Second result should have distance > 0 second_result = results[1] assert second_result.vec_distance and second_result.vec_distance > 0.0 - - def test_search_uses_retrieval_query_template(self, mocker): - template = "task: search | Do something with {content}" - - settings = {"prompt_template_retrieval_query": template} - - rag = SQLiteRag.create(":memory:", settings=settings) - - mock_engine = mocker.Mock() - mock_engine.search.return_value = [] - - rag._engine = mock_engine - - query = "test query" - rag.search(query) - - # Assert that engine.search was called with the formatted template - expected_query = rag._settings.prompt_template_retrieval_query.format( - content=query - ) - mock_engine.search.assert_called_once_with(expected_query, top_k=10) - - @pytest.mark.parametrize("use_prompt_templates", [True, False]) - def test_search_with_prompt_template(self, mocker, use_prompt_templates): - # Arrange - settings = Settings( - use_prompt_templates=use_prompt_templates, - prompt_template_retrieval_query="task: search result | query: {content}", - ) - - # Mock engine and its search method - mock_engine = mocker.Mock() - mock_engine.search.return_value = [] # Empty search results - - # Create SQLiteRag instance with mocked dependencies - rag = SQLiteRag(mocker.Mock(), settings) - rag._engine = mock_engine - - mocker.patch.object(rag, "_ensure_initialized") - - # Act - rag.search("test query", new_context=False) - - # Assert - verify engine.search was called with correct formatted query - expected_query = ( - "task: search result | query: test query" - if use_prompt_templates - else "test query" - ) - mock_engine.search.assert_called_once_with(expected_query, top_k=10)