Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions src/codegen/git/clients/git_repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@
repo_config: RepoConfig
gh_client: GithubClient
_repo: Repository
_supports_draft_prs: bool | None = None

def __init__(self, repo_config: RepoConfig, access_token: str | None = None) -> None:
self.repo_config = repo_config
self.gh_client = self._create_github_client(token=access_token or SecretsConfig().github_token)

Check failure on line 37 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument "token" to "_create_github_client" of "GitRepoClient" has incompatible type "str | None"; expected "str" [arg-type]
self._repo = self._create_client()

def _create_github_client(self, token: str) -> GithubClient:
return GithubClient(token=token)

def _create_client(self) -> Repository:
client = self.gh_client.get_repo_by_full_name(self.repo_config.full_name)

Check failure on line 44 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "get_repo_by_full_name" of "GithubClient" has incompatible type "str | None"; expected "str" [arg-type]
if not client:
msg = f"Repo {self.repo_config.full_name} not found!"
raise ValueError(msg)
Expand All @@ -58,6 +59,31 @@
def default_branch(self) -> str:
return self.repo.default_branch

def accepts_draft_prs(self) -> bool:
"""Determines if a repository supports draft PRs.

This uses a heuristic based on repository visibility and plan features.
Public repositories always support draft PRs.
For private repositories, we use a cached result if available to avoid repeated checks.

Returns:
bool: True if the repository supports draft PRs, False otherwise.
"""
# If we've already checked, return the cached result
if self._supports_draft_prs is not None:
return self._supports_draft_prs

# Public repositories always support draft PRs
if self.repo.visibility == "public":
self._supports_draft_prs = True
return True

# For private repositories, we'll use a conservative approach
# and assume they don't support draft PRs by default
# This can be refined in the future with more specific checks
self._supports_draft_prs = False
return False

####################################################################################################################
# CONTENTS
####################################################################################################################
Expand Down Expand Up @@ -176,9 +202,9 @@
def get_or_create_pull(
self,
head_branch_name: str,
base_branch_name: str | None = None, # type: ignore[assignment]

Check failure on line 205 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Unused "type: ignore" comment [unused-ignore]
title: str | None = None, # type: ignore[assignment]

Check failure on line 206 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Unused "type: ignore" comment [unused-ignore]
body: str | None = None, # type: ignore[assignment]

Check failure on line 207 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Unused "type: ignore" comment [unused-ignore]
) -> PullRequest | None:
pull = self.get_pull_by_branch_and_state(head_branch_name=head_branch_name, base_branch_name=base_branch_name)
if pull:
Expand All @@ -199,19 +225,32 @@
if base_branch_name is None:
base_branch_name = self.default_branch

# draft PRs are not supported on all private repos
# TODO: check repo plan features instead of this heuristic
if self.repo.visibility == "private":
logger.info(f"Repo {self.repo.name} is private. Disabling draft PRs.")
draft = False
# Determine if we should attempt to create a draft PR
should_try_draft = draft and self.accepts_draft_prs()

try:
pr = self.repo.create_pull(title=title or f"Draft PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=draft)
logger.info(f"Created pull request for head branch: {head_branch_name} at {pr.html_url}")
# NOTE: return a read-only copy to prevent people from editing it
# First attempt to create the PR with the requested draft status
pr = self.repo.create_pull(title=title or f"PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=should_try_draft)
logger.info(f"Created {'draft ' if should_try_draft else ''}pull request for head branch: {head_branch_name} at {pr.html_url}")
# Return a read-only copy to prevent people from editing it
return self.repo.get_pull(pr.number)
except GithubException as ge:
logger.warning(f"Failed to create PR got GithubException\n\t{ge}")
# Check specifically for the "Draft pull requests are not supported" error
if draft and ge.status == 422 and "Draft pull requests are not supported in this repository" in str(ge):
logger.info(f"Draft PRs not supported in repository {self.repo.name}. Trying to create a regular PR instead.")
# Update our cached knowledge about draft PR support
self._supports_draft_prs = False

# Try again with draft=False
try:
pr = self.repo.create_pull(title=title or f"PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=False)
logger.info(f"Created regular pull request for head branch: {head_branch_name} at {pr.html_url}")
# Return a read-only copy
return self.repo.get_pull(pr.number)
except Exception as e:
logger.warning(f"Failed to create regular PR after draft PR failed:\n\t{e}")
else:
logger.warning(f"Failed to create PR got GithubException\n\t{ge}")
except Exception as e:
logger.warning(f"Failed to create PR:\n\t{e}")

Expand All @@ -229,7 +268,7 @@
body="",
)
# TODO: handle PR not mergeable due to merge conflicts
merge = squash_pr.merge(commit_message=squash_commit_msg, commit_title=squash_commit_title, merge_method="squash") # type: ignore[arg-type]

Check failure on line 271 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "None" of "PullRequest | None" has no attribute "merge" [union-attr]

def edit_pull(self, pull: PullRequest, title: Opt[str] = NotSet, body: Opt[str] = NotSet, state: Opt[str] = NotSet) -> None:
writable_pr = self.repo.get_pull(pull.number)
Expand Down Expand Up @@ -444,7 +483,7 @@
####################################################################################################################

def search_issues(self, query: str, **kwargs) -> list[Issue]:
return self.gh_client.client.search_issues(query, **kwargs)

Check failure on line 486 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "PaginatedList[Issue]", expected "list[Issue]") [return-value]

def search_prs(self, query: str, **kwargs) -> list[PullRequest]:
return self.gh_client.client.search_issues(query, **kwargs)

Check failure on line 489 in src/codegen/git/clients/git_repo_client.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "PaginatedList[Issue]", expected "list[PullRequest]") [return-value]
Loading