Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/sqlite_rag/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def main(
def show_settings(ctx: typer.Context):
"""Show current settings"""
rag_context = ctx.obj["rag_context"]
rag = rag_context.get_rag(require_existing=True)
try:
rag = rag_context.get_rag(require_existing=True)
except FileNotFoundError:
typer.echo("Database not found. No settings available.")
raise typer.Exit(1)

current_settings = rag.get_settings()

typer.echo("Current settings:")
Expand Down Expand Up @@ -242,17 +247,44 @@ def add(
help="Optional metadata in JSON format to associate with the document",
metavar="JSON",
),
only_extensions: Optional[str] = typer.Option(
None,
"--only",
help="Only process these file extensions from supported list (comma-separated, e.g. 'py,js')",
),
exclude_extensions: Optional[str] = typer.Option(
None,
"--exclude",
help="File extensions to exclude (comma-separated, e.g. 'py,js')",
),
):
"""Add a file path to the database"""
rag_context = ctx.obj["rag_context"]
start_time = time.time()

only_list = (
[e.strip().lstrip(".").lower() for e in only_extensions.split(",") if e.strip()]
if only_extensions
else None
)
exclude_list = (
[
e.strip().lstrip(".").lower()
for e in exclude_extensions.split(",")
if e.strip()
]
if exclude_extensions
else None
)

rag = rag_context.get_rag()
rag.add(
path,
recursive=recursive,
use_relative_paths=use_relative_paths,
metadata=json.loads(metadata or "{}"),
only_extensions=only_list,
exclude_extensions=exclude_list,
)

elapsed_time = time.time() - start_time
Expand Down
107 changes: 69 additions & 38 deletions src/sqlite_rag/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,67 @@

class FileReader:
extensions = [
".c",
".cpp",
".css",
".csv",
".docx",
".go",
".h",
".hpp",
".html",
".java",
".js",
".json",
".kt",
".md",
".mdx",
".mjs",
".pdf",
".php",
".pptx",
".py",
".rb",
".rs",
".svelte",
".swift",
".ts",
".tsx",
".txt",
".vue",
".xml",
".xlsx",
".yaml",
".yml",
"c",
"cpp",
"css",
"csv",
"docx",
"go",
"h",
"hpp",
"html",
"java",
"js",
"json",
"kt",
"md",
"mdx",
"mjs",
"pdf",
"php",
"pptx",
"py",
"rb",
"rs",
"svelte",
"swift",
"ts",
"tsx",
"txt",
"vue",
"xml",
"xlsx",
"yaml",
"yml",
]

@staticmethod
def is_supported(path: Path) -> bool:
"""Check if the file extension is supported"""
return path.suffix.lower() in FileReader.extensions
def is_supported(
path: Path,
only_extensions: Optional[list[str]] = None,
exclude_extensions: Optional[list[str]] = None,
) -> bool:
"""Check if the file extension is supported.

Parameters:
path (Path): The file path to check.
only_extensions (Optional[list[str]]): If provided, only files with these extensions are considered.
exclude_extensions (Optional[list[str]]): If provided, files with these extensions are excluded.
"""
extension = path.suffix.lower().lstrip(".")

supported_extensions = set(FileReader.extensions)
exclude_set = set()

# Only keep those that are in both lists
if only_extensions:
only_set = {ext.lower().lstrip(".") for ext in only_extensions}
supported_extensions &= only_set

if exclude_extensions:
exclude_set = {ext.lower().lstrip(".") for ext in exclude_extensions}

return extension in supported_extensions and extension not in exclude_set

@staticmethod
def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str:
Expand All @@ -65,12 +88,19 @@ def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str
raise ValueError(f"Failed to parse file {path}") from exc

@staticmethod
def collect_files(path: Path, recursive: bool = False) -> list[Path]:
def collect_files(
path: Path,
recursive: bool = False,
only_extensions: Optional[list[str]] = None,
exclude_extensions: Optional[list[str]] = None,
) -> list[Path]:
"""Collect files from the path, optionally recursively"""
if not path.exists():
raise FileNotFoundError(f"{path} does not exist.")

if path.is_file() and FileReader.is_supported(path):
if path.is_file() and FileReader.is_supported(
path, only_extensions, exclude_extensions
):
return [path]

files_to_process = []
Expand All @@ -83,7 +113,8 @@ def collect_files(path: Path, recursive: bool = False) -> list[Path]:
files_to_process = [
f
for f in files_to_process
if f.is_file() and FileReader.is_supported(f)
if f.is_file()
and FileReader.is_supported(f, only_extensions, exclude_extensions)
]

return files_to_process
20 changes: 18 additions & 2 deletions src/sqlite_rag/sqliterag.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,32 @@ def add(
recursive: bool = False,
use_relative_paths: bool = False,
metadata: dict = {},
only_extensions: Optional[list[str]] = None,
exclude_extensions: Optional[list[str]] = None,
) -> int:
"""Add the file content into the database"""
"""Add the file content into the database

Args:
path: File or directory path to add
recursive: Recursively add files in directories
use_relative_paths: Store relative paths instead of absolute paths
metadata: Metadata to associate with documents
only_extensions: Only process these file extensions from the supported list (e.g. ['py', 'js'])
exclude_extensions: Skip these file extensions (e.g. ['py', 'js'])
"""
self._ensure_initialized()

if not Path(path).exists():
raise FileNotFoundError(f"{path} does not exist.")

parent = Path(path).parent

files_to_process = FileReader.collect_files(Path(path), recursive=recursive)
files_to_process = FileReader.collect_files(
Path(path),
recursive=recursive,
only_extensions=only_extensions,
exclude_extensions=exclude_extensions,
)

self._engine.create_new_context()

Expand Down
83 changes: 83 additions & 0 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,86 @@ def test_change_database_path(self):
assert result.exit_code == 0

assert f"Database: {tmp_db.name}" in result.stdout

def test_add_with_exclude_extensions(self):
with tempfile.TemporaryDirectory() as tmp_dir:
(Path(tmp_dir) / "file1.txt").write_text("This is a text file.")
(Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.")
(Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')")
(Path(tmp_dir) / "file4.js").write_text("console.log('Hello, world!');")

with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
runner = CliRunner()

result = runner.invoke(
app,
["--database", tmp_db.name, "add", tmp_dir, "--exclude", "py,js"],
)
assert result.exit_code == 0

# Check that only .txt and .md files were added
assert "Processing 2 files" in result.stdout
assert "file1.txt" in result.stdout
assert "file2.md" in result.stdout
assert "file3.py" not in result.stdout
assert "file4.js" not in result.stdout

def test_add_with_only_extensions(self):
with tempfile.TemporaryDirectory() as tmp_dir:
(Path(tmp_dir) / "file1.txt").write_text("This is a text file.")
(Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.")
(Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')")
(Path(tmp_dir) / "file4.js").write_text("console.log('Hello, world!');")

with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
runner = CliRunner()

result = runner.invoke(
app,
[
"--database",
tmp_db.name,
"add",
tmp_dir,
"--only",
"md,txt",
],
)
assert result.exit_code == 0

# Check that only .txt and .md files were added
assert "Processing 2 files" in result.stdout
assert "file1.txt" in result.stdout
assert "file2.md" in result.stdout
assert "file3.py" not in result.stdout
assert "file4.js" not in result.stdout

def test_add_with_only_and_exclude_extensions_are_normilized(self):
with tempfile.TemporaryDirectory() as tmp_dir:
(Path(tmp_dir) / "file1.txt").write_text("This is a text file.")
(Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.")
(Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')")

with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
runner = CliRunner()

result = runner.invoke(
app,
[
"--database",
tmp_db.name,
"add",
tmp_dir,
"--only",
".md, .txt,py",
"--exclude",
".py ", # wins over --only
],
)
assert result.exit_code == 0

# Check that only .txt and .md files were added
assert "Processing 2 files" in result.stdout
assert "file1.txt" in result.stdout
assert "file2.md" in result.stdout
assert "file3.py" not in result.stdout
Loading