diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 68eedb48..9dcf517a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -76,3 +76,10 @@ repos:
rev: v1.36.4
hooks:
- id: djlint-reformat-jinja
+
+ - repo: https://github.com/igorshubovych/markdownlint-cli
+ rev: v0.43.0
+ hooks:
+ - id: markdownlint
+ description: "Lint markdown files."
+ args: ["--disable=line-length"]
diff --git a/README.md b/README.md
index 6d0747a2..7e02c467 100644
--- a/README.md
+++ b/README.md
@@ -1,32 +1,13 @@
-[](https://gitingest.com)
+# GitIngest
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+[](https://github.com/cyclotruc/gitingest/blob/main/LICENSE)
+[](https://badge.fury.io/py/gitingest)
+[](https://pepy.tech/project/gitingest)
+[](https://github.com/cyclotruc/gitingest/issues)
+[](https://github.com/psf/black)
+[](https://discord.com/invite/zerRaGK9EC)
-# GitIngest
+[](https://gitingest.com)
Turn any Git repository into a prompt-friendly text ingest for LLMs.
@@ -92,15 +73,15 @@ By default, this won't write a file but can be enabled with the `output` argumen
1. Build the image:
-``` bash
-docker build -t gitingest .
-```
+ ``` bash
+ docker build -t gitingest .
+ ```
2. Run the container:
-``` bash
-docker run -d --name gitingest -p 8000:8000 gitingest
-```
+ ``` bash
+ docker run -d --name gitingest -p 8000:8000 gitingest
+ ```
The application will be available at `http://localhost:8000`
Ensure environment variables are set before running the application or deploying it via Docker.
@@ -135,22 +116,20 @@ ALLOWED_HOSTS="gitingest.local,localhost"
1. Clone the repository
-```bash
-git clone https://github.com/cyclotruc/gitingest.git
-cd gitingest
-```
+ ```bash
+ git clone https://github.com/cyclotruc/gitingest.git
+ cd gitingest
+ ```
2. Install dependencies
-```bash
-pip install -r requirements.txt
-```
+ ```bash
+ pip install -r requirements.txt
+ ```
3. Run the application:
-```bash
-cd src
-uvicorn main:app --reload
-```
-
-The frontend will be available at `localhost:8000`
+ ```bash
+ cd src
+ uvicorn main:app --reload
+ ```
diff --git a/src/config.py b/src/config.py
index b918fb2a..8da41da8 100644
--- a/src/config.py
+++ b/src/config.py
@@ -1,7 +1,7 @@
-MAX_DISPLAY_SIZE = 300_000
-TMP_BASE_PATH = "../tmp"
+MAX_DISPLAY_SIZE: int = 300_000
+TMP_BASE_PATH: str = "../tmp"
-EXAMPLE_REPOS = [
+EXAMPLE_REPOS: list[dict[str, str]] = [
{"name": "Gitingest", "url": "https://github.com/cyclotruc/gitingest"},
{"name": "FastAPI", "url": "https://github.com/tiangolo/fastapi"},
{"name": "Flask", "url": "https://github.com/pallets/flask"},
diff --git a/src/gitingest/cli.py b/src/gitingest/cli.py
index c5f8a493..f275efac 100644
--- a/src/gitingest/cli.py
+++ b/src/gitingest/cli.py
@@ -1,19 +1,9 @@
-import os
-
import click
from gitingest.ingest import ingest
from gitingest.ingest_from_query import MAX_FILE_SIZE
-def normalize_pattern(pattern: str) -> str:
- pattern = pattern.strip()
- pattern = pattern.lstrip(os.sep)
- if pattern.endswith(os.sep):
- pattern += "*"
- return pattern
-
-
@click.command()
@click.argument("source", type=str, required=True)
@click.option("--output", "-o", default=None, help="Output file path (default: .txt in current directory)")
diff --git a/src/gitingest/clone.py b/src/gitingest/clone.py
index 97a990a2..a91b6a99 100644
--- a/src/gitingest/clone.py
+++ b/src/gitingest/clone.py
@@ -3,7 +3,7 @@
from gitingest.utils import AsyncTimeoutError, async_timeout
-CLONE_TIMEOUT = 20
+CLONE_TIMEOUT: int = 20
@dataclass
@@ -14,67 +14,6 @@ class CloneConfig:
branch: str | None = None
-async def check_repo_exists(url: str) -> bool:
- """
- Check if a repository exists at the given URL using an HTTP HEAD request.
-
- Parameters
- ----------
- url : str
- The URL of the repository.
-
- Returns
- -------
- bool
- True if the repository exists, False otherwise.
- """
- proc = await asyncio.create_subprocess_exec(
- "curl",
- "-I",
- url,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE,
- )
- stdout, _ = await proc.communicate()
- if proc.returncode != 0:
- return False
- # Check if stdout contains "404" status code
- stdout_str = stdout.decode()
- return "HTTP/1.1 404" not in stdout_str and "HTTP/2 404" not in stdout_str
-
-
-async def run_git_command(*args: str) -> tuple[bytes, bytes]:
- """
- Executes a git command asynchronously and captures its output.
-
- Parameters
- ----------
- *args : str
- The git command and its arguments to execute.
-
- Returns
- -------
- Tuple[bytes, bytes]
- A tuple containing the stdout and stderr of the git command.
-
- Raises
- ------
- RuntimeError
- If the git command exits with a non-zero status.
- """
- proc = await asyncio.create_subprocess_exec(
- *args,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE,
- )
- stdout, stderr = await proc.communicate()
- if proc.returncode != 0:
- error_message = stderr.decode().strip()
- raise RuntimeError(f"Git command failed: {' '.join(args)}\nError: {error_message}")
-
- return stdout, stderr
-
-
@async_timeout(CLONE_TIMEOUT)
async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]:
"""
@@ -116,7 +55,7 @@ async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]:
raise ValueError("The 'local_path' parameter is required.")
# Check if the repository exists
- if not await check_repo_exists(url):
+ if not await _check_repo_exists(url):
raise ValueError("Repository not found, make sure it is public")
try:
@@ -124,21 +63,82 @@ async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]:
# Scenario 1: Clone and checkout a specific commit
# Clone the repository without depth to ensure full history for checkout
clone_cmd = ["git", "clone", "--single-branch", url, local_path]
- await run_git_command(*clone_cmd)
+ await _run_git_command(*clone_cmd)
# Checkout the specific commit
checkout_cmd = ["git", "-C", local_path, "checkout", commit]
- return await run_git_command(*checkout_cmd)
+ return await _run_git_command(*checkout_cmd)
if branch and branch.lower() not in ("main", "master"):
# Scenario 2: Clone a specific branch with shallow depth
clone_cmd = ["git", "clone", "--depth=1", "--single-branch", "--branch", branch, url, local_path]
- return await run_git_command(*clone_cmd)
+ return await _run_git_command(*clone_cmd)
# Scenario 3: Clone the default branch with shallow depth
clone_cmd = ["git", "clone", "--depth=1", "--single-branch", url, local_path]
- return await run_git_command(*clone_cmd)
+ return await _run_git_command(*clone_cmd)
except (RuntimeError, asyncio.TimeoutError, AsyncTimeoutError):
raise # Re-raise the exception
+
+
+async def _check_repo_exists(url: str) -> bool:
+ """
+ Check if a repository exists at the given URL using an HTTP HEAD request.
+
+ Parameters
+ ----------
+ url : str
+ The URL of the repository.
+
+ Returns
+ -------
+ bool
+ True if the repository exists, False otherwise.
+ """
+ proc = await asyncio.create_subprocess_exec(
+ "curl",
+ "-I",
+ url,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
+ stdout, _ = await proc.communicate()
+ if proc.returncode != 0:
+ return False
+ # Check if stdout contains "404" status code
+ stdout_str = stdout.decode()
+ return "HTTP/1.1 404" not in stdout_str and "HTTP/2 404" not in stdout_str
+
+
+async def _run_git_command(*args: str) -> tuple[bytes, bytes]:
+ """
+ Executes a git command asynchronously and captures its output.
+
+ Parameters
+ ----------
+ *args : str
+ The git command and its arguments to execute.
+
+ Returns
+ -------
+ Tuple[bytes, bytes]
+ A tuple containing the stdout and stderr of the git command.
+
+ Raises
+ ------
+ RuntimeError
+ If the git command exits with a non-zero status.
+ """
+ proc = await asyncio.create_subprocess_exec(
+ *args,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
+ stdout, stderr = await proc.communicate()
+ if proc.returncode != 0:
+ error_message = stderr.decode().strip()
+ raise RuntimeError(f"Git command failed: {' '.join(args)}\nError: {error_message}")
+
+ return stdout, stderr
diff --git a/src/gitingest/ingest_from_query.py b/src/gitingest/ingest_from_query.py
index 51cca8d2..886afa26 100644
--- a/src/gitingest/ingest_from_query.py
+++ b/src/gitingest/ingest_from_query.py
@@ -10,7 +10,7 @@
MAX_TOTAL_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB
-def should_include(path: str, base_path: str, include_patterns: list[str]) -> bool:
+def _should_include(path: str, base_path: str, include_patterns: list[str]) -> bool:
rel_path = path.replace(base_path, "").lstrip(os.sep)
include = False
for pattern in include_patterns:
@@ -19,17 +19,15 @@ def should_include(path: str, base_path: str, include_patterns: list[str]) -> bo
return include
-def should_exclude(path: str, base_path: str, ignore_patterns: list[str]) -> bool:
+def _should_exclude(path: str, base_path: str, ignore_patterns: list[str]) -> bool:
rel_path = path.replace(base_path, "").lstrip(os.sep)
for pattern in ignore_patterns:
- if pattern == "":
- continue
- if fnmatch(rel_path, pattern):
+ if pattern and fnmatch(rel_path, pattern):
return True
return False
-def is_safe_symlink(symlink_path: str, base_path: str) -> bool:
+def _is_safe_symlink(symlink_path: str, base_path: str) -> bool:
"""Check if a symlink points to a location within the base directory."""
try:
target_path = os.path.realpath(symlink_path)
@@ -40,7 +38,7 @@ def is_safe_symlink(symlink_path: str, base_path: str) -> bool:
return False
-def is_text_file(file_path: str) -> bool:
+def _is_text_file(file_path: str) -> bool:
"""Determines if a file is likely a text file based on its content."""
try:
with open(file_path, "rb") as file:
@@ -50,7 +48,7 @@ def is_text_file(file_path: str) -> bool:
return False
-def read_file_content(file_path: str) -> str:
+def _read_file_content(file_path: str) -> str:
try:
with open(file_path, encoding="utf-8", errors="ignore") as f:
return f.read()
@@ -58,7 +56,7 @@ def read_file_content(file_path: str) -> str:
return f"Error reading file: {str(e)}"
-def scan_directory(
+def _scan_directory(
path: str,
query: dict[str, Any],
seen_paths: set[str] | None = None,
@@ -68,6 +66,7 @@ def scan_directory(
"""Recursively analyzes a directory and its contents with safety limits."""
if seen_paths is None:
seen_paths = set()
+
if stats is None:
stats = {"total_files": 0, "total_size": 0}
@@ -109,18 +108,18 @@ def scan_directory(
for item in os.listdir(path):
item_path = os.path.join(path, item)
- if should_exclude(item_path, base_path, ignore_patterns):
+ if _should_exclude(item_path, base_path, ignore_patterns):
continue
is_file = os.path.isfile(item_path)
if is_file and query["include_patterns"]:
- if not should_include(item_path, base_path, include_patterns):
+ if not _should_include(item_path, base_path, include_patterns):
result["ignore_content"] = True
continue
# Handle symlinks
if os.path.islink(item_path):
- if not is_safe_symlink(item_path, base_path):
+ if not _is_safe_symlink(item_path, base_path):
print(f"Skipping symlink that points outside base directory: {item_path}")
continue
real_path = os.path.realpath(item_path)
@@ -141,8 +140,8 @@ def scan_directory(
print(f"Maximum file limit ({MAX_FILES}) reached")
return result
- is_text = is_text_file(real_path)
- content = read_file_content(real_path) if is_text else "[Non-text file]"
+ is_text = _is_text_file(real_path)
+ content = _read_file_content(real_path) if is_text else "[Non-text file]"
child = {
"name": item,
@@ -156,7 +155,7 @@ def scan_directory(
result["file_count"] += 1
elif os.path.isdir(real_path):
- subdir = scan_directory(
+ subdir = _scan_directory(
path=real_path,
query=query,
seen_paths=seen_paths,
@@ -185,8 +184,8 @@ def scan_directory(
print(f"Maximum file limit ({MAX_FILES}) reached")
return result
- is_text = is_text_file(item_path)
- content = read_file_content(item_path) if is_text else "[Non-text file]"
+ is_text = _is_text_file(item_path)
+ content = _read_file_content(item_path) if is_text else "[Non-text file]"
child = {
"name": item,
@@ -200,7 +199,7 @@ def scan_directory(
result["file_count"] += 1
elif os.path.isdir(item_path):
- subdir = scan_directory(
+ subdir = _scan_directory(
path=item_path,
query=query,
seen_paths=seen_paths,
@@ -219,7 +218,7 @@ def scan_directory(
return result
-def extract_files_content(
+def _extract_files_content(
query: dict[str, Any],
node: dict[str, Any],
max_file_size: int,
@@ -243,12 +242,12 @@ def extract_files_content(
)
elif node["type"] == "directory":
for child in node["children"]:
- extract_files_content(query=query, node=child, max_file_size=max_file_size, files=files)
+ _extract_files_content(query=query, node=child, max_file_size=max_file_size, files=files)
return files
-def create_file_content_string(files: list[dict[str, Any]]) -> str:
+def _create_file_content_string(files: list[dict[str, Any]]) -> str:
"""Creates a formatted string of file contents with separators."""
output = ""
separator = "=" * 48 + "\n"
@@ -278,7 +277,7 @@ def create_file_content_string(files: list[dict[str, Any]]) -> str:
return output
-def create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str:
+def _create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str:
"""Creates a summary string with file counts and content size."""
if "user_name" in query:
summary = f"Repository: {query['user_name']}/{query['repo_name']}\n"
@@ -297,7 +296,7 @@ def create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str:
return summary
-def create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: str = "", is_last: bool = True) -> str:
+def _create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: str = "", is_last: bool = True) -> str:
"""Creates a tree-like string representation of the file structure."""
tree = ""
@@ -314,12 +313,12 @@ def create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: s
new_prefix = prefix + (" " if is_last else "ā ") if node["name"] else prefix
children = node["children"]
for i, child in enumerate(children):
- tree += create_tree_structure(query, child, new_prefix, i == len(children) - 1)
+ tree += _create_tree_structure(query, child, new_prefix, i == len(children) - 1)
return tree
-def generate_token_string(context_string: str) -> str | None:
+def _generate_token_string(context_string: str) -> str | None:
"""Returns the number of tokens in a text string."""
formatted_tokens = ""
try:
@@ -340,16 +339,16 @@ def generate_token_string(context_string: str) -> str | None:
return formatted_tokens
-def ingest_single_file(path: str, query: dict[str, Any]) -> tuple[str, str, str]:
+def _ingest_single_file(path: str, query: dict[str, Any]) -> tuple[str, str, str]:
if not os.path.isfile(path):
raise ValueError(f"Path {path} is not a file")
file_size = os.path.getsize(path)
- is_text = is_text_file(path)
+ is_text = _is_text_file(path)
if not is_text:
raise ValueError(f"File {path} is not a text file")
- content = read_file_content(path)
+ content = _read_file_content(path)
if file_size > query["max_file_size"]:
content = "[Content ignored: file too large]"
@@ -366,26 +365,26 @@ def ingest_single_file(path: str, query: dict[str, Any]) -> tuple[str, str, str]
f"Lines: {len(content.splitlines()):,}\n"
)
- files_content = create_file_content_string([file_info])
+ files_content = _create_file_content_string([file_info])
tree = "Directory structure:\nāāā " + os.path.basename(path)
- formatted_tokens = generate_token_string(files_content)
+ formatted_tokens = _generate_token_string(files_content)
if formatted_tokens:
summary += f"\nEstimated tokens: {formatted_tokens}"
return summary, tree, files_content
-def ingest_directory(path: str, query: dict[str, Any]) -> tuple[str, str, str]:
- nodes = scan_directory(path=path, query=query)
+def _ingest_directory(path: str, query: dict[str, Any]) -> tuple[str, str, str]:
+ nodes = _scan_directory(path=path, query=query)
if not nodes:
raise ValueError(f"No files found in {path}")
- files = extract_files_content(query=query, node=nodes, max_file_size=query["max_file_size"])
- summary = create_summary_string(query, nodes)
- tree = "Directory structure:\n" + create_tree_structure(query, nodes)
- files_content = create_file_content_string(files)
+ files = _extract_files_content(query=query, node=nodes, max_file_size=query["max_file_size"])
+ summary = _create_summary_string(query, nodes)
+ tree = "Directory structure:\n" + _create_tree_structure(query, nodes)
+ files_content = _create_file_content_string(files)
- formatted_tokens = generate_token_string(tree + files_content)
+ formatted_tokens = _generate_token_string(tree + files_content)
if formatted_tokens:
summary += f"\nEstimated tokens: {formatted_tokens}"
@@ -394,11 +393,11 @@ def ingest_directory(path: str, query: dict[str, Any]) -> tuple[str, str, str]:
def ingest_from_query(query: dict[str, Any]) -> tuple[str, str, str]:
"""Main entry point for analyzing a codebase directory or single file."""
- path = os.path.join(query["local_path"], query["subpath"].lstrip(os.sep))
- if not os.path.exists(path) and not os.path.exists(os.path.dirname(path)):
- raise ValueError(f"{query['subpath']} cannot be found")
+ path = f"{query['local_path']}{query['subpath']}"
+ if not os.path.exists(path):
+ raise ValueError(f"{query['slug']} cannot be found")
if query.get("type") == "blob":
- return ingest_single_file(path, query)
+ return _ingest_single_file(path, query)
- return ingest_directory(path, query)
+ return _ingest_directory(path, query)
diff --git a/src/gitingest/parse_query.py b/src/gitingest/parse_query.py
index 2e6470e1..477520a7 100644
--- a/src/gitingest/parse_query.py
+++ b/src/gitingest/parse_query.py
@@ -6,11 +6,11 @@
from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
-TMP_BASE_PATH = "../tmp"
+TMP_BASE_PATH: str = "../tmp"
HEX_DIGITS = set(string.hexdigits)
-def parse_url(url: str) -> dict[str, Any]:
+def _parse_url(url: str) -> dict[str, Any]:
url = url.split(" ")[0]
url = unquote(url) # Decode URL-encoded characters
@@ -69,14 +69,14 @@ def _is_valid_git_commit_hash(commit: str) -> bool:
return len(commit) == 40 and all(c in HEX_DIGITS for c in commit)
-def normalize_pattern(pattern: str) -> str:
+def _normalize_pattern(pattern: str) -> str:
pattern = pattern.lstrip(os.sep)
if pattern.endswith(os.sep):
pattern += "*"
return pattern
-def parse_patterns(pattern: list[str] | str) -> list[str]:
+def _parse_patterns(pattern: list[str] | str) -> list[str]:
patterns = pattern if isinstance(pattern, list) else [pattern]
patterns = [p.strip() for p in patterns]
@@ -87,10 +87,10 @@ def parse_patterns(pattern: list[str] | str) -> list[str]:
"underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*) are allowed."
)
- return [normalize_pattern(p) for p in patterns]
+ return [_normalize_pattern(p) for p in patterns]
-def override_ignore_patterns(ignore_patterns: list[str], include_patterns: list[str]) -> list[str]:
+def _override_ignore_patterns(ignore_patterns: list[str], include_patterns: list[str]) -> list[str]:
"""
Removes patterns from ignore_patterns that are present in include_patterns using set difference.
@@ -109,7 +109,7 @@ def override_ignore_patterns(ignore_patterns: list[str], include_patterns: list[
return list(set(ignore_patterns) - set(include_patterns))
-def parse_path(path: str) -> dict[str, Any]:
+def _parse_path(path: str) -> dict[str, Any]:
query = {
"url": None,
"local_path": os.path.abspath(path),
@@ -151,19 +151,19 @@ def parse_query(
"""
# Determine the parsing method based on the source type
if from_web or source.startswith("https://") or "github.com" in source:
- query = parse_url(source)
+ query = _parse_url(source)
else:
- query = parse_path(source)
+ query = _parse_path(source)
# Process ignore patterns
ignore_patterns_list = DEFAULT_IGNORE_PATTERNS.copy()
if ignore_patterns:
- ignore_patterns_list += parse_patterns(ignore_patterns)
+ ignore_patterns_list += _parse_patterns(ignore_patterns)
# Process include patterns and override ignore patterns accordingly
if include_patterns:
- parsed_include = parse_patterns(include_patterns)
- ignore_patterns_list = override_ignore_patterns(ignore_patterns_list, include_patterns=parsed_include)
+ parsed_include = _parse_patterns(include_patterns)
+ ignore_patterns_list = _override_ignore_patterns(ignore_patterns_list, include_patterns=parsed_include)
else:
parsed_include = None
diff --git a/src/gitingest/tests/test_clone.py b/src/gitingest/tests/test_clone.py
index 585ba6eb..e3b81289 100644
--- a/src/gitingest/tests/test_clone.py
+++ b/src/gitingest/tests/test_clone.py
@@ -2,7 +2,7 @@
import pytest
-from gitingest.clone import CloneConfig, check_repo_exists, clone_repo
+from gitingest.clone import CloneConfig, _check_repo_exists, clone_repo
@pytest.mark.asyncio
@@ -14,9 +14,8 @@ async def test_clone_repo_with_commit() -> None:
branch="main",
)
- with patch("gitingest.clone.check_repo_exists", return_value=True) as mock_check:
- with patch("gitingest.clone.run_git_command", new_callable=AsyncMock) as mock_exec:
-
+ with patch("gitingest.clone._check_repo_exists", return_value=True) as mock_check:
+ with patch("gitingest.clone._run_git_command", new_callable=AsyncMock) as mock_exec:
mock_process = AsyncMock()
mock_process.communicate.return_value = (b"output", b"error")
mock_exec.return_value = mock_process
@@ -29,8 +28,8 @@ async def test_clone_repo_with_commit() -> None:
async def test_clone_repo_without_commit() -> None:
query = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo", commit=None, branch="main")
- with patch("gitingest.clone.check_repo_exists", return_value=True) as mock_check:
- with patch("gitingest.clone.run_git_command", new_callable=AsyncMock) as mock_exec:
+ with patch("gitingest.clone._check_repo_exists", return_value=True) as mock_check:
+ with patch("gitingest.clone._run_git_command", new_callable=AsyncMock) as mock_exec:
mock_process = AsyncMock()
mock_process.communicate.return_value = (b"output", b"error")
mock_exec.return_value = mock_process
@@ -48,7 +47,7 @@ async def test_clone_repo_nonexistent_repository() -> None:
commit=None,
branch="main",
)
- with patch("gitingest.clone.check_repo_exists", return_value=False) as mock_check:
+ with patch("gitingest.clone._check_repo_exists", return_value=False) as mock_check:
with pytest.raises(ValueError, match="Repository not found"):
await clone_repo(clone_config)
mock_check.assert_called_once_with(clone_config.url)
@@ -65,13 +64,13 @@ async def test_check_repo_exists() -> None:
# Test existing repository
mock_process.returncode = 0
- assert await check_repo_exists(url) is True
+ assert await _check_repo_exists(url) is True
# Test non-existing repository (404 response)
mock_process.communicate.return_value = (b"HTTP/1.1 404 Not Found\n", b"")
mock_process.returncode = 0
- assert await check_repo_exists(url) is False
+ assert await _check_repo_exists(url) is False
# Test failed request
mock_process.returncode = 1
- assert await check_repo_exists(url) is False
+ assert await _check_repo_exists(url) is False
diff --git a/src/gitingest/tests/test_ingest.py b/src/gitingest/tests/test_ingest.py
index fa8369a7..53257a1e 100644
--- a/src/gitingest/tests/test_ingest.py
+++ b/src/gitingest/tests/test_ingest.py
@@ -3,7 +3,7 @@
import pytest
-from gitingest.ingest_from_query import extract_files_content, scan_directory
+from gitingest.ingest_from_query import _extract_files_content, _scan_directory
# Test fixtures
@@ -74,7 +74,7 @@ def temp_directory(tmp_path: Path) -> Path:
def test_scan_directory(temp_directory: Path, sample_query: dict[str, Any]) -> None:
- result = scan_directory(str(temp_directory), query=sample_query)
+ result = _scan_directory(str(temp_directory), query=sample_query)
if result is None:
assert False, "Result is None"
@@ -85,10 +85,10 @@ def test_scan_directory(temp_directory: Path, sample_query: dict[str, Any]) -> N
def test_extract_files_content(temp_directory: Path, sample_query: dict[str, Any]) -> None:
- nodes = scan_directory(str(temp_directory), query=sample_query)
+ nodes = _scan_directory(str(temp_directory), query=sample_query)
if nodes is None:
assert False, "Nodes is None"
- files = extract_files_content(query=sample_query, node=nodes, max_file_size=1_000_000)
+ files = _extract_files_content(query=sample_query, node=nodes, max_file_size=1_000_000)
assert len(files) == 8 # All .txt and .py files
# Check for presence of key files
diff --git a/src/gitingest/tests/test_parse_query.py b/src/gitingest/tests/test_parse_query.py
index 71ff71ef..b87856d6 100644
--- a/src/gitingest/tests/test_parse_query.py
+++ b/src/gitingest/tests/test_parse_query.py
@@ -1,7 +1,7 @@
import pytest
from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
-from gitingest.parse_query import parse_query, parse_url
+from gitingest.parse_query import _parse_url, parse_query
def test_parse_url_valid() -> None:
@@ -11,7 +11,7 @@ def test_parse_url_valid() -> None:
"https://bitbucket.org/user/repo",
]
for url in test_cases:
- result = parse_url(url)
+ result = _parse_url(url)
assert result["user_name"] == "user"
assert result["repo_name"] == "repo"
assert result["url"] == url
@@ -20,7 +20,7 @@ def test_parse_url_valid() -> None:
def test_parse_url_invalid() -> None:
url = "https://only-domain.com"
with pytest.raises(ValueError, match="Invalid repository URL"):
- parse_url(url)
+ _parse_url(url)
def test_parse_query_basic() -> None:
diff --git a/src/gitingest/utils.py b/src/gitingest/utils.py
index 8406d5cd..82b8e303 100644
--- a/src/gitingest/utils.py
+++ b/src/gitingest/utils.py
@@ -1,4 +1,3 @@
-## Async Timeout decorator
import asyncio
import functools
from collections.abc import Awaitable, Callable
@@ -13,6 +12,7 @@ class AsyncTimeoutError(Exception):
def async_timeout(seconds: int = 10) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
+ # Async Timeout decorator
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
diff --git a/src/main.py b/src/main.py
index 18de770c..d00367b8 100644
--- a/src/main.py
+++ b/src/main.py
@@ -13,65 +13,124 @@
from routers import download, dynamic, index
from server_utils import limiter
+# Load environment variables from .env file
load_dotenv()
+# Initialize the FastAPI application
app = FastAPI()
app.state.limiter = limiter
-# Define a wrapper handler with the correct signature
async def rate_limit_exception_handler(request: Request, exc: Exception) -> Response:
+ """
+ Custom exception handler for rate-limiting errors.
+
+ Parameters
+ ----------
+ request : Request
+ The incoming HTTP request.
+ exc : Exception
+ The exception raised, expected to be RateLimitExceeded.
+
+ Returns
+ -------
+ Response
+ A response indicating that the rate limit has been exceeded.
+ """
if isinstance(exc, RateLimitExceeded):
- # Delegate to the actual handler
+ # Delegate to the default rate limit handler
return _rate_limit_exceeded_handler(request, exc)
- # Optionally, handle other exceptions or re-raise
+ # Re-raise other exceptions
raise exc
-# Register the wrapper handler
+# Register the custom exception handler for rate limits
app.add_exception_handler(RateLimitExceeded, rate_limit_exception_handler)
+# Mount static files to serve CSS, JS, and other static assets
app.mount("/static", StaticFiles(directory="static"), name="static")
-app_analytics_key = os.getenv("API_ANALYTICS_KEY")
-if app_analytics_key:
- app.add_middleware(Analytics, api_key=app_analytics_key)
-# Define the default allowed hosts
-default_allowed_hosts = ["gitingest.com", "*.gitingest.com", "localhost", "127.0.0.1"]
+# Set up API analytics middleware if an API key is provided
+if app_analytics_key := os.getenv("API_ANALYTICS_KEY"):
+ app.add_middleware(Analytics, api_key=app_analytics_key)
-# Fetch allowed hosts from the environment variable or use the default
+# Fetch allowed hosts from the environment or use the default values
allowed_hosts = os.getenv("ALLOWED_HOSTS")
if allowed_hosts:
allowed_hosts = allowed_hosts.split(",")
else:
+ # Define the default allowed hosts for the application
+ default_allowed_hosts = ["gitingest.com", "*.gitingest.com", "localhost", "127.0.0.1"]
allowed_hosts = default_allowed_hosts
+# Add middleware to enforce allowed hosts
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts)
+
+# Set up template rendering
templates = Jinja2Templates(directory="templates")
@app.get("/health")
async def health_check() -> dict[str, str]:
+ """
+ Health check endpoint to verify that the server is running.
+
+ Returns
+ -------
+ dict[str, str]
+ A JSON object with a "status" key indicating the server's health status.
+ """
return {"status": "healthy"}
@app.head("/")
async def head_root() -> HTMLResponse:
- """Mirror the headers and status code of the index page"""
+ """
+ Respond to HTTP HEAD requests for the root URL.
+
+ Mirrors the headers and status code of the index page.
+
+ Returns
+ -------
+ HTMLResponse
+ An empty HTML response with appropriate headers.
+ """
return HTMLResponse(content=None, headers={"content-type": "text/html; charset=utf-8"})
@app.get("/api/", response_class=HTMLResponse)
@app.get("/api", response_class=HTMLResponse)
async def api_docs(request: Request) -> HTMLResponse:
+ """
+ Render the API documentation page.
+
+ Parameters
+ ----------
+ request : Request
+ The incoming HTTP request.
+
+ Returns
+ -------
+ HTMLResponse
+ A rendered HTML page displaying API documentation.
+ """
return templates.TemplateResponse("api.jinja", {"request": request})
@app.get("/robots.txt")
async def robots() -> FileResponse:
+ """
+ Serve the `robots.txt` file to guide search engine crawlers.
+
+ Returns
+ -------
+ FileResponse
+ The `robots.txt` file located in the static directory.
+ """
return FileResponse("static/robots.txt")
+# Include routers for modular endpoints
app.include_router(index)
app.include_router(download)
app.include_router(dynamic)
diff --git a/src/process_query.py b/src/process_query.py
index f55068cb..470b675b 100644
--- a/src/process_query.py
+++ b/src/process_query.py
@@ -12,6 +12,21 @@
def print_query(url: str, max_file_size: int, pattern_type: str, pattern: str) -> None:
+ """
+ Print a formatted summary of the query details, including the URL, file size,
+ and pattern information, for easier debugging or logging.
+
+ Parameters
+ ----------
+ url : str
+ The URL associated with the query.
+ max_file_size : int
+ The maximum file size allowed for the query, in bytes.
+ pattern_type : str
+ Specifies the type of pattern to use, either "include" or "exclude".
+ pattern : str
+ The actual pattern string to include or exclude in the query.
+ """
print(f"{Colors.WHITE}{url:<20}{Colors.END}", end="")
if int(max_file_size / 1024) != 50:
print(f" | {Colors.YELLOW}Size: {int(max_file_size/1024)}kb{Colors.END}", end="")
@@ -22,12 +37,46 @@ def print_query(url: str, max_file_size: int, pattern_type: str, pattern: str) -
def print_error(url: str, e: Exception, max_file_size: int, pattern_type: str, pattern: str) -> None:
+ """
+ Print a formatted error message including the URL, file size, pattern details, and the exception encountered,
+ for debugging or logging purposes.
+
+ Parameters
+ ----------
+ url : str
+ The URL associated with the query that caused the error.
+ e : Exception
+ The exception raised during the query or process.
+ max_file_size : int
+ The maximum file size allowed for the query, in bytes.
+ pattern_type : str
+ Specifies the type of pattern to use, either "include" or "exclude".
+ pattern : str
+ The actual pattern string to include or exclude in the query.
+ """
print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="")
print_query(url, max_file_size, pattern_type, pattern)
print(f" | {Colors.RED}{e}{Colors.END}")
def print_success(url: str, max_file_size: int, pattern_type: str, pattern: str, summary: str) -> None:
+ """
+ Print a formatted success message, including the URL, file size, pattern details, and a summary with estimated
+ tokens, for debugging or logging purposes.
+
+ Parameters
+ ----------
+ url : str
+ The URL associated with the successful query.
+ max_file_size : int
+ The maximum file size allowed for the query, in bytes.
+ pattern_type : str
+ Specifies the type of pattern to use, either "include" or "exclude".
+ pattern : str
+ The actual pattern string to include or exclude in the query.
+ summary : str
+ A summary of the query result, including details like estimated tokens.
+ """
estimated_tokens = summary[summary.index("Estimated tokens:") + len("Estimated ") :]
print(f"{Colors.GREEN}INFO{Colors.END}: {Colors.GREEN}<- {Colors.END}", end="")
print_query(url, max_file_size, pattern_type, pattern)
@@ -42,6 +91,32 @@ async def process_query(
pattern: str = "",
is_index: bool = False,
) -> _TemplateResponse:
+ """
+ Process a query by parsing input, cloning a repository, and generating a summary.
+
+ Handle user input, process GitHub repository data, and prepare
+ a response for rendering a template with the processed results or an error message.
+
+ Parameters
+ ----------
+ request : Request
+ The HTTP request object.
+ input_text : str
+ Input text provided by the user, typically a GitHub repository URL or slug.
+ slider_position : int
+ Position of the slider, representing the maximum file size in the query.
+ pattern_type : str, optional
+ Type of pattern to use, either "include" or "exclude" (default is "exclude").
+ pattern : str, optional
+ Pattern to include or exclude in the query, depending on the pattern type.
+ is_index : bool, optional
+ Flag indicating whether the request is for the index page (default is False).
+
+ Returns
+ -------
+ _TemplateResponse
+ Rendered template response containing the processed results or an error message.
+ """
template = "index.jinja" if is_index else "github.jinja"
max_file_size = logSliderToSize(slider_position)
diff --git a/src/server_utils.py b/src/server_utils.py
index 2a6e186f..7af4b854 100644
--- a/src/server_utils.py
+++ b/src/server_utils.py
@@ -1,19 +1,29 @@
import math
-## Rate Limiter
from slowapi import Limiter
from slowapi.util import get_remote_address
+# Initialize a rate limiter
limiter = Limiter(key_func=get_remote_address)
-## Logarithmic slider to file size conversion
def logSliderToSize(position: int) -> int:
- """Convert slider position to file size in KB"""
+ """
+ Convert a slider position to a file size in bytes using a logarithmic scale.
+
+ Parameters
+ ----------
+ position : int
+ Slider position ranging from 0 to 500.
+
+ Returns
+ -------
+ int
+ File size in bytes corresponding to the slider position.
+ """
maxp = 500
minv = math.log(1)
- maxv = math.log(102400)
-
+ maxv = math.log(102_400)
return round(math.exp(minv + (maxv - minv) * pow(position / maxp, 1.5))) * 1024