From 52108b24889520a71e476afebae165ab05405df6 Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Thu, 20 Mar 2025 02:04:11 +0000 Subject: [PATCH 1/2] Add branch name to view_pr function response --- src/codegen/extensions/tools/github/view_pr.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 00c20f7bb..805d80ae6 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -24,6 +24,10 @@ class ViewPRObservation(Observation): modified_symbols: list[str] = Field( description="Names of modified symbols in the PR", ) + head_branch: str = Field( + description="Name of the head branch of the PR", + default="", + ) str_template: ClassVar[str] = "PR #{pr_id}" @@ -37,6 +41,10 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: """ try: patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id) + + # Get the PR object to extract the branch name + pr = codebase._op.get_pull_request(pr_id) + head_branch = pr.head.ref if pr else "" return ViewPRObservation( status="success", @@ -44,6 +52,7 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: patch=patch, file_commit_sha=file_commit_sha, modified_symbols=moddified_symbols, + head_branch=head_branch, ) except Exception as e: @@ -54,4 +63,5 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: patch="", file_commit_sha={}, modified_symbols=[], + head_branch="", ) From 4a280c154dc9217e474aa966be65c28d75ac9ee5 Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Thu, 20 Mar 2025 02:04:55 +0000 Subject: [PATCH 2/2] Automated pre-commit update --- src/codegen/extensions/tools/github/view_pr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 805d80ae6..87f9229d7 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -41,7 +41,7 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: """ try: patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id) - + # Get the PR object to extract the branch name pr = codebase._op.get_pull_request(pr_id) head_branch = pr.head.ref if pr else ""