diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 00c20f7bb..87f9229d7 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -24,6 +24,10 @@ class ViewPRObservation(Observation): modified_symbols: list[str] = Field( description="Names of modified symbols in the PR", ) + head_branch: str = Field( + description="Name of the head branch of the PR", + default="", + ) str_template: ClassVar[str] = "PR #{pr_id}" @@ -38,12 +42,17 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: try: patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id) + # Get the PR object to extract the branch name + pr = codebase._op.get_pull_request(pr_id) + head_branch = pr.head.ref if pr else "" + return ViewPRObservation( status="success", pr_id=pr_id, patch=patch, file_commit_sha=file_commit_sha, modified_symbols=moddified_symbols, + head_branch=head_branch, ) except Exception as e: @@ -54,4 +63,5 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: patch="", file_commit_sha={}, modified_symbols=[], + head_branch="", )