From 40e179ecad75882f947f02d4375c0105029c781d Mon Sep 17 00:00:00 2001 From: Daniele Briggi <=> Date: Tue, 7 Oct 2025 11:07:19 +0000 Subject: [PATCH 1/4] feat(add): #8 option to only/exclude files by ext --- src/sqlite_rag/cli.py | 16 ++++++ src/sqlite_rag/reader.py | 99 ++++++++++++++++++++------------ src/sqlite_rag/sqliterag.py | 20 ++++++- tests/test_reader.py | 95 ++++++++++++++++++++++++++++++- tests/test_sqlite_rag.py | 110 ++++++++++++++++++++++++++++++++++++ 5 files changed, 298 insertions(+), 42 deletions(-) diff --git a/src/sqlite_rag/cli.py b/src/sqlite_rag/cli.py index c99154f..e2108e5 100644 --- a/src/sqlite_rag/cli.py +++ b/src/sqlite_rag/cli.py @@ -242,17 +242,33 @@ def add( help="Optional metadata in JSON format to associate with the document", metavar="JSON", ), + only_extensions: Optional[str] = typer.Option( + None, + "--only", + help="Only process these file extensions from supported list (comma-separated, e.g. 'py,js')", + ), + exclude_extensions: Optional[str] = typer.Option( + None, + "--exclude", + help="File extensions to exclude (comma-separated, e.g. 'py,js')", + ), ): """Add a file path to the database""" rag_context = ctx.obj["rag_context"] start_time = time.time() + # Parse extension lists + only_list = only_extensions.split(",") if only_extensions else None + exclude_list = exclude_extensions.split(",") if exclude_extensions else None + rag = rag_context.get_rag() rag.add( path, recursive=recursive, use_relative_paths=use_relative_paths, metadata=json.loads(metadata or "{}"), + only_extensions=only_list, + exclude_extensions=exclude_list, ) elapsed_time = time.time() - start_time diff --git a/src/sqlite_rag/reader.py b/src/sqlite_rag/reader.py index 4633a96..33ab9af 100644 --- a/src/sqlite_rag/reader.py +++ b/src/sqlite_rag/reader.py @@ -6,44 +6,61 @@ class FileReader: extensions = [ - ".c", - ".cpp", - ".css", - ".csv", - ".docx", - ".go", - ".h", - ".hpp", - ".html", - ".java", - ".js", - ".json", - ".kt", - ".md", - ".mdx", - ".mjs", - ".pdf", - ".php", - ".pptx", - ".py", - ".rb", - ".rs", - ".svelte", - ".swift", - ".ts", - ".tsx", - ".txt", - ".vue", - ".xml", - ".xlsx", - ".yaml", - ".yml", + "c", + "cpp", + "css", + "csv", + "docx", + "go", + "h", + "hpp", + "html", + "java", + "js", + "json", + "kt", + "md", + "mdx", + "mjs", + "pdf", + "php", + "pptx", + "py", + "rb", + "rs", + "svelte", + "swift", + "ts", + "tsx", + "txt", + "vue", + "xml", + "xlsx", + "yaml", + "yml", ] @staticmethod - def is_supported(path: Path) -> bool: + def is_supported( + path: Path, + only_extensions: Optional[list[str]] = None, + exclude_extensions: Optional[list[str]] = None, + ) -> bool: """Check if the file extension is supported""" - return path.suffix.lower() in FileReader.extensions + extension = path.suffix.lower().lstrip(".") + + supported_extensions = set(FileReader.extensions) + exclude_set = set() + + # Only keep those that are in both lists + if only_extensions: + only_set = {ext.lower().lstrip(".") for ext in only_extensions} + supported_extensions &= only_set + + if exclude_extensions: + exclude_set = {ext.lower().lstrip(".") for ext in exclude_extensions} + + return extension in supported_extensions and extension not in exclude_set @staticmethod def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str: @@ -65,12 +82,19 @@ def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str raise ValueError(f"Failed to parse file {path}") from exc @staticmethod - def collect_files(path: Path, recursive: bool = False) -> list[Path]: + def collect_files( + path: Path, + recursive: bool = False, + only_extensions: Optional[list[str]] = None, + exclude_extensions: Optional[list[str]] = None, + ) -> list[Path]: """Collect files from the path, optionally recursively""" if not path.exists(): raise FileNotFoundError(f"{path} does not exist.") - if path.is_file() and FileReader.is_supported(path): + if path.is_file() and FileReader.is_supported( + path, only_extensions, exclude_extensions + ): return [path] files_to_process = [] @@ -83,7 +107,8 @@ def collect_files(path: Path, recursive: bool = False) -> list[Path]: files_to_process = [ f for f in files_to_process - if f.is_file() and FileReader.is_supported(f) + if f.is_file() + and FileReader.is_supported(f, only_extensions, exclude_extensions) ] return files_to_process diff --git a/src/sqlite_rag/sqliterag.py b/src/sqlite_rag/sqliterag.py index 84927d6..8be35b6 100644 --- a/src/sqlite_rag/sqliterag.py +++ b/src/sqlite_rag/sqliterag.py @@ -72,8 +72,19 @@ def add( recursive: bool = False, use_relative_paths: bool = False, metadata: dict = {}, + only_extensions: Optional[list[str]] = None, + exclude_extensions: Optional[list[str]] = None, ) -> int: - """Add the file content into the database""" + """Add the file content into the database + + Args: + path: File or directory path to add + recursive: Recursively add files in directories + use_relative_paths: Store relative paths instead of absolute paths + metadata: Metadata to associate with documents + only_extensions: Only process these file extensions from the supported list (e.g. ['py', 'js']) + exclude_extensions: Skip these file extensions (e.g. ['py', 'js']) + """ self._ensure_initialized() if not Path(path).exists(): @@ -81,7 +92,12 @@ def add( parent = Path(path).parent - files_to_process = FileReader.collect_files(Path(path), recursive=recursive) + files_to_process = FileReader.collect_files( + Path(path), + recursive=recursive, + only_extensions=only_extensions, + exclude_extensions=exclude_extensions, + ) self._engine.create_new_context() diff --git a/tests/test_reader.py b/tests/test_reader.py index 2c769f5..a1fa266 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -62,13 +62,102 @@ def test_collect_files_recursive_directory(self): assert file2 in files def test_is_supported(self): - unsupported_extensions = [".exe", ".bin", ".jpg", ".png"] + unsupported_extensions = ["exe", "bin", "jpg", "png"] for ext in FileReader.extensions: - assert FileReader.is_supported(Path(f"test{ext}")) + assert FileReader.is_supported(Path(f"test.{ext}")) for ext in unsupported_extensions: - assert not FileReader.is_supported(Path(f"test{ext}")) + assert not FileReader.is_supported(Path(f"test.{ext}")) + + def test_is_supported_with_only_extensions(self): + """Test is_supported with only_extensions parameter""" + # Test with only_extensions - should only allow specified extensions + assert FileReader.is_supported(Path("test.py"), only_extensions=["py", "js"]) + assert FileReader.is_supported(Path("test.js"), only_extensions=["py", "js"]) + assert not FileReader.is_supported( + Path("test.txt"), only_extensions=["py", "js"] + ) + assert not FileReader.is_supported( + Path("test.md"), only_extensions=["py", "js"] + ) + + # Test with dots in extensions (should be normalized) + assert FileReader.is_supported(Path("test.py"), only_extensions=[".py", ".js"]) + assert FileReader.is_supported(Path("test.js"), only_extensions=[".py", ".js"]) + + # Test case insensitive + assert FileReader.is_supported(Path("test.py"), only_extensions=["PY", "JS"]) + assert FileReader.is_supported(Path("test.JS"), only_extensions=["py", "js"]) + + def test_is_supported_with_exclude_extensions(self): + """Test is_supported with exclude_extensions parameter""" + # Test basic exclusion - py files should be excluded + assert not FileReader.is_supported(Path("test.py"), exclude_extensions=["py"]) + assert FileReader.is_supported(Path("test.js"), exclude_extensions=["py"]) + assert FileReader.is_supported(Path("test.txt"), exclude_extensions=["py"]) + + # Test with dots in extensions (should be normalized) + assert not FileReader.is_supported(Path("test.py"), exclude_extensions=[".py"]) + assert FileReader.is_supported(Path("test.js"), exclude_extensions=[".py"]) + + # Test case insensitive + assert not FileReader.is_supported(Path("test.py"), exclude_extensions=["PY"]) + assert not FileReader.is_supported(Path("test.PY"), exclude_extensions=["py"]) + + # Test multiple exclusions + assert not FileReader.is_supported( + Path("test.py"), exclude_extensions=["py", "js"] + ) + assert not FileReader.is_supported( + Path("test.js"), exclude_extensions=["py", "js"] + ) + assert FileReader.is_supported( + Path("test.txt"), exclude_extensions=["py", "js"] + ) + + def test_is_supported_with_only_and_exclude_extensions(self): + """Test is_supported with both only_extensions and exclude_extensions""" + # Include py and js, but exclude py - should only allow js + assert not FileReader.is_supported( + Path("test.py"), only_extensions=["py", "js"], exclude_extensions=["py"] + ) + assert FileReader.is_supported( + Path("test.js"), only_extensions=["py", "js"], exclude_extensions=["py"] + ) + assert not FileReader.is_supported( + Path("test.txt"), only_extensions=["py", "js"], exclude_extensions=["py"] + ) + + # Include py, txt, md, but exclude md - should only allow py and txt + assert FileReader.is_supported( + Path("test.py"), + only_extensions=["py", "txt", "md"], + exclude_extensions=["md"], + ) + assert FileReader.is_supported( + Path("test.txt"), + only_extensions=["py", "txt", "md"], + exclude_extensions=["md"], + ) + assert not FileReader.is_supported( + Path("test.md"), + only_extensions=["py", "txt", "md"], + exclude_extensions=["md"], + ) + assert not FileReader.is_supported( + Path("test.js"), + only_extensions=["py", "txt", "md"], + exclude_extensions=["md"], + ) + + def test_is_supported_with_unsupported_extensions_in_only(self): + """Test that only_extensions can't add unsupported extensions""" + # .exe is not in FileReader.extensions, so should not be supported even if in only_extensions + assert not FileReader.is_supported( + Path("test.exe"), only_extensions=["exe", "py"] + ) + assert FileReader.is_supported(Path("test.py"), only_extensions=["exe", "py"]) def test_parse_html_into_markdown(self): with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as f: diff --git a/tests/test_sqlite_rag.py b/tests/test_sqlite_rag.py index db5e556..19fedb7 100644 --- a/tests/test_sqlite_rag.py +++ b/tests/test_sqlite_rag.py @@ -302,6 +302,116 @@ def test_add_markdown_with_frontmatter(self): assert "author" in metadata["extracted"] assert metadata["extracted"]["author"] == "Test Author" + def test_add_with_only_extensions(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create files with supported extensions + py_file = Path(temp_dir) / "script.py" + js_file = Path(temp_dir) / "app.js" + txt_file = Path(temp_dir) / "readme.txt" + md_file = Path(temp_dir) / "docs.md" + + py_file.write_text("print('hello world')") + js_file.write_text("console.log('hello world')") + txt_file.write_text("This is a readme file") + md_file.write_text("# Documentation") + + rag = SQLiteRag.create(":memory:") + + # Add with only_extensions - only process py and js files + processed = rag.add(temp_dir, only_extensions=["py", "js"]) + + assert processed == 2 # Only py and js files should be processed + + conn = rag._conn + cursor = conn.execute("SELECT uri FROM documents ORDER BY uri") + docs = cursor.fetchall() + assert len(docs) == 2 + uris = [doc[0] for doc in docs] + assert any("script.py" in uri for uri in uris) + assert any("app.js" in uri for uri in uris) + assert not any("readme.txt" in uri for uri in uris) + assert not any("docs.md" in uri for uri in uris) + + def test_add_with_exclude_extensions(self): + with tempfile.TemporaryDirectory() as temp_dir: + txt_file = Path(temp_dir) / "document.txt" + md_file = Path(temp_dir) / "document.md" + py_file = Path(temp_dir) / "script.py" + + txt_file.write_text("Text document content") + md_file.write_text("# Markdown document") + py_file.write_text("print('hello world')") + + rag = SQLiteRag.create(":memory:") + + # Add with exclude_extensions + processed = rag.add(temp_dir, exclude_extensions=["py"]) + + assert processed == 2 # Only txt and md files should be processed + + conn = rag._conn + cursor = conn.execute("SELECT uri FROM documents ORDER BY uri") + docs = cursor.fetchall() + assert len(docs) == 2 + # Should not contain the py file + uris = [doc[0] for doc in docs] + assert any("document.txt" in uri for uri in uris) + assert any("document.md" in uri for uri in uris) + assert not any("script.py" in uri for uri in uris) + + def test_add_with_only_and_exclude_extensions(self): + with tempfile.TemporaryDirectory() as temp_dir: + py_file = Path(temp_dir) / "script.py" + txt_file = Path(temp_dir) / "document.txt" + md_file = Path(temp_dir) / "readme.md" + js_file = Path(temp_dir) / "app.js" + + py_file.write_text("print('hello world')") + txt_file.write_text("Text document") + md_file.write_text("# Markdown document") + js_file.write_text("console.log('hello')") + + rag = SQLiteRag.create(":memory:") + + # Only .py and .txt files but exclude .py files + processed = rag.add( + temp_dir, only_extensions=["py", "txt"], exclude_extensions=["py"] + ) + + assert ( + processed == 1 + ) # Only txt file (py excluded, md not in only list, js not in only list) + + conn = rag._conn + cursor = conn.execute("SELECT uri FROM documents ORDER BY uri") + docs = cursor.fetchall() + uris = [doc[0] for doc in docs] + assert any("document.txt" in uri for uri in uris) + assert not any("script.py" in uri for uri in uris) # Excluded + assert not any("readme.md" in uri for uri in uris) # Not in only list + assert not any("app.js" in uri for uri in uris) # Not in only list + + def test_add_single_file_with_only_extensions(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write("print('hello world')") + temp_file_path = f.name + + rag = SQLiteRag.create(":memory:") + + # With only_extensions=['txt'], should not be processed (py not in list) + processed = rag.add(temp_file_path, only_extensions=["txt"]) + assert processed == 0 + + # With only_extensions=['py'], should be processed + processed = rag.add(temp_file_path, only_extensions=["py"]) + assert processed == 1 + + conn = rag._conn + cursor = conn.execute("SELECT content FROM documents") + doc = cursor.fetchone() + assert doc + assert "print('hello world')" in doc[0] + class TestSQLiteRag: def test_list_documents(self): From 6f6d3d31333c1475e32d1db16b62ed350d8806ac Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Tue, 7 Oct 2025 14:11:14 +0200 Subject: [PATCH 2/4] Update src/sqlite_rag/cli.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/sqlite_rag/cli.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/sqlite_rag/cli.py b/src/sqlite_rag/cli.py index e2108e5..e3d1aeb 100644 --- a/src/sqlite_rag/cli.py +++ b/src/sqlite_rag/cli.py @@ -257,9 +257,15 @@ def add( rag_context = ctx.obj["rag_context"] start_time = time.time() - # Parse extension lists - only_list = only_extensions.split(",") if only_extensions else None - exclude_list = exclude_extensions.split(",") if exclude_extensions else None + # Parse and normalize extension lists + only_list = ( + [e.strip().lstrip(".").lower() for e in only_extensions.split(",") if e.strip()] + if only_extensions else None + ) + exclude_list = ( + [e.strip().lstrip(".").lower() for e in exclude_extensions.split(",") if e.strip()] + if exclude_extensions else None + ) rag = rag_context.get_rag() rag.add( From 3a44b2f9000c619fc46766499c22356e9fd4b6bb Mon Sep 17 00:00:00 2001 From: Daniele Briggi <=> Date: Tue, 7 Oct 2025 12:53:17 +0000 Subject: [PATCH 3/4] chore(tests): add cli tests --- src/sqlite_rag/cli.py | 13 ++++-- src/sqlite_rag/reader.py | 8 +++- tests/integration/test_cli.py | 83 +++++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 5 deletions(-) diff --git a/src/sqlite_rag/cli.py b/src/sqlite_rag/cli.py index e3d1aeb..cb98a61 100644 --- a/src/sqlite_rag/cli.py +++ b/src/sqlite_rag/cli.py @@ -257,14 +257,19 @@ def add( rag_context = ctx.obj["rag_context"] start_time = time.time() - # Parse and normalize extension lists only_list = ( [e.strip().lstrip(".").lower() for e in only_extensions.split(",") if e.strip()] - if only_extensions else None + if only_extensions + else None ) exclude_list = ( - [e.strip().lstrip(".").lower() for e in exclude_extensions.split(",") if e.strip()] - if exclude_extensions else None + [ + e.strip().lstrip(".").lower() + for e in exclude_extensions.split(",") + if e.strip() + ] + if exclude_extensions + else None ) rag = rag_context.get_rag() diff --git a/src/sqlite_rag/reader.py b/src/sqlite_rag/reader.py index 33ab9af..6950f91 100644 --- a/src/sqlite_rag/reader.py +++ b/src/sqlite_rag/reader.py @@ -46,7 +46,13 @@ def is_supported( only_extensions: Optional[list[str]] = None, exclude_extensions: Optional[list[str]] = None, ) -> bool: - """Check if the file extension is supported""" + """Check if the file extension is supported. + + Parameters: + path (Path): The file path to check. + only_extensions (Optional[list[str]]): If provided, only files with these extensions are considered. + exclude_extensions (Optional[list[str]]): If provided, files with these extensions are excluded. + """ extension = path.suffix.lower().lstrip(".") supported_extensions = set(FileReader.extensions) diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index e0de2b2..facc137 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -113,3 +113,86 @@ def test_change_database_path(self): assert result.exit_code == 0 assert f"Database: {tmp_db.name}" in result.stdout + + def test_add_with_exclude_extensions(self): + with tempfile.TemporaryDirectory() as tmp_dir: + (Path(tmp_dir) / "file1.txt").write_text("This is a text file.") + (Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.") + (Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')") + (Path(tmp_dir) / "file4.js").write_text("console.log('Hello, world!');") + + with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db: + runner = CliRunner() + + result = runner.invoke( + app, + ["--database", tmp_db.name, "add", tmp_dir, "--exclude", "py,js"], + ) + assert result.exit_code == 0 + + # Check that only .txt and .md files were added + assert "Processing 2 files" in result.stdout + assert "file1.txt" in result.stdout + assert "file2.md" in result.stdout + assert "file3.py" not in result.stdout + assert "file4.js" not in result.stdout + + def test_add_with_only_extensions(self): + with tempfile.TemporaryDirectory() as tmp_dir: + (Path(tmp_dir) / "file1.txt").write_text("This is a text file.") + (Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.") + (Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')") + (Path(tmp_dir) / "file4.js").write_text("console.log('Hello, world!');") + + with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db: + runner = CliRunner() + + result = runner.invoke( + app, + [ + "--database", + tmp_db.name, + "add", + tmp_dir, + "--only", + "md,txt", + ], + ) + assert result.exit_code == 0 + + # Check that only .txt and .md files were added + assert "Processing 2 files" in result.stdout + assert "file1.txt" in result.stdout + assert "file2.md" in result.stdout + assert "file3.py" not in result.stdout + assert "file4.js" not in result.stdout + + def test_add_with_only_and_exclude_extensions_are_normilized(self): + with tempfile.TemporaryDirectory() as tmp_dir: + (Path(tmp_dir) / "file1.txt").write_text("This is a text file.") + (Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.") + (Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')") + + with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db: + runner = CliRunner() + + result = runner.invoke( + app, + [ + "--database", + tmp_db.name, + "add", + tmp_dir, + "--only", + ".md, .txt,py", + "--exclude", + ".py ", # wins over --only + ], + ) + assert result.exit_code == 0 + + # Check that only .txt and .md files were added + assert "Processing 2 files" in result.stdout + assert "file1.txt" in result.stdout + assert "file2.md" in result.stdout + assert "file3.py" not in result.stdout From fe65f7bf6e395000ac8fdfd03b49899e0450eb3a Mon Sep 17 00:00:00 2001 From: Daniele Briggi <=> Date: Tue, 7 Oct 2025 13:05:10 +0000 Subject: [PATCH 4/4] fix(cli): settings command show a message when missing db file --- src/sqlite_rag/cli.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/sqlite_rag/cli.py b/src/sqlite_rag/cli.py index cb98a61..6574889 100644 --- a/src/sqlite_rag/cli.py +++ b/src/sqlite_rag/cli.py @@ -96,7 +96,12 @@ def main( def show_settings(ctx: typer.Context): """Show current settings""" rag_context = ctx.obj["rag_context"] - rag = rag_context.get_rag(require_existing=True) + try: + rag = rag_context.get_rag(require_existing=True) + except FileNotFoundError: + typer.echo("Database not found. No settings available.") + raise typer.Exit(1) + current_settings = rag.get_settings() typer.echo("Current settings:")