diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 00c20f7bb..f359fe504 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -1,6 +1,6 @@ """Tool for viewing PR contents and modified symbols.""" -from typing import ClassVar +from typing import Any, ClassVar from pydantic import Field @@ -24,6 +24,14 @@ class ViewPRObservation(Observation): modified_symbols: list[str] = Field( description="Names of modified symbols in the PR", ) + github_comments: list[dict[str, Any]] = Field( + description="Comments on the PR", + default_factory=list, + ) + github_reviews: list[dict[str, Any]] = Field( + description="Reviews on the PR", + default_factory=list, + ) str_template: ClassVar[str] = "PR #{pr_id}" @@ -36,14 +44,16 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: pr_id: Number of the PR to get the contents for """ try: - patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id) + patch, file_commit_sha, modified_symbols, comments, reviews = codebase.get_modified_symbols_in_pr(pr_id) return ViewPRObservation( status="success", pr_id=pr_id, patch=patch, file_commit_sha=file_commit_sha, - modified_symbols=moddified_symbols, + modified_symbols=modified_symbols, + github_comments=comments, + github_reviews=reviews, ) except Exception as e: @@ -54,4 +64,6 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: patch="", file_commit_sha={}, modified_symbols=[], + github_comments=[], + github_reviews=[], ) diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 01950a7a2..f36081cc2 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -1527,13 +1527,67 @@ def from_files( logger.info("Codebase initialization complete") return codebase - def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str], str]: + def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str], list[dict], list[dict]]: """Get all modified symbols in a pull request""" pr = self._op.get_pull_request(pr_id) cg_pr = CodegenPR(self._op, self, pr) patch = cg_pr.get_pr_diff() commit_sha = cg_pr.get_file_commit_shas() - return patch, commit_sha, cg_pr.modified_symbols, pr.head.ref + + # Get comments and reviews + comments = [] + reviews = [] + + try: + # Get PR comments (issue comments) + issue_comments = pr.get_issue_comments() + for comment in issue_comments: + comments.append( + { + "id": comment.id, + "user": comment.user.login, + "body": comment.body, + "created_at": comment.created_at.isoformat(), + "updated_at": comment.updated_at.isoformat() if comment.updated_at else None, + "type": "issue_comment", + } + ) + + # Get PR review comments (comments on specific lines) + review_comments = pr.get_comments() + for comment in review_comments: + comments.append( + { + "id": comment.id, + "user": comment.user.login, + "body": comment.body, + "created_at": comment.created_at.isoformat(), + "updated_at": comment.updated_at.isoformat() if comment.updated_at else None, + "path": comment.path, + "position": comment.position, + "commit_id": comment.commit_id, + "type": "review_comment", + } + ) + + # Get PR reviews + pr_reviews = pr.get_reviews() + for review in pr_reviews: + reviews.append( + { + "id": review.id, + "user": review.user.login, + "body": review.body, + "state": review.state, + "submitted_at": review.submitted_at.isoformat() if review.submitted_at else None, + "commit_id": review.commit_id, + "type": "review", + } + ) + except Exception as e: + print(f"Error fetching PR comments or reviews: {e}") + + return patch, commit_sha, cg_pr.modified_symbols, comments, reviews def create_pr_comment(self, pr_number: int, body: str) -> None: """Create a comment on a pull request"""