Skip to content
Open
285 changes: 257 additions & 28 deletions src/rotator_library/providers/copilot_auth_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
import time
import asyncio
import logging
import re
from pathlib import Path
from ..utils.paths import get_oauth_dir
from typing import Dict, Any, Optional, Union
import tempfile
import shutil
from dataclasses import dataclass, field
from glob import glob

import httpx
from rich.console import Console
Expand All @@ -32,23 +36,26 @@
console = Console()


@dataclass
class CopilotCredentialSetupResult:
"""Standardized result for Copilot credential setup operations."""

success: bool
file_path: Optional[str] = None
email: Optional[str] = None
tier: Optional[str] = None
project_id: Optional[str] = None
is_update: bool = False
error: Optional[str] = None
credentials: Optional[Dict[str, Any]] = field(default=None, repr=False)


class CopilotAuthBase:
"""
GitHub Copilot OAuth2 authentication using Device Flow.

This provider uses GitHub's Device Authorization Grant flow, which is
more suitable for CLI applications than the web-based Authorization Code flow.

Key differences from GoogleOAuthBase:
- Uses GitHub Device Flow (polls for authorization)
- Two-token system: GitHub OAuth token + Copilot API token
- Copilot API tokens expire quickly (~30 min) and need frequent refresh

Subclasses may override:
- ENV_PREFIX: Prefix for environment variables (default: "COPILOT")
- REFRESH_EXPIRY_BUFFER_SECONDS: Time buffer before token expiry

Supports both github.com and GitHub Enterprise deployments.
"""

# GitHub Copilot OAuth Client ID (from VS Code Copilot extension)
Expand Down Expand Up @@ -273,17 +280,20 @@ def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
return True

async def _refresh_copilot_token(
self, path: str, creds: Dict[str, Any], force: bool = False
self, path: Optional[str], creds: Dict[str, Any], force: bool = False
) -> Dict[str, Any]:
"""
Refresh the Copilot API token using the GitHub OAuth token.

The GitHub OAuth token (refresh_token) is long-lived.
The Copilot API token (access_token) expires after ~30 minutes.
"""
async with await self._get_lock(path):
lock_key = path or f"in-memory://copilot/{id(creds)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using id(creds) for an in-memory lock key is clever, but be aware that it might lead to redundant locks if the creds dictionary is copied or recreated (e.g., during serialization/deserialization cycles). If the credentials have a unique identifier like an email, that might be a more stable key.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point — id() can be recycled if the dict gets garbage collected and recreated. In practice the in-memory path only triggers for env-based credentials which are long-lived singletons, but using the email from _proxy_metadata would be more robust. Open to changing if preferred.

display_name = Path(path).name if path else "in-memory credential"

async with await self._get_lock(lock_key):
# Skip if token is still valid (unless forced)
cached_creds = self._credentials_cache.get(path, creds)
cached_creds = self._credentials_cache.get(lock_key, creds)
if not force and not self._is_token_expired(cached_creds):
return cached_creds

Expand All @@ -302,7 +312,7 @@ async def _refresh_copilot_token(
urls = self._get_urls(domain)

lib_logger.debug(
f"Refreshing {self.ENV_PREFIX} Copilot API token for '{Path(path).name}' (forced: {force})..."
f"Refreshing {self.ENV_PREFIX} Copilot API token for '{display_name}' (forced: {force})..."
)

async with httpx.AsyncClient() as client:
Expand All @@ -319,10 +329,14 @@ async def _refresh_copilot_token(

if response.status_code == 401:
lib_logger.warning(
f"GitHub token invalid for '{Path(path).name}' (HTTP 401). "
f"GitHub token invalid for '{display_name}' (HTTP 401). "
f"Token may have been revoked. Starting re-authentication..."
)
return await self.initialize_token(path)
if path:
return await self.initialize_token(path)
raise ValueError(
"GitHub token invalid for in-memory credential and cannot re-auth without a file path"
)

response.raise_for_status()
token_data = response.json()
Expand All @@ -338,9 +352,12 @@ async def _refresh_copilot_token(
creds["_proxy_metadata"] = {}
creds["_proxy_metadata"]["last_check_timestamp"] = time.time()

await self._save_credentials(path, creds)
if path:
await self._save_credentials(path, creds)
else:
self._credentials_cache[lock_key] = creds
lib_logger.debug(
f"Successfully refreshed {self.ENV_PREFIX} Copilot API token for '{Path(path).name}'."
f"Successfully refreshed {self.ENV_PREFIX} Copilot API token for '{display_name}'."
)
return creds

Expand Down Expand Up @@ -396,9 +413,13 @@ async def initialize_token(
)

try:
creds = (
await self._load_credentials(creds_or_path) if path else creds_or_path
)
if path:
creds: Dict[str, Any] = await self._load_credentials(path)
elif isinstance(creds_or_path, dict):
creds = creds_or_path
else:
raise ValueError("Invalid credential input for Copilot initialization")

needs_auth = False
reason = ""

Expand Down Expand Up @@ -545,10 +566,13 @@ async def initialize_token(
)
if user_response.is_success:
user_info = user_response.json()
resolved_identity = (
user_info.get("email")
or user_info.get("login")
or "unknown"
)
new_creds["_proxy_metadata"]["email"] = (
user_info.get(
"email", user_info.get("login", "unknown")
)
resolved_identity
)
except Exception as e:
lib_logger.warning(f"Failed to fetch user info: {e}")
Expand Down Expand Up @@ -591,7 +615,12 @@ async def get_user_info(
) -> Dict[str, Any]:
"""Get user info from cached metadata or API."""
path = creds_or_path if isinstance(creds_or_path, str) else None
creds = await self._load_credentials(creds_or_path) if path else creds_or_path
if path:
creds: Dict[str, Any] = await self._load_credentials(path)
elif isinstance(creds_or_path, dict):
creds = creds_or_path
else:
return {"email": "unknown"}

if creds.get("_proxy_metadata", {}).get("email"):
return {"email": creds["_proxy_metadata"]["email"]}
Expand All @@ -615,8 +644,10 @@ async def get_user_info(
)
if response.is_success:
user_info = response.json()
email = user_info.get(
"email", user_info.get("login", "unknown")
email = (
user_info.get("email")
or user_info.get("login")
or "unknown"
)
creds["_proxy_metadata"] = {
"email": email,
Expand All @@ -629,3 +660,201 @@ async def get_user_info(
lib_logger.warning(f"Failed to fetch user info: {e}")

return {"email": "unknown"}

def _get_oauth_base_dir(self) -> Path:
"""Return the OAuth credentials base directory."""
return get_oauth_dir()

def _get_provider_file_prefix(self) -> str:
"""Return file prefix for Copilot credential files."""
return "copilot"

def _find_existing_credential_by_email(
self, email: str, base_dir: Optional[Path] = None
) -> Optional[Path]:
"""Find existing credential file by email for deduplication."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

for cred in self.list_credentials(base_dir):
if cred.get("email", "").lower() == email.lower():
return Path(cred["file_path"])
return None

def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int:
"""Get next available credential number."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

prefix = self._get_provider_file_prefix()
pattern = str(base_dir / f"{prefix}_oauth_*.json")

existing_numbers = []
for cred_file in glob(pattern):
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
if match:
existing_numbers.append(int(match.group(1)))

if not existing_numbers:
return 1
return max(existing_numbers) + 1

def _build_credential_path(
self, base_dir: Optional[Path] = None, number: Optional[int] = None
) -> Path:
"""Build path for a new Copilot credential file."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

if number is None:
number = self._get_next_credential_number(base_dir)

prefix = self._get_provider_file_prefix()
return base_dir / f"{prefix}_oauth_{number}.json"

async def setup_credential(
self, base_dir: Optional[Path] = None
) -> CopilotCredentialSetupResult:
"""Complete credential setup flow: OAuth -> save."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

base_dir.mkdir(parents=True, exist_ok=True)

try:
temp_creds = {"_proxy_metadata": {"display_name": "new Copilot credential"}}
new_creds = await self.initialize_token(temp_creds)

user_info = await self.get_user_info(new_creds)
email = user_info.get("email")
if not email:
return CopilotCredentialSetupResult(
success=False, error="Could not retrieve email from OAuth response"
)

existing_path = self._find_existing_credential_by_email(email, base_dir)
is_update = existing_path is not None

if is_update:
file_path = existing_path
lib_logger.info(
f"Found existing credential for {email}, updating {file_path.name}"
)
else:
file_path = self._build_credential_path(base_dir)
lib_logger.info(
f"Creating new credential for {email} at {file_path.name}"
)

await self._save_credentials(str(file_path), new_creds)

return CopilotCredentialSetupResult(
success=True,
file_path=str(file_path),
email=email,
is_update=is_update,
credentials=new_creds,
)

except Exception as e:
lib_logger.error(f"Copilot credential setup failed: {e}")
return CopilotCredentialSetupResult(success=False, error=str(e))

def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> list[str]:
"""Generate .env file lines for a Copilot credential."""
email = creds.get("_proxy_metadata", {}).get("email", "unknown")
prefix = f"COPILOT_{cred_number}"

lines = [
f"# COPILOT Credential #{cred_number} for: {email}",
f"# Exported from: copilot_oauth_{cred_number}.json",
f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}",
"",
f"{prefix}_GITHUB_TOKEN={creds.get('refresh_token', '')}",
f"{prefix}_ENTERPRISE_URL={creds.get('enterprise_url', '')}",
f"{prefix}_EMAIL={email}",
]

return lines

def export_credential_to_env(
self, credential_path: str, output_dir: Optional[Path] = None
) -> Optional[str]:
"""Export a Copilot credential file to .env format."""
try:
cred_path = Path(credential_path)
with open(cred_path, "r") as f:
creds = json.load(f)

email = creds.get("_proxy_metadata", {}).get("email", "unknown")
match = re.search(r"_oauth_(\d+)\.json$", cred_path.name)
cred_number = int(match.group(1)) if match else 1

if output_dir is None:
output_dir = cred_path.parent

safe_email = email.replace("@", "_at_").replace(".", "_")
env_filename = f"copilot_{cred_number}_{safe_email}.env"
env_path = output_dir / env_filename

with open(env_path, "w") as f:
f.write("\n".join(self.build_env_lines(creds, cred_number)))

return str(env_path)
except Exception as e:
lib_logger.error(f"Failed to export Copilot credential: {e}")
return None

def list_credentials(self, base_dir: Optional[Path] = None) -> list[Dict[str, Any]]:
"""List all Copilot credential files."""
if base_dir is None:
base_dir = self._get_oauth_base_dir()

prefix = self._get_provider_file_prefix()
pattern = str(base_dir / f"{prefix}_oauth_*.json")

credentials = []
for cred_file in sorted(glob(pattern)):
try:
with open(cred_file, "r") as f:
creds = json.load(f)

metadata = creds.get("_proxy_metadata", {})
match = re.search(r"_oauth_(\d+)\.json$", cred_file)
number = int(match.group(1)) if match else 0

credentials.append(
{
"file_path": cred_file,
"email": metadata.get("email", "unknown"),
"number": number,
}
)
except Exception as e:
lib_logger.debug(f"Could not read credential file {cred_file}: {e}")

return credentials

def delete_credential(self, credential_path: str) -> bool:
"""Delete a Copilot credential file."""
try:
cred_path = Path(credential_path)
prefix = self._get_provider_file_prefix()

if not cred_path.name.startswith(f"{prefix}_oauth_"):
lib_logger.error(
f"File {cred_path.name} does not appear to be a Copilot credential"
)
return False

if not cred_path.exists():
lib_logger.warning(f"Credential file does not exist: {credential_path}")
return False

self._credentials_cache.pop(credential_path, None)
cred_path.unlink()
lib_logger.info(f"Deleted credential file: {credential_path}")
return True
except Exception as e:
lib_logger.error(f"Failed to delete Copilot credential: {e}")
return False
Loading