From b5ffa5f8aac678af1b45e39064354c0190b80338 Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:09:58 +0100 Subject: [PATCH] Refactor CLI to wrap async logic with a sync command function (#136) - Change main() to a synchronous Click command - Introduce _async_main() for async ingest logic - Use asyncio.run(...) to properly await the async function --- src/gitingest/cli.py | 41 ++++++++-- src/main.py | 2 +- tests/test_flow_integration.py | 145 +++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 6 deletions(-) create mode 100644 tests/test_flow_integration.py diff --git a/src/gitingest/cli.py b/src/gitingest/cli.py index ef7761b9..a21a4533 100644 --- a/src/gitingest/cli.py +++ b/src/gitingest/cli.py @@ -2,6 +2,8 @@ # pylint: disable=no-value-for-parameter +import asyncio + import click from config import MAX_FILE_SIZE @@ -9,12 +11,42 @@ @click.command() -@click.argument("source", type=str, required=True) +@click.argument("source", type=str, default=".") @click.option("--output", "-o", default=None, help="Output file path (default: .txt in current directory)") @click.option("--max-size", "-s", default=MAX_FILE_SIZE, help="Maximum file size to process in bytes") @click.option("--exclude-pattern", "-e", multiple=True, help="Patterns to exclude") @click.option("--include-pattern", "-i", multiple=True, help="Patterns to include") -async def main( +def main( + source: str, + output: str | None, + max_size: int, + exclude_pattern: tuple[str, ...], + include_pattern: tuple[str, ...], +): + """ + Main entry point for the CLI. This function is called when the CLI is run as a script. + + It calls the async main function to run the command. + + Parameters + ---------- + source : str + The source directory or repository to analyze. + output : str | None + The path where the output file will be written. If not specified, the output will be written + to a file named `.txt` in the current directory. + max_size : int + The maximum file size to process, in bytes. Files larger than this size will be ignored. + exclude_pattern : tuple[str, ...] + A tuple of patterns to exclude during the analysis. Files matching these patterns will be ignored. + include_pattern : tuple[str, ...] + A tuple of patterns to include during the analysis. Only files matching these patterns will be processed. + """ + # Main entry point for the CLI. This function is called when the CLI is run as a script. + asyncio.run(_async_main(source, output, max_size, exclude_pattern, include_pattern)) + + +async def _async_main( source: str, output: str | None, max_size: int, @@ -24,9 +56,8 @@ async def main( """ Analyze a directory or repository and create a text dump of its contents. - This command analyzes the contents of a specified source directory or repository, - applies custom include and exclude patterns, and generates a text summary of the analysis - which is then written to an output file. + This command analyzes the contents of a specified source directory or repository, applies custom include and + exclude patterns, and generates a text summary of the analysis which is then written to an output file. Parameters ---------- diff --git a/src/main.py b/src/main.py index 556b3e1d..e11cebcc 100644 --- a/src/main.py +++ b/src/main.py @@ -175,7 +175,7 @@ async def rate_limit_exception_handler(request: Request, exc: Exception) -> Resp app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts) # Set up template rendering -templates = Jinja2Templates(directory="templates") +templates = Jinja2Templates(directory="src/templates") @app.get("/health") diff --git a/tests/test_flow_integration.py b/tests/test_flow_integration.py new file mode 100644 index 00000000..a8b84f57 --- /dev/null +++ b/tests/test_flow_integration.py @@ -0,0 +1,145 @@ +""" +Integration tests for GitIngest. +These tests cover core functionalities, edge cases, and concurrency handling. +""" + +import shutil +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from main import app + +BASE_DIR = Path(__file__).resolve().parent.parent +TEMPLATE_DIR = BASE_DIR / "src" / "templates" + + +@pytest.fixture(scope="module") +def test_client(): + """Create a test client fixture.""" + with TestClient(app) as client: + client.headers.update({"Host": "localhost"}) + yield client + + +@pytest.fixture(scope="module", autouse=True) +def mock_templates(): + """Mock Jinja2 template rendering to bypass actual file loading.""" + with patch("starlette.templating.Jinja2Templates.TemplateResponse") as mock_template: + mock_template.return_value = "Mocked Template Response" + yield mock_template + + +def cleanup_temp_directories(): + temp_dir = Path("/tmp/gitingest") + if temp_dir.exists(): + try: + shutil.rmtree(temp_dir) + except PermissionError as e: + print(f"Error cleaning up {temp_dir}: {e}") + + +@pytest.fixture(scope="module", autouse=True) +def cleanup(): + """Cleanup temporary directories after tests.""" + yield + cleanup_temp_directories() + + +@pytest.mark.asyncio +async def test_remote_repository_analysis(test_client): # pylint: disable=redefined-outer-name + """Test the complete flow of analyzing a remote repository.""" + form_data = { + "input_text": "https://github.com/octocat/Hello-World", + "max_file_size": "243", + "pattern_type": "exclude", + "pattern": "", + } + + response = test_client.post("/", data=form_data) + assert response.status_code == 200, f"Form submission failed: {response.text}" + assert "Mocked Template Response" in response.text + + +@pytest.mark.asyncio +async def test_invalid_repository_url(test_client): # pylint: disable=redefined-outer-name + """Test handling of an invalid repository URL.""" + form_data = { + "input_text": "https://github.com/nonexistent/repo", + "max_file_size": "243", + "pattern_type": "exclude", + "pattern": "", + } + + response = test_client.post("/", data=form_data) + assert response.status_code == 200, f"Request failed: {response.text}" + assert "Mocked Template Response" in response.text + + +@pytest.mark.asyncio +async def test_large_repository(test_client): # pylint: disable=redefined-outer-name + """Simulate analysis of a large repository with nested folders.""" + form_data = { + "input_text": "https://github.com/large/repo-with-many-files", + "max_file_size": "243", + "pattern_type": "exclude", + "pattern": "", + } + + response = test_client.post("/", data=form_data) + assert response.status_code == 200, f"Request failed: {response.text}" + assert "Mocked Template Response" in response.text + + +@pytest.mark.asyncio +async def test_concurrent_requests(test_client): # pylint: disable=redefined-outer-name + """Test handling of multiple concurrent requests.""" + + def make_request(): + form_data = { + "input_text": "https://github.com/octocat/Hello-World", + "max_file_size": "243", + "pattern_type": "exclude", + "pattern": "", + } + response = test_client.post("/", data=form_data) + assert response.status_code == 200, f"Request failed: {response.text}" + assert "Mocked Template Response" in response.text + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(make_request) for _ in range(5)] + for future in futures: + future.result() + + +@pytest.mark.asyncio +async def test_large_file_handling(test_client): # pylint: disable=redefined-outer-name + """Test handling of repositories with large files.""" + form_data = { + "input_text": "https://github.com/octocat/Hello-World", + "max_file_size": "1", + "pattern_type": "exclude", + "pattern": "", + } + + response = test_client.post("/", data=form_data) + assert response.status_code == 200, f"Request failed: {response.text}" + assert "Mocked Template Response" in response.text + + +@pytest.mark.asyncio +async def test_repository_with_patterns(test_client): # pylint: disable=redefined-outer-name + """Test repository analysis with include/exclude patterns.""" + form_data = { + "input_text": "https://github.com/octocat/Hello-World", + "max_file_size": "243", + "pattern_type": "include", + "pattern": "*.md", + } + + response = test_client.post("/", data=form_data) + assert response.status_code == 200, f"Request failed: {response.text}" + assert "Mocked Template Response" in response.text