diff --git a/src/gitingest/cli.py b/src/gitingest/cli.py index 2163b0e1..549b5945 100644 --- a/src/gitingest/cli.py +++ b/src/gitingest/cli.py @@ -16,12 +16,14 @@ @click.option("--max-size", "-s", default=MAX_FILE_SIZE, help="Maximum file size to process in bytes") @click.option("--exclude-pattern", "-e", multiple=True, help="Patterns to exclude") @click.option("--include-pattern", "-i", multiple=True, help="Patterns to include") +@click.option("--branch", "-b", default=None, help="Branch to clone and ingest") def main( source: str, output: str | None, max_size: int, exclude_pattern: tuple[str, ...], include_pattern: tuple[str, ...], + branch: str | None, ): """ Main entry point for the CLI. This function is called when the CLI is run as a script. @@ -41,9 +43,11 @@ def main( A tuple of patterns to exclude during the analysis. Files matching these patterns will be ignored. include_pattern : tuple[str, ...] A tuple of patterns to include during the analysis. Only files matching these patterns will be processed. + branch : str | None + The branch to clone (optional). """ # Main entry point for the CLI. This function is called when the CLI is run as a script. - asyncio.run(_async_main(source, output, max_size, exclude_pattern, include_pattern)) + asyncio.run(_async_main(source, output, max_size, exclude_pattern, include_pattern, branch)) async def _async_main( @@ -52,6 +56,7 @@ async def _async_main( max_size: int, exclude_pattern: tuple[str, ...], include_pattern: tuple[str, ...], + branch: str | None, ) -> None: """ Analyze a directory or repository and create a text dump of its contents. @@ -72,6 +77,8 @@ async def _async_main( A tuple of patterns to exclude during the analysis. Files matching these patterns will be ignored. include_pattern : tuple[str, ...] A tuple of patterns to include during the analysis. Only files matching these patterns will be processed. + branch : str | None + The branch to clone (optional). Raises ------ @@ -85,7 +92,7 @@ async def _async_main( if not output: output = OUTPUT_FILE_PATH - summary, _, _ = await ingest(source, max_size, include_patterns, exclude_patterns, output=output) + summary, _, _ = await ingest(source, max_size, include_patterns, exclude_patterns, branch, output=output) click.echo(f"Analysis complete! Output written to: {output}") click.echo("\nSummary:") diff --git a/src/gitingest/repository_ingest.py b/src/gitingest/repository_ingest.py index f92c1c2d..57be89da 100644 --- a/src/gitingest/repository_ingest.py +++ b/src/gitingest/repository_ingest.py @@ -15,6 +15,7 @@ async def ingest( max_file_size: int = 10 * 1024 * 1024, # 10 MB include_patterns: set[str] | str | None = None, exclude_patterns: set[str] | str | None = None, + branch: str | None = None, output: str | None = None, ) -> tuple[str, str, str]: """ @@ -35,6 +36,8 @@ async def ingest( Pattern or set of patterns specifying which files to include. If `None`, all files are included. exclude_patterns : set[str] | str | None, optional Pattern or set of patterns specifying which files to exclude. If `None`, no files are excluded. + branch : str | None, optional + The branch to clone and ingest. If `None`, the default branch is used. output : str | None, optional File path where the summary and content should be written. If `None`, the results are not written to a file. @@ -61,17 +64,23 @@ async def ingest( ) if parsed_query.url: + selected_branch = branch if branch else parsed_query.branch # prioritize branch argument + parsed_query.branch = selected_branch + # Extract relevant fields for CloneConfig clone_config = CloneConfig( url=parsed_query.url, local_path=str(parsed_query.local_path), commit=parsed_query.commit, - branch=parsed_query.branch, + branch=selected_branch, ) clone_result = clone_repo(clone_config) if inspect.iscoroutine(clone_result): - asyncio.run(clone_result) + if asyncio.get_event_loop().is_running(): + await clone_result + else: + asyncio.run(clone_result) else: raise TypeError("clone_repo did not return a coroutine as expected.") diff --git a/tests/test_repository_clone.py b/tests/test_repository_clone.py index de417bea..380ad5d0 100644 --- a/tests/test_repository_clone.py +++ b/tests/test_repository_clone.py @@ -6,6 +6,7 @@ """ import asyncio +import os from unittest.mock import AsyncMock, patch import pytest @@ -306,3 +307,58 @@ async def test_clone_repo_with_timeout() -> None: mock_exec.side_effect = asyncio.TimeoutError with pytest.raises(AsyncTimeoutError, match="Operation timed out after"): await clone_repo(clone_config) + + +@pytest.mark.asyncio +async def test_clone_specific_branch(tmp_path): + """ + Test cloning a specific branch of a repository. + + Given a valid repository URL and a branch name: + When `clone_repo` is called, + Then the repository should be cloned and checked out at that branch. + """ + repo_url = "https://github.com/cyclotruc/gitingest.git" + branch_name = "main" + local_path = tmp_path / "gitingest" + + config = CloneConfig(url=repo_url, local_path=str(local_path), branch=branch_name) + await clone_repo(config) + + # Assertions + assert local_path.exists(), "The repository was not cloned successfully." + assert local_path.is_dir(), "The cloned repository path is not a directory." + + # Check the current branch + current_branch = os.popen(f"git -C {local_path} branch --show-current").read().strip() + assert current_branch == branch_name, f"Expected branch '{branch_name}', got '{current_branch}'." + + +@pytest.mark.asyncio +async def test_clone_branch_with_slashes(tmp_path): + """ + Test cloning a branch with slashes in the name. + + Given a valid repository URL and a branch name with slashes: + When `clone_repo` is called, + Then the repository should be cloned and checked out at that branch. + """ + repo_url = "https://github.com/user/repo" + branch_name = "fix/in-operator" + local_path = tmp_path / "gitingest" + + clone_config = CloneConfig(url=repo_url, local_path=str(local_path), branch=branch_name) + with patch("gitingest.repository_clone._check_repo_exists", return_value=True): + with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_exec: + await clone_repo(clone_config) + + mock_exec.assert_called_once_with( + "git", + "clone", + "--depth=1", + "--single-branch", + "--branch", + "fix/in-operator", + clone_config.url, + clone_config.local_path, + )