diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 68eedb48..9dcf517a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -76,3 +76,10 @@ repos: rev: v1.36.4 hooks: - id: djlint-reformat-jinja + + - repo: https://github.com/igorshubovych/markdownlint-cli + rev: v0.43.0 + hooks: + - id: markdownlint + description: "Lint markdown files." + args: ["--disable=line-length"] diff --git a/README.md b/README.md index 6d0747a2..7e02c467 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,13 @@ -[![Image](./docs/frontpage.png "GitIngest main page")](https://gitingest.com) +# GitIngest - - - License - - - - PyPI version - - - - Downloads - - - - GitHub issues - - - - Code style: black - - - - - Discord - +[![License](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/cyclotruc/gitingest/blob/main/LICENSE) +[![PyPI version](https://badge.fury.io/py/gitingest.svg)](https://badge.fury.io/py/gitingest) +[![Downloads](https://pepy.tech/badge/gitingest)](https://pepy.tech/project/gitingest) +[![GitHub issues](https://img.shields.io/github/issues/cyclotruc/gitingest)](https://github.com/cyclotruc/gitingest/issues) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Discord](https://dcbadge.limes.pink/api/server/https://discord.com/invite/zerRaGK9EC)](https://discord.com/invite/zerRaGK9EC) -# GitIngest +[![Image](./docs/frontpage.png "GitIngest main page")](https://gitingest.com) Turn any Git repository into a prompt-friendly text ingest for LLMs. @@ -92,15 +73,15 @@ By default, this won't write a file but can be enabled with the `output` argumen 1. Build the image: -``` bash -docker build -t gitingest . -``` + ``` bash + docker build -t gitingest . + ``` 2. Run the container: -``` bash -docker run -d --name gitingest -p 8000:8000 gitingest -``` + ``` bash + docker run -d --name gitingest -p 8000:8000 gitingest + ``` The application will be available at `http://localhost:8000` Ensure environment variables are set before running the application or deploying it via Docker. @@ -135,22 +116,20 @@ ALLOWED_HOSTS="gitingest.local,localhost" 1. Clone the repository -```bash -git clone https://github.com/cyclotruc/gitingest.git -cd gitingest -``` + ```bash + git clone https://github.com/cyclotruc/gitingest.git + cd gitingest + ``` 2. Install dependencies -```bash -pip install -r requirements.txt -``` + ```bash + pip install -r requirements.txt + ``` 3. Run the application: -```bash -cd src -uvicorn main:app --reload -``` - -The frontend will be available at `localhost:8000` + ```bash + cd src + uvicorn main:app --reload + ``` diff --git a/src/config.py b/src/config.py index b918fb2a..8da41da8 100644 --- a/src/config.py +++ b/src/config.py @@ -1,7 +1,7 @@ -MAX_DISPLAY_SIZE = 300_000 -TMP_BASE_PATH = "../tmp" +MAX_DISPLAY_SIZE: int = 300_000 +TMP_BASE_PATH: str = "../tmp" -EXAMPLE_REPOS = [ +EXAMPLE_REPOS: list[dict[str, str]] = [ {"name": "Gitingest", "url": "https://github.com/cyclotruc/gitingest"}, {"name": "FastAPI", "url": "https://github.com/tiangolo/fastapi"}, {"name": "Flask", "url": "https://github.com/pallets/flask"}, diff --git a/src/gitingest/cli.py b/src/gitingest/cli.py index c5f8a493..f275efac 100644 --- a/src/gitingest/cli.py +++ b/src/gitingest/cli.py @@ -1,19 +1,9 @@ -import os - import click from gitingest.ingest import ingest from gitingest.ingest_from_query import MAX_FILE_SIZE -def normalize_pattern(pattern: str) -> str: - pattern = pattern.strip() - pattern = pattern.lstrip(os.sep) - if pattern.endswith(os.sep): - pattern += "*" - return pattern - - @click.command() @click.argument("source", type=str, required=True) @click.option("--output", "-o", default=None, help="Output file path (default: .txt in current directory)") diff --git a/src/gitingest/clone.py b/src/gitingest/clone.py index 97a990a2..a91b6a99 100644 --- a/src/gitingest/clone.py +++ b/src/gitingest/clone.py @@ -3,7 +3,7 @@ from gitingest.utils import AsyncTimeoutError, async_timeout -CLONE_TIMEOUT = 20 +CLONE_TIMEOUT: int = 20 @dataclass @@ -14,67 +14,6 @@ class CloneConfig: branch: str | None = None -async def check_repo_exists(url: str) -> bool: - """ - Check if a repository exists at the given URL using an HTTP HEAD request. - - Parameters - ---------- - url : str - The URL of the repository. - - Returns - ------- - bool - True if the repository exists, False otherwise. - """ - proc = await asyncio.create_subprocess_exec( - "curl", - "-I", - url, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, _ = await proc.communicate() - if proc.returncode != 0: - return False - # Check if stdout contains "404" status code - stdout_str = stdout.decode() - return "HTTP/1.1 404" not in stdout_str and "HTTP/2 404" not in stdout_str - - -async def run_git_command(*args: str) -> tuple[bytes, bytes]: - """ - Executes a git command asynchronously and captures its output. - - Parameters - ---------- - *args : str - The git command and its arguments to execute. - - Returns - ------- - Tuple[bytes, bytes] - A tuple containing the stdout and stderr of the git command. - - Raises - ------ - RuntimeError - If the git command exits with a non-zero status. - """ - proc = await asyncio.create_subprocess_exec( - *args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await proc.communicate() - if proc.returncode != 0: - error_message = stderr.decode().strip() - raise RuntimeError(f"Git command failed: {' '.join(args)}\nError: {error_message}") - - return stdout, stderr - - @async_timeout(CLONE_TIMEOUT) async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]: """ @@ -116,7 +55,7 @@ async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]: raise ValueError("The 'local_path' parameter is required.") # Check if the repository exists - if not await check_repo_exists(url): + if not await _check_repo_exists(url): raise ValueError("Repository not found, make sure it is public") try: @@ -124,21 +63,82 @@ async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]: # Scenario 1: Clone and checkout a specific commit # Clone the repository without depth to ensure full history for checkout clone_cmd = ["git", "clone", "--single-branch", url, local_path] - await run_git_command(*clone_cmd) + await _run_git_command(*clone_cmd) # Checkout the specific commit checkout_cmd = ["git", "-C", local_path, "checkout", commit] - return await run_git_command(*checkout_cmd) + return await _run_git_command(*checkout_cmd) if branch and branch.lower() not in ("main", "master"): # Scenario 2: Clone a specific branch with shallow depth clone_cmd = ["git", "clone", "--depth=1", "--single-branch", "--branch", branch, url, local_path] - return await run_git_command(*clone_cmd) + return await _run_git_command(*clone_cmd) # Scenario 3: Clone the default branch with shallow depth clone_cmd = ["git", "clone", "--depth=1", "--single-branch", url, local_path] - return await run_git_command(*clone_cmd) + return await _run_git_command(*clone_cmd) except (RuntimeError, asyncio.TimeoutError, AsyncTimeoutError): raise # Re-raise the exception + + +async def _check_repo_exists(url: str) -> bool: + """ + Check if a repository exists at the given URL using an HTTP HEAD request. + + Parameters + ---------- + url : str + The URL of the repository. + + Returns + ------- + bool + True if the repository exists, False otherwise. + """ + proc = await asyncio.create_subprocess_exec( + "curl", + "-I", + url, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await proc.communicate() + if proc.returncode != 0: + return False + # Check if stdout contains "404" status code + stdout_str = stdout.decode() + return "HTTP/1.1 404" not in stdout_str and "HTTP/2 404" not in stdout_str + + +async def _run_git_command(*args: str) -> tuple[bytes, bytes]: + """ + Executes a git command asynchronously and captures its output. + + Parameters + ---------- + *args : str + The git command and its arguments to execute. + + Returns + ------- + Tuple[bytes, bytes] + A tuple containing the stdout and stderr of the git command. + + Raises + ------ + RuntimeError + If the git command exits with a non-zero status. + """ + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + error_message = stderr.decode().strip() + raise RuntimeError(f"Git command failed: {' '.join(args)}\nError: {error_message}") + + return stdout, stderr diff --git a/src/gitingest/ingest_from_query.py b/src/gitingest/ingest_from_query.py index 51cca8d2..886afa26 100644 --- a/src/gitingest/ingest_from_query.py +++ b/src/gitingest/ingest_from_query.py @@ -10,7 +10,7 @@ MAX_TOTAL_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB -def should_include(path: str, base_path: str, include_patterns: list[str]) -> bool: +def _should_include(path: str, base_path: str, include_patterns: list[str]) -> bool: rel_path = path.replace(base_path, "").lstrip(os.sep) include = False for pattern in include_patterns: @@ -19,17 +19,15 @@ def should_include(path: str, base_path: str, include_patterns: list[str]) -> bo return include -def should_exclude(path: str, base_path: str, ignore_patterns: list[str]) -> bool: +def _should_exclude(path: str, base_path: str, ignore_patterns: list[str]) -> bool: rel_path = path.replace(base_path, "").lstrip(os.sep) for pattern in ignore_patterns: - if pattern == "": - continue - if fnmatch(rel_path, pattern): + if pattern and fnmatch(rel_path, pattern): return True return False -def is_safe_symlink(symlink_path: str, base_path: str) -> bool: +def _is_safe_symlink(symlink_path: str, base_path: str) -> bool: """Check if a symlink points to a location within the base directory.""" try: target_path = os.path.realpath(symlink_path) @@ -40,7 +38,7 @@ def is_safe_symlink(symlink_path: str, base_path: str) -> bool: return False -def is_text_file(file_path: str) -> bool: +def _is_text_file(file_path: str) -> bool: """Determines if a file is likely a text file based on its content.""" try: with open(file_path, "rb") as file: @@ -50,7 +48,7 @@ def is_text_file(file_path: str) -> bool: return False -def read_file_content(file_path: str) -> str: +def _read_file_content(file_path: str) -> str: try: with open(file_path, encoding="utf-8", errors="ignore") as f: return f.read() @@ -58,7 +56,7 @@ def read_file_content(file_path: str) -> str: return f"Error reading file: {str(e)}" -def scan_directory( +def _scan_directory( path: str, query: dict[str, Any], seen_paths: set[str] | None = None, @@ -68,6 +66,7 @@ def scan_directory( """Recursively analyzes a directory and its contents with safety limits.""" if seen_paths is None: seen_paths = set() + if stats is None: stats = {"total_files": 0, "total_size": 0} @@ -109,18 +108,18 @@ def scan_directory( for item in os.listdir(path): item_path = os.path.join(path, item) - if should_exclude(item_path, base_path, ignore_patterns): + if _should_exclude(item_path, base_path, ignore_patterns): continue is_file = os.path.isfile(item_path) if is_file and query["include_patterns"]: - if not should_include(item_path, base_path, include_patterns): + if not _should_include(item_path, base_path, include_patterns): result["ignore_content"] = True continue # Handle symlinks if os.path.islink(item_path): - if not is_safe_symlink(item_path, base_path): + if not _is_safe_symlink(item_path, base_path): print(f"Skipping symlink that points outside base directory: {item_path}") continue real_path = os.path.realpath(item_path) @@ -141,8 +140,8 @@ def scan_directory( print(f"Maximum file limit ({MAX_FILES}) reached") return result - is_text = is_text_file(real_path) - content = read_file_content(real_path) if is_text else "[Non-text file]" + is_text = _is_text_file(real_path) + content = _read_file_content(real_path) if is_text else "[Non-text file]" child = { "name": item, @@ -156,7 +155,7 @@ def scan_directory( result["file_count"] += 1 elif os.path.isdir(real_path): - subdir = scan_directory( + subdir = _scan_directory( path=real_path, query=query, seen_paths=seen_paths, @@ -185,8 +184,8 @@ def scan_directory( print(f"Maximum file limit ({MAX_FILES}) reached") return result - is_text = is_text_file(item_path) - content = read_file_content(item_path) if is_text else "[Non-text file]" + is_text = _is_text_file(item_path) + content = _read_file_content(item_path) if is_text else "[Non-text file]" child = { "name": item, @@ -200,7 +199,7 @@ def scan_directory( result["file_count"] += 1 elif os.path.isdir(item_path): - subdir = scan_directory( + subdir = _scan_directory( path=item_path, query=query, seen_paths=seen_paths, @@ -219,7 +218,7 @@ def scan_directory( return result -def extract_files_content( +def _extract_files_content( query: dict[str, Any], node: dict[str, Any], max_file_size: int, @@ -243,12 +242,12 @@ def extract_files_content( ) elif node["type"] == "directory": for child in node["children"]: - extract_files_content(query=query, node=child, max_file_size=max_file_size, files=files) + _extract_files_content(query=query, node=child, max_file_size=max_file_size, files=files) return files -def create_file_content_string(files: list[dict[str, Any]]) -> str: +def _create_file_content_string(files: list[dict[str, Any]]) -> str: """Creates a formatted string of file contents with separators.""" output = "" separator = "=" * 48 + "\n" @@ -278,7 +277,7 @@ def create_file_content_string(files: list[dict[str, Any]]) -> str: return output -def create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str: +def _create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str: """Creates a summary string with file counts and content size.""" if "user_name" in query: summary = f"Repository: {query['user_name']}/{query['repo_name']}\n" @@ -297,7 +296,7 @@ def create_summary_string(query: dict[str, Any], nodes: dict[str, Any]) -> str: return summary -def create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: str = "", is_last: bool = True) -> str: +def _create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: str = "", is_last: bool = True) -> str: """Creates a tree-like string representation of the file structure.""" tree = "" @@ -314,12 +313,12 @@ def create_tree_structure(query: dict[str, Any], node: dict[str, Any], prefix: s new_prefix = prefix + (" " if is_last else "│ ") if node["name"] else prefix children = node["children"] for i, child in enumerate(children): - tree += create_tree_structure(query, child, new_prefix, i == len(children) - 1) + tree += _create_tree_structure(query, child, new_prefix, i == len(children) - 1) return tree -def generate_token_string(context_string: str) -> str | None: +def _generate_token_string(context_string: str) -> str | None: """Returns the number of tokens in a text string.""" formatted_tokens = "" try: @@ -340,16 +339,16 @@ def generate_token_string(context_string: str) -> str | None: return formatted_tokens -def ingest_single_file(path: str, query: dict[str, Any]) -> tuple[str, str, str]: +def _ingest_single_file(path: str, query: dict[str, Any]) -> tuple[str, str, str]: if not os.path.isfile(path): raise ValueError(f"Path {path} is not a file") file_size = os.path.getsize(path) - is_text = is_text_file(path) + is_text = _is_text_file(path) if not is_text: raise ValueError(f"File {path} is not a text file") - content = read_file_content(path) + content = _read_file_content(path) if file_size > query["max_file_size"]: content = "[Content ignored: file too large]" @@ -366,26 +365,26 @@ def ingest_single_file(path: str, query: dict[str, Any]) -> tuple[str, str, str] f"Lines: {len(content.splitlines()):,}\n" ) - files_content = create_file_content_string([file_info]) + files_content = _create_file_content_string([file_info]) tree = "Directory structure:\n└── " + os.path.basename(path) - formatted_tokens = generate_token_string(files_content) + formatted_tokens = _generate_token_string(files_content) if formatted_tokens: summary += f"\nEstimated tokens: {formatted_tokens}" return summary, tree, files_content -def ingest_directory(path: str, query: dict[str, Any]) -> tuple[str, str, str]: - nodes = scan_directory(path=path, query=query) +def _ingest_directory(path: str, query: dict[str, Any]) -> tuple[str, str, str]: + nodes = _scan_directory(path=path, query=query) if not nodes: raise ValueError(f"No files found in {path}") - files = extract_files_content(query=query, node=nodes, max_file_size=query["max_file_size"]) - summary = create_summary_string(query, nodes) - tree = "Directory structure:\n" + create_tree_structure(query, nodes) - files_content = create_file_content_string(files) + files = _extract_files_content(query=query, node=nodes, max_file_size=query["max_file_size"]) + summary = _create_summary_string(query, nodes) + tree = "Directory structure:\n" + _create_tree_structure(query, nodes) + files_content = _create_file_content_string(files) - formatted_tokens = generate_token_string(tree + files_content) + formatted_tokens = _generate_token_string(tree + files_content) if formatted_tokens: summary += f"\nEstimated tokens: {formatted_tokens}" @@ -394,11 +393,11 @@ def ingest_directory(path: str, query: dict[str, Any]) -> tuple[str, str, str]: def ingest_from_query(query: dict[str, Any]) -> tuple[str, str, str]: """Main entry point for analyzing a codebase directory or single file.""" - path = os.path.join(query["local_path"], query["subpath"].lstrip(os.sep)) - if not os.path.exists(path) and not os.path.exists(os.path.dirname(path)): - raise ValueError(f"{query['subpath']} cannot be found") + path = f"{query['local_path']}{query['subpath']}" + if not os.path.exists(path): + raise ValueError(f"{query['slug']} cannot be found") if query.get("type") == "blob": - return ingest_single_file(path, query) + return _ingest_single_file(path, query) - return ingest_directory(path, query) + return _ingest_directory(path, query) diff --git a/src/gitingest/parse_query.py b/src/gitingest/parse_query.py index 2e6470e1..477520a7 100644 --- a/src/gitingest/parse_query.py +++ b/src/gitingest/parse_query.py @@ -6,11 +6,11 @@ from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS -TMP_BASE_PATH = "../tmp" +TMP_BASE_PATH: str = "../tmp" HEX_DIGITS = set(string.hexdigits) -def parse_url(url: str) -> dict[str, Any]: +def _parse_url(url: str) -> dict[str, Any]: url = url.split(" ")[0] url = unquote(url) # Decode URL-encoded characters @@ -69,14 +69,14 @@ def _is_valid_git_commit_hash(commit: str) -> bool: return len(commit) == 40 and all(c in HEX_DIGITS for c in commit) -def normalize_pattern(pattern: str) -> str: +def _normalize_pattern(pattern: str) -> str: pattern = pattern.lstrip(os.sep) if pattern.endswith(os.sep): pattern += "*" return pattern -def parse_patterns(pattern: list[str] | str) -> list[str]: +def _parse_patterns(pattern: list[str] | str) -> list[str]: patterns = pattern if isinstance(pattern, list) else [pattern] patterns = [p.strip() for p in patterns] @@ -87,10 +87,10 @@ def parse_patterns(pattern: list[str] | str) -> list[str]: "underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*) are allowed." ) - return [normalize_pattern(p) for p in patterns] + return [_normalize_pattern(p) for p in patterns] -def override_ignore_patterns(ignore_patterns: list[str], include_patterns: list[str]) -> list[str]: +def _override_ignore_patterns(ignore_patterns: list[str], include_patterns: list[str]) -> list[str]: """ Removes patterns from ignore_patterns that are present in include_patterns using set difference. @@ -109,7 +109,7 @@ def override_ignore_patterns(ignore_patterns: list[str], include_patterns: list[ return list(set(ignore_patterns) - set(include_patterns)) -def parse_path(path: str) -> dict[str, Any]: +def _parse_path(path: str) -> dict[str, Any]: query = { "url": None, "local_path": os.path.abspath(path), @@ -151,19 +151,19 @@ def parse_query( """ # Determine the parsing method based on the source type if from_web or source.startswith("https://") or "github.com" in source: - query = parse_url(source) + query = _parse_url(source) else: - query = parse_path(source) + query = _parse_path(source) # Process ignore patterns ignore_patterns_list = DEFAULT_IGNORE_PATTERNS.copy() if ignore_patterns: - ignore_patterns_list += parse_patterns(ignore_patterns) + ignore_patterns_list += _parse_patterns(ignore_patterns) # Process include patterns and override ignore patterns accordingly if include_patterns: - parsed_include = parse_patterns(include_patterns) - ignore_patterns_list = override_ignore_patterns(ignore_patterns_list, include_patterns=parsed_include) + parsed_include = _parse_patterns(include_patterns) + ignore_patterns_list = _override_ignore_patterns(ignore_patterns_list, include_patterns=parsed_include) else: parsed_include = None diff --git a/src/gitingest/tests/test_clone.py b/src/gitingest/tests/test_clone.py index 585ba6eb..e3b81289 100644 --- a/src/gitingest/tests/test_clone.py +++ b/src/gitingest/tests/test_clone.py @@ -2,7 +2,7 @@ import pytest -from gitingest.clone import CloneConfig, check_repo_exists, clone_repo +from gitingest.clone import CloneConfig, _check_repo_exists, clone_repo @pytest.mark.asyncio @@ -14,9 +14,8 @@ async def test_clone_repo_with_commit() -> None: branch="main", ) - with patch("gitingest.clone.check_repo_exists", return_value=True) as mock_check: - with patch("gitingest.clone.run_git_command", new_callable=AsyncMock) as mock_exec: - + with patch("gitingest.clone._check_repo_exists", return_value=True) as mock_check: + with patch("gitingest.clone._run_git_command", new_callable=AsyncMock) as mock_exec: mock_process = AsyncMock() mock_process.communicate.return_value = (b"output", b"error") mock_exec.return_value = mock_process @@ -29,8 +28,8 @@ async def test_clone_repo_with_commit() -> None: async def test_clone_repo_without_commit() -> None: query = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo", commit=None, branch="main") - with patch("gitingest.clone.check_repo_exists", return_value=True) as mock_check: - with patch("gitingest.clone.run_git_command", new_callable=AsyncMock) as mock_exec: + with patch("gitingest.clone._check_repo_exists", return_value=True) as mock_check: + with patch("gitingest.clone._run_git_command", new_callable=AsyncMock) as mock_exec: mock_process = AsyncMock() mock_process.communicate.return_value = (b"output", b"error") mock_exec.return_value = mock_process @@ -48,7 +47,7 @@ async def test_clone_repo_nonexistent_repository() -> None: commit=None, branch="main", ) - with patch("gitingest.clone.check_repo_exists", return_value=False) as mock_check: + with patch("gitingest.clone._check_repo_exists", return_value=False) as mock_check: with pytest.raises(ValueError, match="Repository not found"): await clone_repo(clone_config) mock_check.assert_called_once_with(clone_config.url) @@ -65,13 +64,13 @@ async def test_check_repo_exists() -> None: # Test existing repository mock_process.returncode = 0 - assert await check_repo_exists(url) is True + assert await _check_repo_exists(url) is True # Test non-existing repository (404 response) mock_process.communicate.return_value = (b"HTTP/1.1 404 Not Found\n", b"") mock_process.returncode = 0 - assert await check_repo_exists(url) is False + assert await _check_repo_exists(url) is False # Test failed request mock_process.returncode = 1 - assert await check_repo_exists(url) is False + assert await _check_repo_exists(url) is False diff --git a/src/gitingest/tests/test_ingest.py b/src/gitingest/tests/test_ingest.py index fa8369a7..53257a1e 100644 --- a/src/gitingest/tests/test_ingest.py +++ b/src/gitingest/tests/test_ingest.py @@ -3,7 +3,7 @@ import pytest -from gitingest.ingest_from_query import extract_files_content, scan_directory +from gitingest.ingest_from_query import _extract_files_content, _scan_directory # Test fixtures @@ -74,7 +74,7 @@ def temp_directory(tmp_path: Path) -> Path: def test_scan_directory(temp_directory: Path, sample_query: dict[str, Any]) -> None: - result = scan_directory(str(temp_directory), query=sample_query) + result = _scan_directory(str(temp_directory), query=sample_query) if result is None: assert False, "Result is None" @@ -85,10 +85,10 @@ def test_scan_directory(temp_directory: Path, sample_query: dict[str, Any]) -> N def test_extract_files_content(temp_directory: Path, sample_query: dict[str, Any]) -> None: - nodes = scan_directory(str(temp_directory), query=sample_query) + nodes = _scan_directory(str(temp_directory), query=sample_query) if nodes is None: assert False, "Nodes is None" - files = extract_files_content(query=sample_query, node=nodes, max_file_size=1_000_000) + files = _extract_files_content(query=sample_query, node=nodes, max_file_size=1_000_000) assert len(files) == 8 # All .txt and .py files # Check for presence of key files diff --git a/src/gitingest/tests/test_parse_query.py b/src/gitingest/tests/test_parse_query.py index 71ff71ef..b87856d6 100644 --- a/src/gitingest/tests/test_parse_query.py +++ b/src/gitingest/tests/test_parse_query.py @@ -1,7 +1,7 @@ import pytest from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS -from gitingest.parse_query import parse_query, parse_url +from gitingest.parse_query import _parse_url, parse_query def test_parse_url_valid() -> None: @@ -11,7 +11,7 @@ def test_parse_url_valid() -> None: "https://bitbucket.org/user/repo", ] for url in test_cases: - result = parse_url(url) + result = _parse_url(url) assert result["user_name"] == "user" assert result["repo_name"] == "repo" assert result["url"] == url @@ -20,7 +20,7 @@ def test_parse_url_valid() -> None: def test_parse_url_invalid() -> None: url = "https://only-domain.com" with pytest.raises(ValueError, match="Invalid repository URL"): - parse_url(url) + _parse_url(url) def test_parse_query_basic() -> None: diff --git a/src/gitingest/utils.py b/src/gitingest/utils.py index 8406d5cd..82b8e303 100644 --- a/src/gitingest/utils.py +++ b/src/gitingest/utils.py @@ -1,4 +1,3 @@ -## Async Timeout decorator import asyncio import functools from collections.abc import Awaitable, Callable @@ -13,6 +12,7 @@ class AsyncTimeoutError(Exception): def async_timeout(seconds: int = 10) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: + # Async Timeout decorator def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: diff --git a/src/main.py b/src/main.py index 18de770c..d00367b8 100644 --- a/src/main.py +++ b/src/main.py @@ -13,65 +13,124 @@ from routers import download, dynamic, index from server_utils import limiter +# Load environment variables from .env file load_dotenv() +# Initialize the FastAPI application app = FastAPI() app.state.limiter = limiter -# Define a wrapper handler with the correct signature async def rate_limit_exception_handler(request: Request, exc: Exception) -> Response: + """ + Custom exception handler for rate-limiting errors. + + Parameters + ---------- + request : Request + The incoming HTTP request. + exc : Exception + The exception raised, expected to be RateLimitExceeded. + + Returns + ------- + Response + A response indicating that the rate limit has been exceeded. + """ if isinstance(exc, RateLimitExceeded): - # Delegate to the actual handler + # Delegate to the default rate limit handler return _rate_limit_exceeded_handler(request, exc) - # Optionally, handle other exceptions or re-raise + # Re-raise other exceptions raise exc -# Register the wrapper handler +# Register the custom exception handler for rate limits app.add_exception_handler(RateLimitExceeded, rate_limit_exception_handler) +# Mount static files to serve CSS, JS, and other static assets app.mount("/static", StaticFiles(directory="static"), name="static") -app_analytics_key = os.getenv("API_ANALYTICS_KEY") -if app_analytics_key: - app.add_middleware(Analytics, api_key=app_analytics_key) -# Define the default allowed hosts -default_allowed_hosts = ["gitingest.com", "*.gitingest.com", "localhost", "127.0.0.1"] +# Set up API analytics middleware if an API key is provided +if app_analytics_key := os.getenv("API_ANALYTICS_KEY"): + app.add_middleware(Analytics, api_key=app_analytics_key) -# Fetch allowed hosts from the environment variable or use the default +# Fetch allowed hosts from the environment or use the default values allowed_hosts = os.getenv("ALLOWED_HOSTS") if allowed_hosts: allowed_hosts = allowed_hosts.split(",") else: + # Define the default allowed hosts for the application + default_allowed_hosts = ["gitingest.com", "*.gitingest.com", "localhost", "127.0.0.1"] allowed_hosts = default_allowed_hosts +# Add middleware to enforce allowed hosts app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts) + +# Set up template rendering templates = Jinja2Templates(directory="templates") @app.get("/health") async def health_check() -> dict[str, str]: + """ + Health check endpoint to verify that the server is running. + + Returns + ------- + dict[str, str] + A JSON object with a "status" key indicating the server's health status. + """ return {"status": "healthy"} @app.head("/") async def head_root() -> HTMLResponse: - """Mirror the headers and status code of the index page""" + """ + Respond to HTTP HEAD requests for the root URL. + + Mirrors the headers and status code of the index page. + + Returns + ------- + HTMLResponse + An empty HTML response with appropriate headers. + """ return HTMLResponse(content=None, headers={"content-type": "text/html; charset=utf-8"}) @app.get("/api/", response_class=HTMLResponse) @app.get("/api", response_class=HTMLResponse) async def api_docs(request: Request) -> HTMLResponse: + """ + Render the API documentation page. + + Parameters + ---------- + request : Request + The incoming HTTP request. + + Returns + ------- + HTMLResponse + A rendered HTML page displaying API documentation. + """ return templates.TemplateResponse("api.jinja", {"request": request}) @app.get("/robots.txt") async def robots() -> FileResponse: + """ + Serve the `robots.txt` file to guide search engine crawlers. + + Returns + ------- + FileResponse + The `robots.txt` file located in the static directory. + """ return FileResponse("static/robots.txt") +# Include routers for modular endpoints app.include_router(index) app.include_router(download) app.include_router(dynamic) diff --git a/src/process_query.py b/src/process_query.py index f55068cb..470b675b 100644 --- a/src/process_query.py +++ b/src/process_query.py @@ -12,6 +12,21 @@ def print_query(url: str, max_file_size: int, pattern_type: str, pattern: str) -> None: + """ + Print a formatted summary of the query details, including the URL, file size, + and pattern information, for easier debugging or logging. + + Parameters + ---------- + url : str + The URL associated with the query. + max_file_size : int + The maximum file size allowed for the query, in bytes. + pattern_type : str + Specifies the type of pattern to use, either "include" or "exclude". + pattern : str + The actual pattern string to include or exclude in the query. + """ print(f"{Colors.WHITE}{url:<20}{Colors.END}", end="") if int(max_file_size / 1024) != 50: print(f" | {Colors.YELLOW}Size: {int(max_file_size/1024)}kb{Colors.END}", end="") @@ -22,12 +37,46 @@ def print_query(url: str, max_file_size: int, pattern_type: str, pattern: str) - def print_error(url: str, e: Exception, max_file_size: int, pattern_type: str, pattern: str) -> None: + """ + Print a formatted error message including the URL, file size, pattern details, and the exception encountered, + for debugging or logging purposes. + + Parameters + ---------- + url : str + The URL associated with the query that caused the error. + e : Exception + The exception raised during the query or process. + max_file_size : int + The maximum file size allowed for the query, in bytes. + pattern_type : str + Specifies the type of pattern to use, either "include" or "exclude". + pattern : str + The actual pattern string to include or exclude in the query. + """ print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") print_query(url, max_file_size, pattern_type, pattern) print(f" | {Colors.RED}{e}{Colors.END}") def print_success(url: str, max_file_size: int, pattern_type: str, pattern: str, summary: str) -> None: + """ + Print a formatted success message, including the URL, file size, pattern details, and a summary with estimated + tokens, for debugging or logging purposes. + + Parameters + ---------- + url : str + The URL associated with the successful query. + max_file_size : int + The maximum file size allowed for the query, in bytes. + pattern_type : str + Specifies the type of pattern to use, either "include" or "exclude". + pattern : str + The actual pattern string to include or exclude in the query. + summary : str + A summary of the query result, including details like estimated tokens. + """ estimated_tokens = summary[summary.index("Estimated tokens:") + len("Estimated ") :] print(f"{Colors.GREEN}INFO{Colors.END}: {Colors.GREEN}<- {Colors.END}", end="") print_query(url, max_file_size, pattern_type, pattern) @@ -42,6 +91,32 @@ async def process_query( pattern: str = "", is_index: bool = False, ) -> _TemplateResponse: + """ + Process a query by parsing input, cloning a repository, and generating a summary. + + Handle user input, process GitHub repository data, and prepare + a response for rendering a template with the processed results or an error message. + + Parameters + ---------- + request : Request + The HTTP request object. + input_text : str + Input text provided by the user, typically a GitHub repository URL or slug. + slider_position : int + Position of the slider, representing the maximum file size in the query. + pattern_type : str, optional + Type of pattern to use, either "include" or "exclude" (default is "exclude"). + pattern : str, optional + Pattern to include or exclude in the query, depending on the pattern type. + is_index : bool, optional + Flag indicating whether the request is for the index page (default is False). + + Returns + ------- + _TemplateResponse + Rendered template response containing the processed results or an error message. + """ template = "index.jinja" if is_index else "github.jinja" max_file_size = logSliderToSize(slider_position) diff --git a/src/server_utils.py b/src/server_utils.py index 2a6e186f..7af4b854 100644 --- a/src/server_utils.py +++ b/src/server_utils.py @@ -1,19 +1,29 @@ import math -## Rate Limiter from slowapi import Limiter from slowapi.util import get_remote_address +# Initialize a rate limiter limiter = Limiter(key_func=get_remote_address) -## Logarithmic slider to file size conversion def logSliderToSize(position: int) -> int: - """Convert slider position to file size in KB""" + """ + Convert a slider position to a file size in bytes using a logarithmic scale. + + Parameters + ---------- + position : int + Slider position ranging from 0 to 500. + + Returns + ------- + int + File size in bytes corresponding to the slider position. + """ maxp = 500 minv = math.log(1) - maxv = math.log(102400) - + maxv = math.log(102_400) return round(math.exp(minv + (maxv - minv) * pow(position / maxp, 1.5))) * 1024