Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/codegen/extensions/tools/github/view_pr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tool for viewing PR contents and modified symbols."""

from typing import ClassVar
from typing import Any, ClassVar

from pydantic import Field

Expand All @@ -24,6 +24,14 @@ class ViewPRObservation(Observation):
modified_symbols: list[str] = Field(
description="Names of modified symbols in the PR",
)
github_comments: list[dict[str, Any]] = Field(
description="Comments on the PR",
default_factory=list,
)
github_reviews: list[dict[str, Any]] = Field(
description="Reviews on the PR",
default_factory=list,
)

str_template: ClassVar[str] = "PR #{pr_id}"

Expand All @@ -36,14 +44,16 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation:
pr_id: Number of the PR to get the contents for
"""
try:
patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id)
patch, file_commit_sha, modified_symbols, comments, reviews = codebase.get_modified_symbols_in_pr(pr_id)

return ViewPRObservation(
status="success",
pr_id=pr_id,
patch=patch,
file_commit_sha=file_commit_sha,
modified_symbols=moddified_symbols,
modified_symbols=modified_symbols,
github_comments=comments,
github_reviews=reviews,
)

except Exception as e:
Expand All @@ -54,4 +64,6 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation:
patch="",
file_commit_sha={},
modified_symbols=[],
github_comments=[],
github_reviews=[],
)
58 changes: 56 additions & 2 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,13 +1527,67 @@ def from_files(
logger.info("Codebase initialization complete")
return codebase

def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str], str]:
def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str], list[dict], list[dict]]:
"""Get all modified symbols in a pull request"""
pr = self._op.get_pull_request(pr_id)
cg_pr = CodegenPR(self._op, self, pr)
patch = cg_pr.get_pr_diff()
commit_sha = cg_pr.get_file_commit_shas()
return patch, commit_sha, cg_pr.modified_symbols, pr.head.ref

# Get comments and reviews
comments = []
reviews = []

try:
# Get PR comments (issue comments)
issue_comments = pr.get_issue_comments()
for comment in issue_comments:
comments.append(
{
"id": comment.id,
"user": comment.user.login,
"body": comment.body,
"created_at": comment.created_at.isoformat(),
"updated_at": comment.updated_at.isoformat() if comment.updated_at else None,
"type": "issue_comment",
}
)

# Get PR review comments (comments on specific lines)
review_comments = pr.get_comments()
for comment in review_comments:
comments.append(
{
"id": comment.id,
"user": comment.user.login,
"body": comment.body,
"created_at": comment.created_at.isoformat(),
"updated_at": comment.updated_at.isoformat() if comment.updated_at else None,
"path": comment.path,
"position": comment.position,
"commit_id": comment.commit_id,
"type": "review_comment",
}
)

# Get PR reviews
pr_reviews = pr.get_reviews()
for review in pr_reviews:
reviews.append(
{
"id": review.id,
"user": review.user.login,
"body": review.body,
"state": review.state,
"submitted_at": review.submitted_at.isoformat() if review.submitted_at else None,
"commit_id": review.commit_id,
"type": "review",
}
)
except Exception as e:
print(f"Error fetching PR comments or reviews: {e}")

return patch, commit_sha, cg_pr.modified_symbols, comments, reviews

def create_pr_comment(self, pr_number: int, body: str) -> None:
"""Create a comment on a pull request"""
Expand Down
Loading