diff --git a/.env.example b/.env.example index 72351421..22c61142 100644 --- a/.env.example +++ b/.env.example @@ -85,6 +85,32 @@ # Path to your iFlow credential file (e.g., ~/.iflow/oauth_creds.json). #IFLOW_OAUTH_1="" +# --- OpenAI Codex (ChatGPT OAuth) --- +# One-time import from Codex CLI auth files (copied into oauth_creds/openai_codex_oauth_*.json) +#OPENAI_CODEX_OAUTH_1="~/.codex/auth.json" + +# Stateless env credentials (legacy single account) +#OPENAI_CODEX_ACCESS_TOKEN="" +#OPENAI_CODEX_REFRESH_TOKEN="" +#OPENAI_CODEX_EXPIRY_DATE="0" +#OPENAI_CODEX_ID_TOKEN="" +#OPENAI_CODEX_ACCOUNT_ID="" +#OPENAI_CODEX_EMAIL="" + +# Stateless env credentials (numbered multi-account) +#OPENAI_CODEX_1_ACCESS_TOKEN="" +#OPENAI_CODEX_1_REFRESH_TOKEN="" +#OPENAI_CODEX_1_EXPIRY_DATE="0" +#OPENAI_CODEX_1_ID_TOKEN="" +#OPENAI_CODEX_1_ACCOUNT_ID="" +#OPENAI_CODEX_1_EMAIL="" + +# OpenAI Codex routing/config +#OPENAI_CODEX_API_BASE="https://chatgpt.com/backend-api" +#OPENAI_CODEX_OAUTH_PORT=1455 +#OPENAI_CODEX_MODELS='["gpt-5.1-codex","gpt-5-codex"]' +#ROTATION_MODE_OPENAI_CODEX=sequential + # ------------------------------------------------------------------------------ # | [ADVANCED] Provider-Specific Settings | @@ -162,6 +188,7 @@ # # Provider Defaults: # - antigravity: sequential (free tier accounts with daily quotas) +# - openai_codex: sequential (account-level quota behavior) # - All others: balanced # # Example: @@ -401,8 +428,24 @@ # ------------------------------------------------------------------------------ # # OAuth callback port for Antigravity interactive re-authentication. -# Default: 8085 (same as Gemini CLI, shared) -# ANTIGRAVITY_OAUTH_PORT=8085 +# Default: 51121 +# ANTIGRAVITY_OAUTH_PORT=51121 + +# ------------------------------------------------------------------------------ +# | [ADVANCED] iFlow OAuth Configuration | +# ------------------------------------------------------------------------------ +# +# OAuth callback port for iFlow interactive re-authentication. +# Default: 11451 +# IFLOW_OAUTH_PORT=11451 + +# ------------------------------------------------------------------------------ +# | [ADVANCED] OpenAI Codex OAuth Configuration | +# ------------------------------------------------------------------------------ +# +# OAuth callback port for OpenAI Codex interactive authentication. +# Default: 1455 +# OPENAI_CODEX_OAUTH_PORT=1455 # ------------------------------------------------------------------------------ # | [ADVANCED] Debugging / Logging | diff --git a/.gitignore b/.gitignore index 3711fdfd..0a308835 100644 --- a/.gitignore +++ b/.gitignore @@ -126,9 +126,10 @@ staged_changes.txt launcher_config.json quota_viewer_config.json cache/antigravity/thought_signatures.json -logs/ -cache/ +/logs/ +/cache/ *.env -oauth_creds/ +/oauth_creds/ +/usage/ diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index f7ebbde5..fbfafd73 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -22,9 +22,9 @@ This architecture cleanly separates the API interface from the resilience logic, This library is the heart of the project, containing all the logic for managing a pool of API keys, tracking their usage, and handling provider interactions to ensure application resilience. -### 2.1. `client.py` - The `RotatingClient` +### 2.1. `client/rotating_client.py` - The `RotatingClient` -The `RotatingClient` is the central class that orchestrates all operations. It is designed as a long-lived, async-native object. +The `RotatingClient` is the central class that orchestrates all operations. It is now a slim facade that delegates to modular components (executor, filters, transforms) while remaining a long-lived, async-native object. #### Initialization @@ -35,7 +35,7 @@ client = RotatingClient( api_keys=api_keys, oauth_credentials=oauth_credentials, max_retries=2, - usage_file_path="key_usage.json", + usage_file_path="usage.json", configure_logging=True, global_timeout=30, abort_on_callback_error=True, @@ -50,7 +50,7 @@ client = RotatingClient( - `api_keys` (`Optional[Dict[str, List[str]]]`, default: `None`): A dictionary mapping provider names to a list of API keys. - `oauth_credentials` (`Optional[Dict[str, List[str]]]`, default: `None`): A dictionary mapping provider names to a list of file paths to OAuth credential JSON files. - `max_retries` (`int`, default: `2`): The number of times to retry a request with the *same key* if a transient server error occurs. -- `usage_file_path` (`str`, default: `"key_usage.json"`): The path to the JSON file where usage statistics are persisted. +- `usage_file_path` (`str`, optional): Base path for usage persistence (defaults to `usage/` in the data directory). The client stores per-provider files under `usage/usage_.json`. - `configure_logging` (`bool`, default: `True`): If `True`, configures the library's logger to propagate logs to the root logger. - `global_timeout` (`int`, default: `30`): A hard time limit (in seconds) for the entire request lifecycle. - `abort_on_callback_error` (`bool`, default: `True`): If `True`, any exception raised by `pre_request_callback` will abort the request. @@ -96,9 +96,9 @@ The `_safe_streaming_wrapper` is a critical component for stability. It: * **Error Interception**: Detects if a chunk contains an API error (like a quota limit) instead of content, and raises a specific `StreamedAPIError`. * **Quota Handling**: If a specific "quota exceeded" error is detected mid-stream multiple times, it can terminate the stream gracefully to prevent infinite retry loops on oversized inputs. -### 2.2. `usage_manager.py` - Stateful Concurrency & Usage Management +### 2.2. `usage/manager.py` - Stateful Concurrency & Usage Management -This class is the stateful core of the library, managing concurrency, usage tracking, cooldowns, and quota resets. +This class is the stateful core of the library, managing concurrency, usage tracking, cooldowns, and quota resets. Usage tracking now lives in the `rotator_library/usage/` package with per-provider managers and `usage/usage_.json` storage. #### Key Concepts @@ -205,15 +205,17 @@ The `CredentialManager` class (`credential_manager.py`) centralizes the lifecycl On startup (unless `SKIP_OAUTH_INIT_CHECK=true`), the manager performs a comprehensive sweep: -1. **System-Wide Scan**: Searches for OAuth credential files in standard locations: +1. **System-Wide Scan / Import Sources**: - `~/.gemini/` → All `*.json` files (typically `credentials.json`) - `~/.qwen/` → All `*.json` files (typically `oauth_creds.json`) - - `~/.iflow/` → All `*. json` files + - `~/.iflow/` → All `*.json` files + - `~/.codex/auth.json` + `~/.codex-accounts.json` → OpenAI Codex first-run import sources 2. **Local Import**: Valid credentials are **copied** (not moved) to the project's `oauth_creds/` directory with standardized names: - - `gemini_cli_oauth_1.json`, `gemini_cli_oauth_2.json`, etc. + - `gemini_cli_oauth_1.json`, `gemini_cli_oauth_2.json`, etc. - `qwen_code_oauth_1.json`, `qwen_code_oauth_2.json`, etc. - `iflow_oauth_1.json`, `iflow_oauth_2.json`, etc. + - `openai_codex_oauth_1.json`, `openai_codex_oauth_2.json`, etc. 3. **Intelligent Deduplication**: - The manager inspects each credential file for a `_proxy_metadata` field containing the user's email or ID @@ -292,6 +294,24 @@ IFLOW_EMAIL IFLOW_API_KEY ``` +**OpenAI Codex Environment Variables:** +``` +OPENAI_CODEX_ACCESS_TOKEN +OPENAI_CODEX_REFRESH_TOKEN +OPENAI_CODEX_EXPIRY_DATE +OPENAI_CODEX_ID_TOKEN +OPENAI_CODEX_ACCOUNT_ID +OPENAI_CODEX_EMAIL + +# Numbered multi-account format +OPENAI_CODEX_1_ACCESS_TOKEN +OPENAI_CODEX_1_REFRESH_TOKEN +OPENAI_CODEX_1_EXPIRY_DATE +OPENAI_CODEX_1_ID_TOKEN +OPENAI_CODEX_1_ACCOUNT_ID +OPENAI_CODEX_1_EMAIL +``` + **How it works:** - If the manager finds (e.g.) `GEMINI_CLI_ACCESS_TOKEN` or `GEMINI_CLI_1_ACCESS_TOKEN`, it constructs an in-memory credential object that mimics the file structure - The credential is referenced internally as `env://gemini_cli/0` (legacy) or `env://gemini_cli/1` (numbered) @@ -304,9 +324,11 @@ IFLOW_API_KEY env://{provider}/{index} Examples: -- env://gemini_cli/1 → GEMINI_CLI_1_ACCESS_TOKEN, etc. -- env://gemini_cli/0 → GEMINI_CLI_ACCESS_TOKEN (legacy single credential) -- env://antigravity/1 → ANTIGRAVITY_1_ACCESS_TOKEN, etc. +- env://gemini_cli/1 → GEMINI_CLI_1_ACCESS_TOKEN, etc. +- env://gemini_cli/0 → GEMINI_CLI_ACCESS_TOKEN (legacy single credential) +- env://antigravity/1 → ANTIGRAVITY_1_ACCESS_TOKEN, etc. +- env://openai_codex/1 → OPENAI_CODEX_1_ACCESS_TOKEN, etc. +- env://openai_codex/0 → OPENAI_CODEX_ACCESS_TOKEN (legacy single credential) ``` #### 2.6.3. Credential Tool Integration @@ -314,7 +336,7 @@ Examples: The `credential_tool.py` provides a user-friendly CLI interface to the `CredentialManager`: **Key Functions:** -1. **OAuth Setup**: Wraps provider-specific `AuthBase` classes (`GeminiAuthBase`, `QwenAuthBase`, `IFlowAuthBase`) to handle interactive login flows +1. **OAuth Setup**: Wraps provider-specific `AuthBase` classes (`GeminiAuthBase`, `QwenAuthBase`, `IFlowAuthBase`, `OpenAICodexAuthBase`) to handle interactive login flows 2. **Credential Export**: Reads local `.json` files and generates `.env` format output for stateless deployment 3. **API Key Management**: Adds or updates `PROVIDER_API_KEY_N` entries in the `.env` file @@ -419,7 +441,7 @@ The `CooldownManager` handles IP or account-level rate limiting that affects all - All subsequent `acquire_key()` calls for that provider will wait until the cooldown expires -### 2.10. Credential Prioritization System (`client.py` & `usage_manager.py`) +### 2.10. Credential Prioritization System (`client/rotating_client.py` & `usage/manager.py`) The library now includes an intelligent credential prioritization system that automatically detects credential tiers and ensures optimal credential selection for each request. @@ -762,7 +784,7 @@ Acquiring key for model antigravity/claude-opus-4.5. Tried keys: 0/12(17,cd:3,fc ``` **Persistence:** -Cycle state is persisted in `key_usage.json` under the `__fair_cycle__` key. +Cycle state is persisted alongside usage data in `usage/usage_.json`. ### 2.20. Custom Caps @@ -1426,12 +1448,13 @@ Each OAuth provider uses a local callback server during authentication. The call | Gemini CLI | 8085 | `GEMINI_CLI_OAUTH_PORT` | | Antigravity | 51121 | `ANTIGRAVITY_OAUTH_PORT` | | iFlow | 11451 | `IFLOW_OAUTH_PORT` | +| OpenAI Codex | 1455 | `OPENAI_CODEX_OAUTH_PORT` | **Configuration Methods:** 1. **Via TUI Settings Menu:** - Main Menu → `4. View Provider & Advanced Settings` → `1. Launch Settings Tool` - - Select the provider (Gemini CLI, Antigravity, or iFlow) + - Select the provider (Gemini CLI, Antigravity, iFlow, or OpenAI Codex) - Modify the `*_OAUTH_PORT` setting - Use "Reset to Default" to restore the original port @@ -1441,6 +1464,7 @@ Each OAuth provider uses a local callback server during authentication. The call GEMINI_CLI_OAUTH_PORT=8085 ANTIGRAVITY_OAUTH_PORT=51121 IFLOW_OAUTH_PORT=11451 + OPENAI_CODEX_OAUTH_PORT=1455 ``` **When to Change Ports:** @@ -1528,7 +1552,7 @@ The following providers use `TimeoutConfig`: | `iflow_provider.py` | `acompletion()` | `streaming()` | | `qwen_code_provider.py` | `acompletion()` | `streaming()` | -**Note:** iFlow, Qwen Code, and Gemini CLI providers always use streaming internally (even for non-streaming requests), aggregating chunks into a complete response. Only Antigravity has a true non-streaming path. +**Note:** iFlow, Qwen Code, Gemini CLI, and OpenAI Codex providers always use streaming internally (even for non-streaming requests), aggregating chunks into a complete response. Only Antigravity has a true non-streaming path. #### Tuning Recommendations @@ -1649,7 +1673,23 @@ QUOTA_GROUPS_GEMINI_CLI_3_FLASH="gemini-3-flash-preview" * **Schema Cleaning**: Similar to Qwen, it aggressively sanitizes tool schemas to prevent 400 errors. * **Dedicated Logging**: Implements `_IFlowFileLogger` to capture raw chunks for debugging proprietary API behaviors. -### 3.4. Google Gemini (`gemini_provider.py`) +### 3.4. OpenAI Codex (`openai_codex_provider.py`) + +* **Auth Base**: Uses `OpenAICodexAuthBase` with Authorization Code + PKCE, queue-based refresh/re-auth, and local-first credential persistence (`oauth_creds/openai_codex_oauth_*.json`). +* **First-Run Import**: `CredentialManager` imports from `~/.codex/auth.json` and `~/.codex-accounts.json` when no local/OpenAI Codex env creds exist. +* **Endpoint Translation**: Implements OpenAI-compatible `/v1/chat/completions` by transforming chat payloads into Codex Responses payloads and calling `POST /codex/responses`. +* **SSE Translation**: Maps Codex SSE event families (e.g. `response.output_item.*`, `response.output_text.delta`, `response.function_call_arguments.*`, `response.completed`) into LiteLLM/OpenAI chunk objects. +* **Rotation Compatibility**: Emits typed `httpx.HTTPStatusError` for transport/status failures and includes provider-specific `parse_quota_error()` for retry/cooldown extraction (`Retry-After`, `error.resets_at`). +* **Default Rotation**: `sequential` (account-level quota behavior). + +**OpenAI Codex Troubleshooting Notes:** + +- **Malformed JWT payload**: If access/id tokens cannot be decoded, account/email metadata can be missing; re-authenticate to rebuild token metadata. +- **Missing account-id claim**: Requests require `chatgpt-account-id`; if absent, refresh/re-auth to repopulate `_proxy_metadata.account_id`. +- **Callback port conflicts**: Change `OPENAI_CODEX_OAUTH_PORT` when port `1455` is already in use. +- **Header mismatch / 403**: Ensure provider sends `Authorization`, `chatgpt-account-id`, and expected Codex headers (`OpenAI-Beta`, `originator`) when routing to `/codex/responses`. + +### 3.5. Google Gemini (`gemini_provider.py`) * **Thinking Parameter**: Automatically handles the `thinking` parameter transformation required for Gemini 2.5 models (`thinking` -> `gemini-2.5-pro` reasoning parameter). * **Safety Settings**: Ensures default safety settings (blocking nothing) are applied if not provided, preventing over-sensitive refusals. @@ -1773,7 +1813,7 @@ The system follows a strict hierarchy of survival: 2. **Credential Management (Level 2)**: OAuth tokens are cached in memory first. If credential files are deleted, the proxy continues using cached tokens. If a token refresh succeeds but the file cannot be written, the new token is buffered for retry and saved on shutdown. -3. **Usage Tracking (Level 3)**: Usage statistics (`key_usage.json`) are maintained in memory via `ResilientStateWriter`. If the file is deleted, the system tracks usage internally and attempts to recreate the file on the next save interval. Pending writes are flushed on shutdown. +3. **Usage Tracking (Level 3)**: Usage statistics (`usage/usage_.json`) are maintained in memory via `ResilientStateWriter`. If the file is deleted, the system tracks usage internally and attempts to recreate the file on the next save interval. Pending writes are flushed on shutdown. 4. **Provider Cache (Level 4)**: The provider cache tracks disk health and continues operating in memory-only mode if disk writes fail. Has its own shutdown mechanism. @@ -1813,7 +1853,7 @@ INFO:rotator_library.resilient_io:Shutdown flush: all 2 write(s) succeeded This architecture supports a robust development workflow: - **Log Cleanup**: You can safely run `rm -rf logs/` while the proxy is serving traffic. The system will recreate the directory structure on the next request. -- **Config Reset**: Deleting `key_usage.json` resets the persistence layer, but the running instance preserves its current in-memory counts for load balancing consistency. +- **Config Reset**: Deleting `usage/usage_.json` resets the persistence layer, but the running instance preserves its current in-memory counts for load balancing consistency. - **File Recovery**: If you delete a critical file, the system attempts directory auto-recreation before every write operation. - **Safe Exit**: Ctrl+C triggers graceful shutdown with final data flush attempt. diff --git a/Dockerfile b/Dockerfile index aafcb117..fe209886 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Build stage -FROM python:3.11-slim AS builder +FROM python:3.12-slim AS builder WORKDIR /app @@ -21,7 +21,7 @@ COPY src/rotator_library ./src/rotator_library RUN pip install --no-cache-dir --user -r requirements.txt # Production stage -FROM python:3.11-slim +FROM python:3.12-slim WORKDIR /app diff --git a/PLAN-openai-codex.md b/PLAN-openai-codex.md new file mode 100644 index 00000000..0a528c3f --- /dev/null +++ b/PLAN-openai-codex.md @@ -0,0 +1,459 @@ +# PLAN: OpenAI Codex OAuth + Multi-Account Support (Revised) + +## Goal +Add first-class `openai_codex` support to LLM-API-Key-Proxy with: +- OAuth login + token refresh +- file/env credential loading +- multi-account rotation via existing `UsageManager` +- OpenAI-compatible `/v1/chat/completions` served through Codex Responses backend +- first-run import from existing Codex CLI credentials (`~/.codex/auth.json`, `~/.codex-accounts.json`) + +--- + +## Review updates applied in this revision + +- Aligned with current local-first architecture: **local managed creds stay in `oauth_creds/`**, not `~/.openai_codex`. +- Reduced MVP risk: **no cross-provider OAuth base refactor in phase 1**. +- Added protocol validation gate (headers/endpoints/SSE event taxonomy) before implementation. +- Expanded wiring checklist to all known hardcoded OAuth provider lists (credential tool, launcher TUI, settings tool). +- Added explicit `env://openai_codex/N` parity requirements and test-harness bootstrap work. + +--- + +## 0) Scope decisions + preflight validation (must lock before coding) + +### 0.1 Provider identity and defaults + +- [x] Provider key: `openai_codex` +- [x] OAuth env prefix: `OPENAI_CODEX` +- [x] Default API base: `https://chatgpt.com/backend-api` +- [x] Responses endpoint path: `/codex/responses` +- [x] Default rotation mode for provider: `sequential` +- [x] Callback env var: `OPENAI_CODEX_OAUTH_PORT` +- [x] JWT parsing strategy: unverified base64url decode (no `PyJWT` dependency) + +### 0.2 Architecture alignment (critical) + +- [x] Keep **local managed credentials** in project data dir: `oauth_creds/openai_codex_oauth_N.json` + - [x] Match existing patterns in `src/rotator_library/utils/paths.py` and other auth bases + - [x] Do **not** introduce a new default managed dir under `~/.openai_codex` for MVP +- [x] Treat `~/.codex/*` only as **import source**, never as primary writable store + +### 0.3 Protocol truth capture (before implementation) + +- [x] Capture one successful non-stream + stream Codex call and confirm: + - [x] Auth endpoint(s) and token exchange params + - [x] Required request headers (`chatgpt-account-id`, `OpenAI-Beta`, `originator`, etc.) + - [x] SSE event names/payload shapes + - [x] Error body format for 401/403/429/5xx +- [x] Save representative payloads/events as test fixtures under `tests/fixtures/openai_codex/` + +--- + +## 1) OAuth + credential plumbing + +## 1.1 Add OpenAI Codex auth base (MVP approach: provider-specific class) + +- [x] Create `src/rotator_library/providers/openai_codex_auth_base.py` +- [x] Base implementation strategy for MVP: + - [x] Adapt proven queue/refresh/reauth approach from `qwen_auth_base.py` / `iflow_auth_base.py` + - [x] **Do not** refactor `GoogleOAuthBase` or create shared `oauth_base.py` in phase 1 + +### 1.1.1 Core lifecycle and queue infrastructure + +- [x] Implement credential cache/locking/queue internals: + - [x] `_credentials_cache`, `_load_credentials()`, `_save_credentials()` + - [x] `_refresh_locks`, `_locks_lock`, `_get_lock()` + - [x] `_refresh_queue`, `_reauth_queue` + - [x] `_queue_refresh()`, `_process_refresh_queue()`, `_process_reauth_queue()` + - [x] `_refresh_failures`, `_next_refresh_after` (backoff tracking) + - [x] `_queued_credentials`, `_unavailable_credentials`, TTL cleanup +- [x] Implement `is_credential_available(path)` with: + - [x] re-auth queue exclusion + - [x] true-expiry check (not proactive buffer) +- [x] Implement `proactively_refresh(credential_identifier)` queue-based behavior + +### 1.1.2 OAuth flow and refresh behavior + +- [x] Interactive OAuth with PKCE + state + - [x] Local callback: `http://localhost:{OPENAI_CODEX_OAUTH_PORT}/auth/callback` + - [x] `ReauthCoordinator` integration (single interactive flow globally) +- [x] Token exchange endpoint: `https://auth.openai.com/oauth/token` +- [x] Authorization endpoint: `https://auth.openai.com/oauth/authorize` +- [x] Refresh flow (`grant_type=refresh_token`) with retry/backoff (3 attempts) +- [x] Refresh error handling: + - [x] `400 invalid_grant` => queue re-auth + raise `CredentialNeedsReauthError` + - [x] `401/403` => queue re-auth + raise `CredentialNeedsReauthError` + - [x] `429` => honor `Retry-After` + - [x] `5xx` => exponential backoff retry + +### 1.1.3 Safe persistence semantics (critical) + +- [x] `_save_credentials()` uses `safe_write_json(..., secure_permissions=True)` +- [x] For rotating refresh-token safety: + - [x] Write-to-disk success required before cache mutation for refreshed tokens + - [x] Avoid stale-cache overwrite scenarios +- [x] Env-backed credentials (`_proxy_metadata.loaded_from_env=true`) skip disk writes safely + +### 1.1.4 JWT and metadata extraction + +- [x] Add unverified JWT decode helper (base64url payload decode with padding) +- [x] Extract from access token (fallback to `id_token`): + - [x] `account_id` claim: `https://api.openai.com/auth.chatgpt_account_id` + - [x] email claim fallback chain: `email` -> `sub` + - [x] `exp` for token expiry +- [x] Maintain metadata under `_proxy_metadata`: + - [x] `email`, `account_id`, `last_check_timestamp` + - [x] `loaded_from_env`, `env_credential_index` + +### 1.1.5 Env credential support + +- [x] Support both formats in `_load_from_env()`: + - [x] legacy single: `OPENAI_CODEX_ACCESS_TOKEN`, `OPENAI_CODEX_REFRESH_TOKEN`, ... + - [x] numbered: `OPENAI_CODEX_1_ACCESS_TOKEN`, `OPENAI_CODEX_1_REFRESH_TOKEN`, ... +- [x] Implement `_parse_env_credential_path(path)` for `env://openai_codex/N` +- [x] Ensure `_load_credentials()` works for file paths **and** `env://` virtual paths + +### 1.1.6 Public methods expected by tooling/runtime + +- [x] `setup_credential()` +- [x] `initialize_token(path_or_creds, force_interactive=False)` +- [x] `get_user_info(creds_or_path)` +- [x] `get_auth_header(credential_identifier)` +- [x] `list_credentials(base_dir)` +- [x] `delete_credential(path)` +- [x] `build_env_lines(creds, cred_number)` +- [x] `export_credential_to_env(credential_path, base_dir)` (used by credential tool export flows) +- [x] `_get_provider_file_prefix() -> "openai_codex"` + +### 1.1.7 Credential schema (`openai_codex_oauth_N.json`) + +```json +{ + "access_token": "eyJhbGciOi...", + "refresh_token": "rt_...", + "id_token": "eyJhbGciOi...", + "expiry_date": 1739400000000, + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "user@example.com", + "account_id": "acct_...", + "last_check_timestamp": 1739396400.0, + "loaded_from_env": false, + "env_credential_index": null + } +} +``` + +> Note: client metadata like `client_id` should be class constants unless Codex token refresh explicitly requires persisted values. + +--- + +## 1.2 First-run import from Codex CLI credentials (CredentialManager integration) + +- [x] Update `src/rotator_library/credential_manager.py` to add Codex import helper + - [x] Trigger only when: + - [x] provider is `openai_codex` + - [x] no local `oauth_creds/openai_codex_oauth_*.json` + - [x] no env-based OpenAI Codex credentials already selected +- [x] Import sources (read-only): + - [x] `~/.codex/auth.json` (single account) + - [x] `~/.codex-accounts.json` (multi-account) +- [x] Normalize imported records to proxy schema +- [x] Extract and store `account_id` + email from JWT claims during import +- [x] Skip malformed entries gracefully with warnings +- [x] Preserve original source files untouched +- [x] Log import summary (count + identifiers) + +--- + +## 1.3 Wire registries and discovery maps + +- [x] Update `src/rotator_library/provider_factory.py` + - [x] Import `OpenAICodexAuthBase` + - [x] Add `"openai_codex": OpenAICodexAuthBase` to `PROVIDER_MAP` +- [x] Update `src/rotator_library/credential_manager.py` + - [x] Add to `DEFAULT_OAUTH_DIRS`: `"openai_codex": Path.home() / ".codex"` (source import context) + - [x] Add to `ENV_OAUTH_PROVIDERS`: `"openai_codex": "OPENAI_CODEX"` + +--- + +## 1.4 Wire credential UI, launcher UI, and settings UI + +### 1.4.1 Credential tool updates (`src/rotator_library/credential_tool.py`) + +- [x] Add to `OAUTH_FRIENDLY_NAMES`: `"openai_codex": "OpenAI Codex"` +- [x] Add to OAuth provider lists: + - [x] `_get_oauth_credentials_summary()` hardcoded list + - [x] `combine_all_credentials()` hardcoded list +- [x] Add to OAuth-only exclusions in API-key flow: + - [x] `oauth_only_providers` in `setup_api_key()` +- [x] Add to setup display mapping in `setup_new_credential()` +- [x] Export support: + - [x] Add OpenAI Codex export option(s) or refactor export menu to provider-driven generic flow + - [x] Ensure combine/export features call new auth-base methods + +### 1.4.2 Launcher TUI updates (`src/proxy_app/launcher_tui.py`) + +- [x] Add `"openai_codex": "OPENAI_CODEX"` to `env_oauth_providers` in `SettingsDetector.detect_credentials()` + +### 1.4.3 Settings tool updates (`src/proxy_app/settings_tool.py`) + +- [x] Import Codex default callback port from auth class with fallback constant +- [x] Add provider settings block for `openai_codex`: + - [x] `OPENAI_CODEX_OAUTH_PORT` +- [x] Register `openai_codex` in `PROVIDER_SETTINGS_MAP` + +--- + +## 1.5 Provider plugin auto-registration verification + +- [x] Create `src/rotator_library/providers/openai_codex_provider.py` + - [x] Confirm `providers/__init__.py` auto-registers as `openai_codex` +- [x] Verify name consistency across all maps/lists: + - [x] `PROVIDER_MAP` (`provider_factory.py`) + - [x] `DEFAULT_OAUTH_DIRS` / `ENV_OAUTH_PROVIDERS` (`credential_manager.py`) + - [x] `OAUTH_FRIENDLY_NAMES` + hardcoded OAuth lists (`credential_tool.py`) + - [x] `env_oauth_providers` (`launcher_tui.py`) + - [x] `PROVIDER_SETTINGS_MAP` (`settings_tool.py`) + +--- + +## 2) Codex inference provider (`openai_codex_provider.py`) + +## 2.1 Provider class skeleton + +- [x] Implement `OpenAICodexProvider(OpenAICodexAuthBase, ProviderInterface)` +- [x] Set class behavior: + - [x] `has_custom_logic() -> True` + - [x] `skip_cost_calculation = True` + - [x] `default_rotation_mode = "sequential"` + - [x] `provider_env_name = "openai_codex"` +- [x] `get_models()` model source order: + - [x] `OPENAI_CODEX_MODELS` via `ModelDefinitions` (priority) + - [x] hardcoded sane fallback models + - [x] optional dynamic discovery if Codex endpoint supports model listing + +## 2.2 Credential initialization + metadata cache + +- [x] Implement `initialize_credentials(credential_paths)` startup hook: + - [x] preload credentials (file + `env://`) + - [x] validate expiry and queue refresh where needed + - [x] parse/cache `account_id` and email + - [x] log summary of ready/refreshing/reauth-required credentials + +## 2.3 Non-streaming completion path + +- [x] Implement `acompletion()` for `stream=false` +- [x] Credential handling: + - [x] use `credential_identifier` from client + - [x] support file + `env://` paths consistently (no `os.path.isfile` shortcut assumptions) + - [x] ensure `initialize_token()` called before request when needed +- [x] Transform incoming OpenAI chat payload to Codex Responses payload: + - [x] `messages` -> Codex `input` + - [x] `model`, `temperature`, `top_p`, `max_tokens` + - [x] tools/tool_choice mapping where supported +- [x] Request target: + - [x] `POST ${OPENAI_CODEX_API_BASE or default}/codex/responses` +- [x] Required headers: + - [x] `Authorization: Bearer ` + - [x] `chatgpt-account-id: ` + - [x] protocol-validated beta/originator headers from preflight +- [x] Parse response into `litellm.ModelResponse` + +## 2.4 Streaming path + SSE translation + +- [x] Implement dedicated SSE parser/translator +- [x] Handle expected Codex event families (validated from fixtures): + - [x] `response.created` + - [x] `response.output_item.added` + - [x] `response.content_part.added` + - [x] `response.content_part.delta` + - [x] `response.content_part.done` + - [x] `response.output_item.done` + - [x] `response.completed` + - [x] `response.failed` / `response.incomplete` + - [x] `error` +- [x] Tool-call delta mapping: + - [x] `response.function_call_arguments.delta` + - [x] `response.function_call_arguments.done` +- [x] Emit translated `litellm.ModelResponse` chunks (not raw SSE strings) + - [x] compatible with `RotatingClient._safe_streaming_wrapper()` +- [x] Finish reason mapping: + - [x] stop -> `stop` + - [x] max_output_tokens -> `length` + - [x] tool_calls -> `tool_calls` + - [x] content_filter -> `content_filter` +- [x] Usage extraction from terminal event: + - [x] `input_tokens` -> `usage.prompt_tokens` + - [x] `output_tokens` -> `usage.completion_tokens` + - [x] `total_tokens` -> `usage.total_tokens` +- [x] Unknown events: + - [x] ignore safely with debug logs + - [x] do not break stream unless terminal error condition + +## 2.5 Error classification + rotation compatibility + +- [x] Ensure HTTP errors surface as `httpx.HTTPStatusError` (or equivalent classified exceptions) +- [x] Validate classification in existing `classify_error()` flow (`error_handler.py`): + - [x] 401/403 => authentication/forbidden -> rotate credential + - [x] 429 => rate_limit/quota_exceeded -> cooldown/rotate + - [x] 5xx => server_error -> retry/rotate + - [x] context-length style 400 => `context_window_exceeded` +- [x] Implement `@staticmethod parse_quota_error(error, error_body=None)` on provider + - [x] parse `Retry-After` + - [x] parse Codex-specific quota payload fields if present + +## 2.6 Quota/tier placeholders (MVP-safe defaults) + +- [x] Add conservative placeholders: + - [x] `tier_priorities` + - [x] `usage_reset_configs` + - [x] `model_quota_groups` +- [x] Mark with TODOs for empirical tuning once real quota behavior is observed + +--- + +## 3) Configuration + documentation updates + +## 3.1 `.env.example` + +- [x] Add one-time file import path: + - [x] `OPENAI_CODEX_OAUTH_1` +- [x] Add stateless env credential vars (legacy + numbered): + - [x] `OPENAI_CODEX_ACCESS_TOKEN` + - [x] `OPENAI_CODEX_REFRESH_TOKEN` + - [x] `OPENAI_CODEX_EXPIRY_DATE` + - [x] `OPENAI_CODEX_ID_TOKEN` + - [x] `OPENAI_CODEX_ACCOUNT_ID` + - [x] `OPENAI_CODEX_EMAIL` + - [x] `OPENAI_CODEX_1_*` variants +- [x] Add routing/config vars: + - [x] `OPENAI_CODEX_API_BASE` + - [x] `OPENAI_CODEX_OAUTH_PORT` + - [x] `OPENAI_CODEX_MODELS` + - [x] `ROTATION_MODE_OPENAI_CODEX` + +## 3.2 `README.md` + +- [x] Add OpenAI Codex to OAuth provider lists/tables +- [x] Add setup instructions: + - [x] interactive OAuth flow + - [x] first-run auto-import from `~/.codex/*` + - [x] env-based stateless deployment format +- [x] Add callback-port table row for OpenAI Codex + +## 3.3 `DOCUMENTATION.md` + +- [x] Update credential discovery/import flow to include Codex source files +- [x] Add OpenAI Codex auth/provider architecture section +- [x] Document schema + env vars + runtime refresh/rotation behavior +- [x] Add troubleshooting section: + - [x] malformed JWT payload + - [x] missing account-id claim + - [x] callback port conflicts + - [x] header mismatch / 403 failures + +--- + +## 4) Tests + +## 4.0 Test harness bootstrap (repo currently has no test suite) + +- [x] Add test directory structure: `tests/` +- [x] Add test dependencies (`pytest`, `pytest-asyncio`, `respx` or equivalent) +- [x] Add minimal test run documentation/command + +## 4.1 Auth base tests (`tests/test_openai_codex_auth.py`) + +- [x] JWT decode helper: + - [x] valid token + - [x] malformed token + - [x] missing claims +- [x] expiry logic: + - [x] `_is_token_expired()` with proactive buffer + - [x] `_is_token_truly_expired()` strict expiry +- [x] env loading: + - [x] legacy vars + - [x] numbered vars + - [x] `env://openai_codex/N` parsing +- [x] save/load round-trip with `_proxy_metadata` +- [x] re-auth queue availability behavior (`is_credential_available`) + +## 4.2 Import tests (`tests/test_openai_codex_import.py`) + +- [x] import from `~/.codex/auth.json` format +- [x] import from `~/.codex-accounts.json` format +- [x] skip import when local `openai_codex_oauth_*.json` exists +- [x] malformed source files handled gracefully +- [x] source files never modified + +## 4.3 Provider request mapping tests (`tests/test_openai_codex_provider.py`) + +- [x] chat request mapping to Codex Responses payload +- [x] non-stream response mapping to `ModelResponse` +- [x] header construction includes account-id + auth headers +- [x] env credential identifiers work (no file-only assumptions) + +## 4.4 SSE translation tests (`tests/test_openai_codex_sse.py`) + +- [x] fixture-driven event sequence -> expected chunk sequence +- [x] content deltas +- [x] tool-call deltas +- [x] finish reason mapping +- [x] usage extraction +- [x] error event propagation +- [x] unknown event tolerance + +## 4.5 Wiring regression tests (lightweight) + +- [x] credential discovery recognizes OpenAI Codex env vars +- [x] provider_factory returns OpenAICodexAuthBase +- [x] `providers` auto-registration includes `openai_codex` + +--- + +## 5) Manual smoke-test checklist + +- [x] `python -m rotator_library.credential_tool` shows **OpenAI Codex** in OAuth setup list +- [x] OpenAI Codex is excluded from API-key setup list (`oauth_only_providers`) +- [x] first run with no local creds imports from `~/.codex/*` into `oauth_creds/openai_codex_oauth_*.json` +- [x] env-based `env://openai_codex/N` credentials are detected and used +- [x] `/v1/models` includes `openai_codex/*` models +- [x] `/v1/chat/completions` works for: + - [x] `stream=false` + - [x] `stream=true` +- [x] expired token refresh works (proactive + on-demand) +- [x] invalid refresh token queues re-auth and rotates to next credential +- [x] `is_credential_available()` returns false for re-auth queued / truly expired creds +- [x] multi-account rotation works in: + - [x] `sequential` (default) + - [x] `balanced` (override) +- [x] launcher/settings UIs show Codex OAuth counts and callback-port setting correctly + +--- + +## 6) Optional phase 2 (post-MVP) + +- [ ] Extract common OAuth queue/cache logic into shared base mixin for `google_oauth_base`, `qwen_auth_base`, `iflow_auth_base`, and Codex +- [ ] Refactor credential tool OAuth provider lists/exports to dynamic provider-driven implementation +- [ ] Add `model_info_service` alias mapping for `openai_codex` if pricing/capability enrichment is desired +- [ ] Tune tier priorities/quota windows from observed production behavior +- [ ] Add periodic background reconciliation from external `~/.codex` stores if needed + +--- + +## Proposed implementation order + +1. **Protocol validation gate** — lock endpoints/headers/events from real fixtures +2. **Auth base** — `openai_codex_auth_base.py` (queue + refresh + reauth + env support) +3. **First-run import** — CredentialManager import flow for `~/.codex/*` +4. **Registry/discovery wiring** — provider_factory + credential_manager maps +5. **UI wiring** — credential_tool + launcher_tui + settings_tool +6. **Provider skeleton** — `openai_codex_provider.py`, model list, startup init +7. **Non-streaming completion** — request mapping + response mapping +8. **Streaming translator** — SSE event translation + tool calls + usage +9. **Error/quota integration** — `parse_quota_error`, retry/cooldown compatibility +10. **Tests** — harness + auth/import/provider/SSE/wiring tests +11. **Docs/config** — `.env.example`, `README.md`, `DOCUMENTATION.md` +12. **Manual smoke validation** — end-to-end checklist diff --git a/README.md b/README.md index a7c3c438..1fd78d4e 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ docker run -d \ -v $(pwd)/.env:/app/.env:ro \ -v $(pwd)/oauth_creds:/app/oauth_creds \ -v $(pwd)/logs:/app/logs \ + -v $(pwd)/usage:/app/usage \ -e SKIP_OAUTH_INIT_CHECK=true \ ghcr.io/mirrowel/llm-api-key-proxy:latest ``` @@ -60,13 +61,13 @@ docker run -d \ **Using Docker Compose:** ```bash -# Create your .env file and key_usage.json first, then: +# Create your .env file and usage directory first, then: cp .env.example .env -touch key_usage.json +mkdir usage docker compose up -d ``` -> **Important:** You must create both `.env` and `key_usage.json` files before running Docker Compose. If `key_usage.json` doesn't exist, Docker will create it as a directory instead of a file, causing errors. +> **Important:** Create the `usage/` directory before running Docker Compose so usage stats persist on the host. > **Note:** For OAuth providers, complete authentication locally first using the credential tool, then mount the `oauth_creds/` directory or export credentials to environment variables. @@ -105,6 +106,7 @@ anthropic/claude-3-5-sonnet ← Anthropic API openrouter/anthropic/claude-3-opus ← OpenRouter gemini_cli/gemini-2.5-pro ← Gemini CLI (OAuth) antigravity/gemini-3-pro-preview ← Antigravity (Gemini 3, Claude Opus 4.5) +openai_codex/gpt-5.1-codex ← OpenAI Codex (ChatGPT OAuth) ``` ### Usage Examples @@ -263,7 +265,7 @@ python -m rotator_library.credential_tool | Type | Providers | How to Add | |------|-----------|------------| | **API Keys** | Gemini, OpenAI, Anthropic, OpenRouter, Groq, Mistral, NVIDIA, Cohere, Chutes | Enter key in TUI or add to `.env` | -| **OAuth** | Gemini CLI, Antigravity, Qwen Code, iFlow | Interactive browser login via credential tool | +| **OAuth** | Gemini CLI, Antigravity, Qwen Code, iFlow, OpenAI Codex | Interactive browser login via credential tool | ### The `.env` File @@ -294,7 +296,7 @@ The proxy is powered by a standalone Python library that you can use directly in - **Intelligent key selection** with tiered, model-aware locking - **Deadline-driven requests** with configurable global timeout - **Automatic failover** between keys on errors -- **OAuth support** for Gemini CLI, Antigravity, Qwen, iFlow +- **OAuth support** for Gemini CLI, Antigravity, Qwen, iFlow, OpenAI Codex - **Stateless deployment ready** — load credentials from environment variables ### Basic Usage @@ -335,10 +337,11 @@ The proxy includes a powerful text-based UI for configuration and management. ### TUI Features - **🚀 Run Proxy** — Start the server with saved settings -- **⚙️ Configure Settings** — Host, port, API key, request logging +- **⚙️ Configure Settings** — Host, port, API key, request logging, raw I/O logging - **🔑 Manage Credentials** — Add/edit API keys and OAuth credentials -- **📊 View Status** — See configured providers and credential counts -- **🔧 Advanced Settings** — Custom providers, model definitions, concurrency +- **📊 View Provider & Advanced Settings** — Inspect providers and launch the settings tool +- **📈 View Quota & Usage Stats (Alpha)** — Usage, quota windows, fair-cycle status +- **🔄 Reload Configuration** — Refresh settings without restarting ### Configuration Files @@ -346,6 +349,8 @@ The proxy includes a powerful text-based UI for configuration and management. |------|----------| | `.env` | All credentials and advanced settings | | `launcher_config.json` | TUI-specific settings (host, port, logging) | +| `quota_viewer_config.json` | Quota viewer remotes + per-provider display toggles | +| `usage/usage_.json` | Usage persistence per provider | --- @@ -375,7 +380,7 @@ The proxy includes a powerful text-based UI for configuration and management. 🔑 Credential Management - **Auto-discovery** of API keys from environment variables -- **OAuth discovery** from standard paths (`~/.gemini/`, `~/.qwen/`, `~/.iflow/`) +- **OAuth discovery/import** from standard paths (`~/.gemini/`, `~/.qwen/`, `~/.iflow/`, `~/.codex/`) - **Duplicate detection** warns when same account added multiple times - **Credential prioritization** — paid tier used before free tier - **Stateless deployment** — export OAuth to environment variables @@ -435,6 +440,13 @@ The proxy includes a powerful text-based UI for configuration and management. - Hybrid auth with separate API key fetch - Tool schema cleaning +**OpenAI Codex:** + +- ChatGPT OAuth Authorization Code + PKCE +- Codex Responses backend (`/codex/responses`) behind OpenAI-compatible `/v1/chat/completions` +- First-run import from `~/.codex/auth.json` + `~/.codex-accounts.json` +- Sequential multi-account rotation + env credential parity (`env://openai_codex/N`) + **NVIDIA NIM:** - Dynamic model discovery @@ -446,10 +458,11 @@ The proxy includes a powerful text-based UI for configuration and management. 📝 Logging & Debugging - **Per-request file logging** with `--enable-request-logging` +- **Raw I/O logging** with `--enable-raw-logging` (proxy boundary payloads) - **Unique request directories** with full transaction details - **Streaming chunk capture** for debugging - **Performance metadata** (duration, tokens, model used) -- **Provider-specific logs** for Qwen, iFlow, Antigravity +- **Provider-specific logs** for Qwen, iFlow, Antigravity, OpenAI Codex @@ -748,6 +761,60 @@ Uses OAuth Authorization Code flow with local callback server. +
+OpenAI Codex + +Uses ChatGPT OAuth credentials and routes requests to the Codex Responses backend. + +**Setup:** + +1. Run the credential tool +2. Select "Add OAuth Credential" → "OpenAI Codex" +3. Complete browser auth flow (local callback server) +4. On first run, existing Codex CLI credentials are auto-imported from: + - `~/.codex/auth.json` + - `~/.codex-accounts.json` + +Imported credentials are normalized and stored locally as: + +- `oauth_creds/openai_codex_oauth_1.json` +- `oauth_creds/openai_codex_oauth_2.json` +- ... + +**Features:** + +- OAuth Authorization Code + PKCE +- Automatic refresh + re-auth queueing +- File-based and stateless env credentials (`env://openai_codex/N`) +- Sequential rotation by default (`ROTATION_MODE_OPENAI_CODEX=sequential`) +- OpenAI-compatible `/v1/chat/completions` via Codex Responses backend + +**Environment Variables (stateless mode):** + +```env +# Single credential (legacy) +OPENAI_CODEX_ACCESS_TOKEN="..." +OPENAI_CODEX_REFRESH_TOKEN="..." +OPENAI_CODEX_EXPIRY_DATE="1739400000000" +OPENAI_CODEX_ID_TOKEN="..." +OPENAI_CODEX_ACCOUNT_ID="acct_..." +OPENAI_CODEX_EMAIL="user@example.com" + +# Numbered multi-credential +OPENAI_CODEX_1_ACCESS_TOKEN="..." +OPENAI_CODEX_1_REFRESH_TOKEN="..." +OPENAI_CODEX_1_EXPIRY_DATE="1739400000000" +OPENAI_CODEX_1_ID_TOKEN="..." +OPENAI_CODEX_1_ACCOUNT_ID="acct_..." +OPENAI_CODEX_1_EMAIL="user1@example.com" + +OPENAI_CODEX_API_BASE="https://chatgpt.com/backend-api" +OPENAI_CODEX_OAUTH_PORT=1455 +ROTATION_MODE_OPENAI_CODEX=sequential +``` + +
+
Stateless Deployment (Export to Environment Variables) @@ -779,11 +846,12 @@ For platforms without file persistence (Railway, Render, Vercel): Customize OAuth callback ports if defaults conflict: -| Provider | Default Port | Environment Variable | -| ----------- | ------------ | ------------------------ | -| Gemini CLI | 8085 | `GEMINI_CLI_OAUTH_PORT` | -| Antigravity | 51121 | `ANTIGRAVITY_OAUTH_PORT` | -| iFlow | 11451 | `IFLOW_OAUTH_PORT` | +| Provider | Default Port | Environment Variable | +| ------------ | ------------ | ------------------------- | +| Gemini CLI | 8085 | `GEMINI_CLI_OAUTH_PORT` | +| Antigravity | 51121 | `ANTIGRAVITY_OAUTH_PORT` | +| iFlow | 11451 | `IFLOW_OAUTH_PORT` | +| OpenAI Codex | 1455 | `OPENAI_CODEX_OAUTH_PORT` |
@@ -801,6 +869,7 @@ Options: --host TEXT Host to bind (default: 0.0.0.0) --port INTEGER Port to run on (default: 8000) --enable-request-logging Enable detailed per-request logging + --enable-raw-logging Capture raw proxy I/O payloads --add-credential Launch interactive credential setup tool ``` @@ -813,6 +882,9 @@ python src/proxy_app/main.py --host 127.0.0.1 --port 9000 # Run with logging python src/proxy_app/main.py --enable-request-logging +# Run with raw I/O logging +python src/proxy_app/main.py --enable-raw-logging + # Add credentials without starting proxy python src/proxy_app/main.py --add-credential ``` @@ -850,8 +922,8 @@ The proxy is available as a multi-architecture Docker image (amd64/arm64) from G cp .env.example .env nano .env -# 2. Create key_usage.json file (required before first run) -touch key_usage.json +# 2. Create usage directory (usage_*.json files are created automatically) +mkdir usage # 3. Start the proxy docker compose up -d @@ -860,13 +932,13 @@ docker compose up -d docker compose logs -f ``` -> **Important:** You must create `key_usage.json` before running Docker Compose. If this file doesn't exist on the host, Docker will create it as a directory instead of a file, causing the container to fail. +> **Important:** Create the `usage/` directory before running Docker Compose so usage stats persist on the host. **Manual Docker Run:** ```bash -# Create key_usage.json if it doesn't exist -touch key_usage.json +# Create usage directory if it doesn't exist +mkdir usage docker run -d \ --name llm-api-proxy \ @@ -875,7 +947,7 @@ docker run -d \ -v $(pwd)/.env:/app/.env:ro \ -v $(pwd)/oauth_creds:/app/oauth_creds \ -v $(pwd)/logs:/app/logs \ - -v $(pwd)/key_usage.json:/app/key_usage.json \ + -v $(pwd)/usage:/app/usage \ -e SKIP_OAUTH_INIT_CHECK=true \ -e PYTHONUNBUFFERED=1 \ ghcr.io/mirrowel/llm-api-key-proxy:latest @@ -895,7 +967,7 @@ docker compose -f docker-compose.dev.yml up -d --build | `.env` | Configuration and API keys (read-only) | | `oauth_creds/` | OAuth credential files (persistent) | | `logs/` | Request logs and detailed logging | -| `key_usage.json` | Usage statistics persistence | +| `usage/` | Usage statistics persistence (`usage_*.json`) | **Image Tags:** @@ -958,6 +1030,23 @@ See [VPS Deployment](Deployment%20guide.md#appendix-deploying-to-a-custom-vps) f --- +## Testing + +A lightweight pytest suite is now included under `tests/`. + +```bash +# Install runtime dependencies +pip install -r requirements.txt + +# Optional explicit test dependencies (also safe to run if already included) +pip install -r requirements-dev.txt + +# Run tests +pytest -q +``` + +--- + ## Troubleshooting | Issue | Solution | @@ -966,7 +1055,7 @@ See [VPS Deployment](Deployment%20guide.md#appendix-deploying-to-a-custom-vps) f | `500 Internal Server Error` | Check provider key validity; enable `--enable-request-logging` for details | | All keys on cooldown | All keys failed recently; check `logs/detailed_logs/` for upstream errors | | Model not found | Verify format is `provider/model_name` (e.g., `gemini/gemini-2.5-flash`) | -| OAuth callback failed | Ensure callback port (8085, 51121, 11451) isn't blocked by firewall | +| OAuth callback failed | Ensure callback port (8085, 51121, 11451, 1455) isn't blocked by firewall | | Streaming hangs | Increase `TIMEOUT_READ_STREAMING`; check provider status | **Detailed Logs:** diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 36458929..becc2606 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -19,8 +19,8 @@ services: - ./oauth_creds:/app/oauth_creds # Mount logs directory for persistent logging - ./logs:/app/logs - # Mount key_usage.json for usage statistics persistence - - ./key_usage.json:/app/key_usage.json + # Mount usage directory for usage statistics persistence + - ./usage:/app/usage # Optionally mount additional .env files (e.g., combined credential files) # - ./antigravity_all_combined.env:/app/antigravity_all_combined.env:ro environment: diff --git a/docker-compose.tls.yml b/docker-compose.tls.yml index e210423f..0c670b3d 100644 --- a/docker-compose.tls.yml +++ b/docker-compose.tls.yml @@ -36,8 +36,8 @@ services: - ./oauth_creds:/app/oauth_creds # Mount logs directory for persistent logging - ./logs:/app/logs - # Mount key_usage.json for usage statistics persistence - - ./key_usage.json:/app/key_usage.json + # Mount usage directory for usage statistics persistence + - ./usage:/app/usage # Optionally mount additional .env files (e.g., combined credential files) # - ./antigravity_all_combined.env:/app/antigravity_all_combined.env:ro environment: diff --git a/docker-compose.yml b/docker-compose.yml index 31964b60..027d5d91 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,8 +17,8 @@ services: - ./oauth_creds:/app/oauth_creds # Mount logs directory for persistent logging - ./logs:/app/logs - # Mount key_usage.json for usage statistics persistence - - ./key_usage.json:/app/key_usage.json + # Mount usage directory for usage statistics persistence + - ./usage:/app/usage # Optionally mount additional .env files (e.g., combined credential files) # - ./antigravity_all_combined.env:/app/antigravity_all_combined.env:ro environment: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..530e83f9 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +pytest +pytest-asyncio +respx diff --git a/requirements.txt b/requirements.txt index 1f5d4985..e5ee231c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,8 @@ customtkinter # For building the executable pyinstaller + +# Test dependencies +pytest +pytest-asyncio +respx diff --git a/src/proxy_app/build.py b/src/proxy_app/build.py index a4c5dc07..421c6bbb 100644 --- a/src/proxy_app/build.py +++ b/src/proxy_app/build.py @@ -49,6 +49,8 @@ def main(): "--hidden-import=rotator_library", "--hidden-import=tiktoken_ext.openai_public", "--hidden-import=tiktoken_ext", + # Fix for Rich 14.0+ which lazy-loads Unicode data via dynamic imports + "--collect-submodules=rich._unicode_data", "--collect-data", "litellm", # Optimization: Exclude unused heavy modules diff --git a/src/proxy_app/launcher_tui.py b/src/proxy_app/launcher_tui.py index 68338b02..35461589 100644 --- a/src/proxy_app/launcher_tui.py +++ b/src/proxy_app/launcher_tui.py @@ -190,6 +190,7 @@ def detect_credentials() -> dict: "antigravity": "ANTIGRAVITY", "qwen_code": "QWEN_CODE", "iflow": "IFLOW", + "openai_codex": "OPENAI_CODEX", } for provider, env_prefix in env_oauth_providers.items(): @@ -377,9 +378,9 @@ def show_main_menu(self): self.console.print( Panel( Text.from_markup( - "⚠️ [bold yellow]INITIAL SETUP REQUIRED[/bold yellow]\n\n" + ":warning: [bold yellow]INITIAL SETUP REQUIRED[/bold yellow]\n\n" "The proxy needs initial configuration:\n" - " ❌ No .env file found\n\n" + " :x: No .env file found\n\n" "Why this matters:\n" " • The .env file stores your credentials and settings\n" " • PROXY_API_KEY protects your proxy from unauthorized access\n" @@ -388,7 +389,7 @@ def show_main_menu(self): ' 1. Select option "3. Manage Credentials" to launch the credential tool\n' " 2. The tool will create .env and set up PROXY_API_KEY automatically\n" " 3. You can add provider credentials (API keys or OAuth)\n\n" - "⚠️ Note: The credential tool adds PROXY_API_KEY by default.\n" + ":warning: Note: The credential tool adds PROXY_API_KEY by default.\n" " You can remove it later if you want an unsecured proxy." ), border_style="yellow", @@ -401,12 +402,12 @@ def show_main_menu(self): self.console.print( Panel( Text.from_markup( - "⚠️ [bold red]SECURITY WARNING: PROXY_API_KEY Not Set[/bold red]\n\n" + ":warning: [bold red]SECURITY WARNING: PROXY_API_KEY Not Set[/bold red]\n\n" "Your proxy is currently UNSECURED!\n" "Anyone can access it without authentication.\n\n" "This is a serious security risk if your proxy is accessible\n" "from the internet or untrusted networks.\n\n" - "👉 [bold]Recommended:[/bold] Set PROXY_API_KEY in .env file\n" + ":point_right: [bold]Recommended:[/bold] Set PROXY_API_KEY in .env file\n" ' Use option "2. Configure Proxy Settings" → "3. Set Proxy API Key"\n' ' or option "3. Manage Credentials"' ), @@ -417,15 +418,15 @@ def show_main_menu(self): # Show config self.console.print() - self.console.print("[bold]📋 Proxy Configuration[/bold]") + self.console.print("[bold]:clipboard: Proxy Configuration[/bold]") self.console.print("━" * 70) self.console.print(f" Host: {self.config.config['host']}") self.console.print(f" Port: {self.config.config['port']}") self.console.print( - f" Transaction Logging: {'✅ Enabled' if self.config.config['enable_request_logging'] else '❌ Disabled'}" + f" Transaction Logging: {':white_check_mark: Enabled' if self.config.config['enable_request_logging'] else ':x: Disabled'}" ) self.console.print( - f" Raw I/O Logging: {'✅ Enabled' if self.config.config.get('enable_raw_logging', False) else '❌ Disabled'}" + f" Raw I/O Logging: {':white_check_mark: Enabled' if self.config.config.get('enable_raw_logging', False) else ':x: Disabled'}" ) # Show actual API key value @@ -437,7 +438,7 @@ def show_main_menu(self): # Show status summary self.console.print() - self.console.print("[bold]📊 Status Summary[/bold]") + self.console.print("[bold]:bar_chart: Status Summary[/bold]") self.console.print("━" * 70) provider_count = len(credentials) custom_count = len(custom_bases) @@ -458,24 +459,26 @@ def show_main_menu(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]🎯 Main Menu[/bold]") + self.console.print("[bold]:dart: Main Menu[/bold]") self.console.print() if show_warning: - self.console.print(" 1. ▶️ Run Proxy Server") - self.console.print(" 2. ⚙️ Configure Proxy Settings") + self.console.print(" 1. :arrow_forward: Run Proxy Server") + self.console.print(" 2. :gear: Configure Proxy Settings") self.console.print( - " 3. 🔑 Manage Credentials ⬅️ [bold yellow]Start here![/bold yellow]" + " 3. :key: Manage Credentials :arrow_left: [bold yellow]Start here![/bold yellow]" ) else: - self.console.print(" 1. ▶️ Run Proxy Server") - self.console.print(" 2. ⚙️ Configure Proxy Settings") - self.console.print(" 3. 🔑 Manage Credentials") + self.console.print(" 1. :arrow_forward: Run Proxy Server") + self.console.print(" 2. :gear: Configure Proxy Settings") + self.console.print(" 3. :key: Manage Credentials") - self.console.print(" 4. 📊 View Provider & Advanced Settings") - self.console.print(" 5. 📈 View Quota & Usage Stats (Alpha)") - self.console.print(" 6. 🔄 Reload Configuration") - self.console.print(" 7. ℹ️ About") - self.console.print(" 8. 🚪 Exit") + self.console.print(" 4. :bar_chart: View Provider & Advanced Settings") + self.console.print( + " 5. :chart_with_upwards_trend: View Quota & Usage Stats (Alpha)" + ) + self.console.print(" 6. :arrows_counterclockwise: Reload Configuration") + self.console.print(" 7. :information_source: About") + self.console.print(" 8. :door: Exit") self.console.print() self.console.print("━" * 70) @@ -500,7 +503,9 @@ def show_main_menu(self): elif choice == "6": load_dotenv(dotenv_path=_get_env_file(), override=True) self.config = LauncherConfig() # Reload config - self.console.print("\n[green]✅ Configuration reloaded![/green]") + self.console.print( + "\n[green]:white_check_mark: Configuration reloaded![/green]" + ) elif choice == "7": self.show_about() elif choice == "8": @@ -518,7 +523,7 @@ def confirm_setting_change(self, setting_name: str, warning_lines: list) -> bool self.console.print( Panel( Text.from_markup( - f"[bold yellow]⚠️ WARNING: You are about to change the {setting_name}[/bold yellow]\n\n" + f"[bold yellow]:warning: WARNING: You are about to change the {setting_name}[/bold yellow]\n\n" + "\n".join(warning_lines) + "\n\n[bold]If you are not sure about changing this - don't.[/bold]" ), @@ -548,36 +553,39 @@ def show_config_menu(self): self.console.print( Panel.fit( - "[bold cyan]⚙️ Proxy Configuration[/bold cyan]", border_style="cyan" + "[bold cyan]:gear: Proxy Configuration[/bold cyan]", + border_style="cyan", ) ) self.console.print() - self.console.print("[bold]📋 Current Settings[/bold]") + self.console.print("[bold]:clipboard: Current Settings[/bold]") self.console.print("━" * 70) self.console.print(f" Host: {self.config.config['host']}") self.console.print(f" Port: {self.config.config['port']}") self.console.print( - f" Transaction Logging: {'✅ Enabled' if self.config.config['enable_request_logging'] else '❌ Disabled'}" + f" Transaction Logging: {':white_check_mark: Enabled' if self.config.config['enable_request_logging'] else ':x: Disabled'}" ) self.console.print( - f" Raw I/O Logging: {'✅ Enabled' if self.config.config.get('enable_raw_logging', False) else '❌ Disabled'}" + f" Raw I/O Logging: {':white_check_mark: Enabled' if self.config.config.get('enable_raw_logging', False) else ':x: Disabled'}" ) self.console.print( - f" Proxy API Key: {'✅ Set' if os.getenv('PROXY_API_KEY') else '❌ Not Set'}" + f" Proxy API Key: {':white_check_mark: Set' if os.getenv('PROXY_API_KEY') else ':x: Not Set'}" ) self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]⚙️ Configuration Options[/bold]") + self.console.print("[bold]:gear: Configuration Options[/bold]") self.console.print() - self.console.print(" 1. 🌐 Set Host IP") + self.console.print(" 1. :globe_with_meridians: Set Host IP") self.console.print(" 2. 🔌 Set Port") - self.console.print(" 3. 🔑 Set Proxy API Key") - self.console.print(" 4. 📝 Toggle Transaction Logging") - self.console.print(" 5. 📋 Toggle Raw I/O Logging") - self.console.print(" 6. 🔄 Reset to Default Settings") + self.console.print(" 3. :key: Set Proxy API Key") + self.console.print(" 4. :memo: Toggle Transaction Logging") + self.console.print(" 5. :clipboard: Toggle Raw I/O Logging") + self.console.print( + " 6. :arrows_counterclockwise: Reset to Default Settings" + ) self.console.print(" 7. ↩️ Back to Main Menu") self.console.print() @@ -609,7 +617,9 @@ def show_config_menu(self): "Enter new host IP", default=self.config.config["host"] ) self.config.update(host=new_host) - self.console.print(f"\n[green]✅ Host updated to: {new_host}[/green]") + self.console.print( + f"\n[green]:white_check_mark: Host updated to: {new_host}[/green]" + ) elif choice == "2": # Show warning and require confirmation confirmed = self.confirm_setting_change( @@ -630,10 +640,10 @@ def show_config_menu(self): if 1 <= new_port <= 65535: self.config.update(port=new_port) self.console.print( - f"\n[green]✅ Port updated to: {new_port}[/green]" + f"\n[green]:white_check_mark: Port updated to: {new_port}[/green]" ) else: - self.console.print("\n[red]❌ Port must be between 1-65535[/red]") + self.console.print("\n[red]:x: Port must be between 1-65535[/red]") elif choice == "3": # Show warning and require confirmation confirmed = self.confirm_setting_change( @@ -641,11 +651,11 @@ def show_config_menu(self): [ "This is the authentication key that applications use to access your proxy.", "", - "[bold red]⚠️ Changing this will BREAK all applications currently configured", + "[bold red]:warning: Changing this will BREAK all applications currently configured", " with the existing API key![/bold red]", "", - "[bold cyan]💡 If you want to add provider API keys (OpenAI, Gemini, etc.),", - ' go to "3. 🔑 Manage Credentials" in the main menu instead.[/bold cyan]', + "[bold cyan]:bulb: If you want to add provider API keys (OpenAI, Gemini, etc.),", + ' go to "3. :key: Manage Credentials" in the main menu instead.[/bold cyan]', ], ) if not confirmed: @@ -661,7 +671,7 @@ def show_config_menu(self): # If setting to empty, show additional warning if not new_key: self.console.print( - "\n[bold red]⚠️ Authentication will be DISABLED - anyone can access your proxy![/bold red]" + "\n[bold red]:warning: Authentication will be DISABLED - anyone can access your proxy![/bold red]" ) Prompt.ask("Press Enter to continue", default="") @@ -669,12 +679,12 @@ def show_config_menu(self): if new_key: self.console.print( - "\n[green]✅ Proxy API Key updated successfully![/green]" + "\n[green]:white_check_mark: Proxy API Key updated successfully![/green]" ) self.console.print(" Updated in .env file") else: self.console.print( - "\n[yellow]⚠️ Proxy API Key cleared - authentication disabled![/yellow]" + "\n[yellow]:warning: Proxy API Key cleared - authentication disabled![/yellow]" ) self.console.print(" Updated in .env file") else: @@ -683,13 +693,13 @@ def show_config_menu(self): current = self.config.config["enable_request_logging"] self.config.update(enable_request_logging=not current) self.console.print( - f"\n[green]✅ Transaction Logging {'enabled' if not current else 'disabled'}![/green]" + f"\n[green]:white_check_mark: Transaction Logging {'enabled' if not current else 'disabled'}![/green]" ) elif choice == "5": current = self.config.config.get("enable_raw_logging", False) self.config.update(enable_raw_logging=not current) self.console.print( - f"\n[green]✅ Raw I/O Logging {'enabled' if not current else 'disabled'}![/green]" + f"\n[green]:white_check_mark: Raw I/O Logging {'enabled' if not current else 'disabled'}![/green]" ) elif choice == "6": # Reset to Default Settings @@ -725,7 +735,7 @@ def show_config_menu(self): else f" Raw I/O Logging {'Disabled':20} → Disabled", f" Proxy API Key {current_api_key[:20]:20} → {default_api_key}", "", - "[bold red]⚠️ This may break applications configured with current settings![/bold red]", + "[bold red]:warning: This may break applications configured with current settings![/bold red]", ] confirmed = self.confirm_setting_change( @@ -744,7 +754,7 @@ def show_config_menu(self): LauncherConfig.update_proxy_api_key(default_api_key) self.console.print( - "\n[green]✅ All settings have been reset to defaults![/green]" + "\n[green]:white_check_mark: All settings have been reset to defaults![/green]" ) self.console.print(f" Host: {default_host}") self.console.print(f" Port: {default_port}") @@ -769,14 +779,14 @@ def show_provider_settings_menu(self): self.console.print( Panel.fit( - "[bold cyan]📊 Provider & Advanced Settings[/bold cyan]", + "[bold cyan]:bar_chart: Provider & Advanced Settings[/bold cyan]", border_style="cyan", ) ) # Configured Providers self.console.print() - self.console.print("[bold]📊 Configured Providers[/bold]") + self.console.print("[bold]:bar_chart: Configured Providers[/bold]") self.console.print("━" * 70) if credentials: for provider, info in credentials.items(): @@ -795,14 +805,16 @@ def show_provider_settings_menu(self): if info["custom"]: display += " (Custom)" - self.console.print(f" ✅ {provider_name:20} {display}") + self.console.print( + f" :white_check_mark: {provider_name:20} {display}" + ) else: self.console.print(" [dim]No providers configured[/dim]") # Custom API Bases if custom_bases: self.console.print() - self.console.print("[bold]🌐 Custom API Bases[/bold]") + self.console.print("[bold]:globe_with_meridians: Custom API Bases[/bold]") self.console.print("━" * 70) for provider, base in custom_bases.items(): self.console.print(f" • {provider:15} {base}") @@ -829,7 +841,7 @@ def show_provider_settings_menu(self): # Model Filters (basic info only) if filters: self.console.print() - self.console.print("[bold]🎯 Model Filters[/bold]") + self.console.print("[bold]:dart: Model Filters[/bold]") self.console.print("━" * 70) for provider, filter_info in filters.items(): status_parts = [] @@ -838,7 +850,7 @@ def show_provider_settings_menu(self): if filter_info["has_ignore"]: status_parts.append("Ignore list") status = " + ".join(status_parts) if status_parts else "None" - self.console.print(f" • {provider:15} ✅ {status}") + self.console.print(f" • {provider:15} :white_check_mark: {status}") # Provider-Specific Settings (deferred to Settings Tool to avoid heavy imports) self.console.print() @@ -852,21 +864,21 @@ def show_provider_settings_menu(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]💡 Actions[/bold]") + self.console.print("[bold]:bulb: Actions[/bold]") self.console.print() self.console.print( - " 1. 🔧 Launch Settings Tool (configure advanced settings)" + " 1. :wrench: Launch Settings Tool (configure advanced settings)" ) self.console.print(" 2. ↩️ Back to Main Menu") self.console.print() self.console.print("━" * 70) self.console.print( - "[dim]ℹ️ Advanced settings are stored in .env file.\n Use the Settings Tool to configure them interactively.[/dim]" + "[dim]:information_source: Advanced settings are stored in .env file.\n Use the Settings Tool to configure them interactively.[/dim]" ) self.console.print() self.console.print( - "[dim]⚠️ Note: Settings Tool supports only common configuration types.\n For complex settings, edit .env directly.[/dim]" + "[dim]:warning: Note: Settings Tool supports only common configuration types.\n For complex settings, edit .env directly.[/dim]" ) self.console.print() @@ -959,7 +971,8 @@ def show_about(self): self.console.print( Panel.fit( - "[bold cyan]ℹ️ About LLM API Key Proxy[/bold cyan]", border_style="cyan" + "[bold cyan]:information_source: About LLM API Key Proxy[/bold cyan]", + border_style="cyan", ) ) @@ -1005,7 +1018,7 @@ def show_about(self): ) self.console.print() - self.console.print("[bold]📝 License & Credits[/bold]") + self.console.print("[bold]:memo: License & Credits[/bold]") self.console.print("━" * 70) self.console.print(" Made with ❤️ by the community") self.console.print(" Open source - contributions welcome!") @@ -1024,7 +1037,7 @@ def run_proxy(self): self.console.print( Panel( Text.from_markup( - "⚠️ [bold yellow]Setup Required[/bold yellow]\n\n" + ":warning: [bold yellow]Setup Required[/bold yellow]\n\n" "Cannot start without .env.\n" "Launching credential tool..." ), @@ -1046,7 +1059,7 @@ def run_proxy(self): # Check again after credential tool if not os.getenv("PROXY_API_KEY"): self.console.print( - "\n[red]❌ PROXY_API_KEY still not set. Cannot start proxy.[/red]" + "\n[red]:x: PROXY_API_KEY still not set. Cannot start proxy.[/red]" ) return diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 12014bdc..3e4bbbbc 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -603,6 +603,8 @@ async def process_credential(provider: str, path: str, provider_instance): max_concurrent_requests_per_key=max_concurrent_requests_per_key, ) + await client.initialize_usage_managers() + # Log loaded credentials summary (compact, always visible for deployment verification) # _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none" # _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none" @@ -956,7 +958,9 @@ async def chat_completions( is_streaming = request_data.get("stream", False) if is_streaming: - response_generator = client.acompletion(request=request, **request_data) + response_generator = await client.acompletion( + request=request, **request_data + ) return StreamingResponse( streaming_response_wrapper( request, request_data, response_generator, raw_logger @@ -1649,10 +1653,10 @@ def show_onboarding_message(): border_style="cyan", ) ) - console.print("[bold yellow]⚠️ Configuration Required[/bold yellow]\n") + console.print("[bold yellow]:warning: Configuration Required[/bold yellow]\n") console.print("The proxy needs initial configuration:") - console.print(" [red]❌ No .env file found[/red]") + console.print(" [red]:x: No .env file found[/red]") console.print("\n[bold]Why this matters:[/bold]") console.print(" • The .env file stores your credentials and settings") @@ -1665,7 +1669,7 @@ def show_onboarding_message(): console.print(" 3. The proxy will then start normally") console.print( - "\n[bold yellow]⚠️ Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default." + "\n[bold yellow]:warning: Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default." ) console.print(" You can remove it later if you want an unsecured proxy.\n") @@ -1708,13 +1712,15 @@ def show_onboarding_message(): # Verify onboarding is complete if needs_onboarding(): - console.print("\n[bold red]❌ Configuration incomplete.[/bold red]") + console.print("\n[bold red]:x: Configuration incomplete.[/bold red]") console.print( "The proxy still cannot start. Please ensure PROXY_API_KEY is set in .env\n" ) sys.exit(1) else: - console.print("\n[bold green]✅ Configuration complete![/bold green]") + console.print( + "\n[bold green]:white_check_mark: Configuration complete![/bold green]" + ) console.print("\nStarting proxy server...\n") import uvicorn diff --git a/src/proxy_app/quota_viewer.py b/src/proxy_app/quota_viewer.py index 6cdb224d..687f96cd 100644 --- a/src/proxy_app/quota_viewer.py +++ b/src/proxy_app/quota_viewer.py @@ -6,42 +6,6 @@ Connects to a running proxy to display quota and usage statistics. Uses only httpx + rich (no heavy rotator_library imports). - -TODO: Missing Features & Improvements -====================================== - -Display Improvements: -- [ ] Add color legend/help screen explaining status colors and symbols -- [ ] Show credential email/project ID if available (currently just filename) -- [ ] Add keyboard shortcut hints (e.g., "Press ? for help") -- [ ] Support terminal resize / responsive layout - -Global Stats Fix: -- [ ] HACK: Global requests currently set to current period requests only - (see client.py get_quota_stats). This doesn't include archived stats. - Fix requires tracking archived requests per quota group in usage_manager.py - to avoid double-counting models that share quota groups. - -Data & Refresh: -- [ ] Auto-refresh option (configurable interval) -- [ ] Show last refresh timestamp more prominently -- [ ] Cache invalidation when switching between current/global view -- [ ] Support for non-OAuth providers (API keys like nvapi-*, gsk_*, etc.) - -Remote Management: -- [ ] Test connection before saving remote -- [ ] Import/export remote configurations -- [ ] SSH tunnel support for remote proxies - -Quota Groups: -- [ ] Show which models are in each quota group (expandable) -- [ ] Historical quota usage graphs (if data available) -- [ ] Alerts/notifications when quota is low - -Credential Details: -- [ ] Show per-model breakdown within quota groups -- [ ] Edit credential priority/tier manually -- [ ] Disable/enable individual credentials """ import os @@ -62,6 +26,56 @@ from .quota_viewer_config import QuotaViewerConfig +# ============================================================================= +# DISPLAY CONFIGURATION - Adjust these values to customize the layout +# ============================================================================= + +# Summary screen table column widths +TABLE_PROVIDER_WIDTH = 12 +TABLE_CREDS_WIDTH = 3 +TABLE_QUOTA_STATUS_WIDTH = 62 +TABLE_REQUESTS_WIDTH = 5 +TABLE_TOKENS_WIDTH = 20 +TABLE_COST_WIDTH = 6 + +# Quota status formatting in summary screen +QUOTA_NAME_WIDTH = 15 # Width for quota group name (e.g., "claude:") +QUOTA_USAGE_WIDTH = 11 # Width for usage ratio (e.g., "2071/2700") +QUOTA_PCT_WIDTH = 6 # Width for percentage (e.g., "76.7%") +QUOTA_BAR_WIDTH = 10 # Width for progress bar + +# Detail view credential panel formatting +DETAIL_GROUP_NAME_WIDTH = ( + 18 # Width for group name in detail view (handles "g25-flash (daily)") +) +DETAIL_USAGE_WIDTH = ( + 16 # Width for usage ratio in detail view (handles "3000/3000(5000)") +) +DETAIL_PCT_WIDTH = 7 # Width for percentage in detail view + +# ============================================================================= +# STATUS DISPLAY CONFIGURATION +# ============================================================================= + +# Credential status icons and colors: (icon, label, color) +# Using Rich emoji markup :name: for consistent width handling +STATUS_DISPLAY = { + "active": (":white_check_mark:", "Active", "green"), + "cooldown": (":stopwatch:", "Cooldown", "yellow"), + "exhausted": (":no_entry:", "Exhausted", "red"), + "mixed": (":warning:", "Mixed", "yellow"), +} + +# Per-group indicator icons (using Rich emoji markup for proper width handling) +INDICATOR_ICONS = { + "fair_cycle": ":scales:", # ⚖️ - Rich will handle width + "custom_cap": ":bar_chart:", # 📊 + "cooldown": ":stopwatch:", # ⏱️ +} + +# ============================================================================= + + def clear_screen(): """Clear the terminal screen.""" os.system("cls" if os.name == "nt" else "clear") @@ -187,19 +201,127 @@ def format_cooldown(seconds: int) -> str: return f"{hours}h {mins}m" if mins > 0 else f"{hours}h" -def natural_sort_key(item: Dict[str, Any]) -> List: +def natural_sort_key(item: Any) -> List: """ Generate a sort key for natural/numeric sorting. Sorts credentials like proj-1, proj-2, proj-10 correctly instead of alphabetically (proj-1, proj-10, proj-2). + + Handles both dict items (new API format) and strings. """ - identifier = item.get("identifier", "") + if isinstance(item, dict): + identifier = item.get("identifier", item.get("stable_id", "")) + else: + identifier = str(item) # Split into text and numeric parts parts = re.split(r"(\d+)", identifier) return [int(p) if p.isdigit() else p.lower() for p in parts] +def get_credentials_list(prov_stats: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Convert credentials from dict format to list. + + The new API returns credentials as a dict keyed by stable_id. + This function converts it to a list for iteration and sorting. + """ + credentials = prov_stats.get("credentials", {}) + if isinstance(credentials, list): + return credentials + if isinstance(credentials, dict): + return list(credentials.values()) + return [] + + +def get_credential_stats( + cred: Dict[str, Any], view_mode: str = "current" +) -> Dict[str, Any]: + """ + Extract display stats from a credential with field name adaptation. + + Maps new API field names to what the viewer expects: + - totals.request_count -> requests + - totals.last_used_at -> last_used_ts + - totals.approx_cost -> approx_cost + - Derive tokens from totals + """ + totals = cred.get("totals", {}) + + # For global view mode, we'd need global totals (currently same as totals) + if view_mode == "global": + stats_source = cred.get("global", totals) + if stats_source == totals: + stats_source = totals + else: + stats_source = totals + + # Calculate proper token stats + prompt_tokens = stats_source.get("prompt_tokens", 0) + cache_read = stats_source.get("prompt_tokens_cache_read", 0) + output_tokens = stats_source.get("output_tokens", 0) + + # Total input = uncached (prompt_tokens) + cached (cache_read) + input_total = prompt_tokens + cache_read + input_cached = cache_read + input_uncached = prompt_tokens + + cache_pct = round(input_cached / input_total * 100, 1) if input_total > 0 else 0 + + return { + "requests": stats_source.get("request_count", 0), + "last_used_ts": stats_source.get("last_used_at"), + "approx_cost": stats_source.get("approx_cost"), + "tokens": { + "input_cached": input_cached, + "input_uncached": input_uncached, + "input_cache_pct": cache_pct, + "output": output_tokens, + }, + } + + +def provider_sort_key(item: Tuple[str, Dict[str, Any]]) -> Tuple: + """ + Sort key for providers. + + Order: has quota_groups -> has activity -> request count -> credential count + """ + name, stats = item + has_quota_groups = bool(stats.get("quota_groups")) + has_activity = stats.get("total_requests", 0) > 0 + return ( + not has_quota_groups, # False (has groups) sorts first + not has_activity, # False (has activity) sorts first + -stats.get("total_requests", 0), # Higher requests first + -stats.get("credential_count", 0), # Higher creds first + name.lower(), # Alphabetically last + ) + + +def quota_group_sort_key(item: Tuple[str, Dict[str, Any]]) -> Tuple: + """ + Sort key for quota groups. + + Order: by total quota limit (lowest first), then alphabetically. + Groups without limits sort last. + """ + name, group_stats = item + windows = group_stats.get("windows", {}) + + if not windows: + return (float("inf"), name) # No windows = sort last + + # Find minimum total_max across windows + min_limit = float("inf") + for window_stats in windows.values(): + total_max = window_stats.get("total_max", 0) + if total_max > 0: + min_limit = min(min_limit, total_max) + + return (min_limit, name) + + class QuotaViewer: """Main Quota Viewer TUI class.""" @@ -210,7 +332,8 @@ def __init__(self, config: Optional[QuotaViewerConfig] = None): Args: config: Optional config object. If not provided, one will be created. """ - self.console = Console() + # Use emoji_variant="text" for more consistent width calculations + self.console = Console(emoji_variant="text") self.config = config or QuotaViewerConfig() self.config.sync_with_launcher_config() @@ -651,18 +774,18 @@ def show_summary_screen(self): # View mode indicator if self.view_mode == "global": - view_label = "[magenta]📊 Global/Lifetime[/magenta]" + view_label = "[magenta]:bar_chart: Global/Lifetime[/magenta]" else: - view_label = "[cyan]📈 Current Period[/cyan]" + view_label = "[cyan]:chart_with_upwards_trend: Current Period[/cyan]" self.console.print("━" * 78) self.console.print( - f"[bold cyan]📈 Quota & Usage Statistics[/bold cyan] | {view_label}" + f"[bold cyan]:chart_with_upwards_trend: Quota & Usage Statistics[/bold cyan] | {view_label}" ) self.console.print("━" * 78) self.console.print( f"Connected to: [bold]{remote_name}[/bold] ({connection_display}) " - f"[green]✅[/green] | {data_age}" + f"[green]:white_check_mark:[/green] | {data_age}" ) self.console.print() @@ -673,17 +796,18 @@ def show_summary_screen(self): table = Table( box=None, show_header=True, header_style="bold", padding=(0, 1) ) - table.add_column("Provider", style="cyan", min_width=10) - table.add_column("Creds", justify="center", min_width=5) - table.add_column("Quota Status", min_width=28) - table.add_column("Requests", justify="right", min_width=8) - table.add_column("Tokens (in/out)", min_width=20) - table.add_column("Cost", justify="right", min_width=6) + table.add_column("Provider", style="cyan", min_width=TABLE_PROVIDER_WIDTH) + table.add_column("Creds", justify="center", min_width=TABLE_CREDS_WIDTH) + table.add_column("Quota Status", min_width=TABLE_QUOTA_STATUS_WIDTH) + table.add_column("Req.", justify="right", min_width=TABLE_REQUESTS_WIDTH) + table.add_column("Tok. in/out(cached)", min_width=TABLE_TOKENS_WIDTH) + table.add_column("Cost", justify="right", min_width=TABLE_COST_WIDTH) providers = self.cached_stats.get("providers", {}) - provider_list = list(providers.keys()) + # Sort providers: quota_groups -> activity -> requests -> creds + sorted_providers = sorted(providers.items(), key=provider_sort_key) - for idx, (provider, prov_stats) in enumerate(providers.items(), 1): + for idx, (provider, prov_stats) in enumerate(sorted_providers, 1): cred_count = prov_stats.get("credential_count", 0) # Use global stats if in global mode @@ -703,7 +827,7 @@ def show_summary_screen(self): ) output = tokens.get("output", 0) cache_pct = tokens.get("input_cache_pct", 0) - token_str = f"{format_tokens(input_total)}/{format_tokens(output)} ({cache_pct}% cached)" + token_str = f"{format_tokens(input_total)}/{format_tokens(output)} ({cache_pct}%)" # Format cost cost_str = format_cost(cost_value) @@ -712,63 +836,100 @@ def show_summary_screen(self): quota_groups = prov_stats.get("quota_groups", {}) if quota_groups: quota_lines = [] - for group_name, group_stats in quota_groups.items(): - # Use remaining requests (not used) so percentage matches displayed value - total_remaining = group_stats.get("total_requests_remaining", 0) - total_max = group_stats.get("total_requests_max", 0) - total_pct = group_stats.get("total_remaining_pct") - tiers = group_stats.get("tiers", {}) + # Sort quota groups by minimum remaining % (lowest first) + sorted_groups = sorted( + quota_groups.items(), key=quota_group_sort_key + ) - # Format tier info: "5(15)f/2s" = 5 active out of 15 free, 2 standard all active - # Sort by priority (lower number = higher priority, appears first) - tier_parts = [] - sorted_tiers = sorted( - tiers.items(), key=lambda x: x[1].get("priority", 10) - ) - for tier_name, tier_info in sorted_tiers: - if tier_name == "unknown": - continue # Skip unknown tiers in display - total_t = tier_info.get("total", 0) - active_t = tier_info.get("active", 0) - # Use first letter: standard-tier -> s, free-tier -> f - short = tier_name.replace("-tier", "")[0] - - if active_t < total_t: - # Some exhausted - show active(total) - tier_parts.append(f"{active_t}({total_t}){short}") + for group_name, group_stats in sorted_groups: + tiers = group_stats.get("tiers", {}) + windows = group_stats.get("windows", {}) + fc_summary = group_stats.get("fair_cycle_summary", {}) + + if not windows: + # No windows = no data, skip + continue + + # Process each window for this group + for window_name, window_stats in windows.items(): + total_remaining = window_stats.get("total_remaining", 0) + total_max = window_stats.get("total_max", 0) + total_pct = window_stats.get("remaining_pct") + tier_availability = window_stats.get( + "tier_availability", {} + ) + + # Format tier info using per-window availability + # "15(18)s" = 15 available out of 18 standard-tier + tier_parts = [] + # Sort tiers by priority (from provider-level tiers) + sorted_tier_names = sorted( + tier_availability.keys(), + key=lambda t: tiers.get(t, {}).get("priority", 10), + ) + for tier_name in sorted_tier_names: + if tier_name == "unknown": + continue + avail_info = tier_availability[tier_name] + total_t = avail_info.get("total", 0) + available_t = avail_info.get("available", 0) + # Use first letter: standard-tier -> s, free-tier -> f + short = tier_name.replace("-tier", "")[0] + + if available_t < total_t: + tier_parts.append( + f"{available_t}({total_t}){short}" + ) + else: + tier_parts.append(f"{total_t}{short}") + + tier_str = "/".join(tier_parts) if tier_parts else "" + + # Only show tier info if this group has limits + if total_max == 0: + tier_str = "" + + # Build FC summary string if any credentials are FC exhausted + fc_str = "" + fc_exhausted = fc_summary.get("exhausted_count", 0) + fc_total = fc_summary.get("total_count", 0) + if fc_exhausted > 0: + fc_str = f"[yellow]{INDICATOR_ICONS['fair_cycle']} {fc_exhausted}/{fc_total}[/yellow]" + + # Determine color based on remaining percentage and FC status + if total_pct is not None: + if total_pct <= 10: + color = "red" + elif total_pct < 30 or fc_exhausted > 0: + color = "yellow" + else: + color = "green" else: - # All active - just show total - tier_parts.append(f"{total_t}{short}") - tier_str = "/".join(tier_parts) if tier_parts else "" - - # Determine color based purely on remaining percentage - if total_pct is not None: - if total_pct <= 10: - color = "red" - elif total_pct < 30: - color = "yellow" + color = "dim" + + pct_str = f"{total_pct}%" if total_pct is not None else "?" + + # Format: "group (window): remaining/max pct% bar tier_info fc_info" + # Show window name if multiple windows exist + if len(windows) > 1: + display_name = f"{group_name} ({window_name})" else: - color = "green" - else: - color = "dim" + display_name = group_name - bar = create_progress_bar(total_pct) - pct_str = f"{total_pct}%" if total_pct is not None else "?" + display_name_trunc = display_name[: QUOTA_NAME_WIDTH - 1] + usage_str = f"{total_remaining}/{total_max}" + bar = create_progress_bar(total_pct, QUOTA_BAR_WIDTH) - # Build status suffix (just tiers now, no outer parens) - status = tier_str + # Build the line with tier info and FC summary + line_parts = [ + f"[{color}]{display_name_trunc + ':':<{QUOTA_NAME_WIDTH}}{usage_str:>{QUOTA_USAGE_WIDTH}} {pct_str:>{QUOTA_PCT_WIDTH}} {bar}[/{color}]" + ] + if tier_str: + line_parts.append(tier_str) + if fc_str: + line_parts.append(fc_str) - # Fixed-width format for aligned bars - # Adjust these to change column spacing: - QUOTA_NAME_WIDTH = 10 # name + colon, left-aligned - QUOTA_USAGE_WIDTH = ( - 12 # remaining/max ratio, right-aligned (handles 100k+) - ) - display_name = group_name[: QUOTA_NAME_WIDTH - 1] - usage_str = f"{total_remaining}/{total_max}" - quota_lines.append( - f"[{color}]{display_name + ':':<{QUOTA_NAME_WIDTH}}{usage_str:>{QUOTA_USAGE_WIDTH}} {pct_str:>4} {bar}[/{color}] {status}" - ) + quota_lines.append(" ".join(line_parts)) # First line goes in the main row first_quota = quota_lines[0] if quota_lines else "-" @@ -795,9 +956,14 @@ def show_summary_screen(self): ) # Add separator between providers (except last) - if idx < len(providers): + if idx < len(sorted_providers): table.add_row( - "─" * 10, "─" * 4, "─" * 26, "─" * 7, "─" * 20, "─" * 6 + "─" * TABLE_PROVIDER_WIDTH, + "─" * TABLE_CREDS_WIDTH, + "─" * TABLE_QUOTA_STATUS_WIDTH, + "─" * TABLE_REQUESTS_WIDTH, + "─" * TABLE_TOKENS_WIDTH, + "─" * TABLE_COST_WIDTH, ) self.console.print(table) @@ -832,9 +998,10 @@ def show_summary_screen(self): self.console.print("━" * 78) self.console.print() - # Build provider menu options + # Build provider menu options (use same sorted order as display) providers = self.cached_stats.get("providers", {}) if self.cached_stats else {} - provider_list = list(providers.keys()) + sorted_providers = sorted(providers.items(), key=provider_sort_key) + provider_list = [name for name, _ in sorted_providers] for idx, provider in enumerate(provider_list, 1): self.console.print(f" {idx}. View [cyan]{provider}[/cyan] details") @@ -886,7 +1053,7 @@ def show_provider_detail_screen(self, provider: str): self.console.print("━" * 78) self.console.print( - f"[bold cyan]📊 {provider.title()} - Detailed Stats[/bold cyan] | {view_label}" + f"[bold cyan]:bar_chart: {provider.title()} - Detailed Stats[/bold cyan] | {view_label}" ) self.console.print("━" * 78) self.console.print() @@ -895,7 +1062,7 @@ def show_provider_detail_screen(self, provider: str): self.console.print("[yellow]No data available.[/yellow]") else: prov_stats = self.cached_stats.get("providers", {}).get(provider, {}) - credentials = prov_stats.get("credentials", []) + credentials = get_credentials_list(prov_stats) # Sort credentials naturally (1, 2, 10 not 1, 10, 2) credentials = sorted(credentials, key=natural_sort_key) @@ -913,8 +1080,6 @@ def show_provider_detail_screen(self, provider: str): self.console.print("━" * 78) self.console.print() self.console.print(" G. Toggle view mode (current/global)") - self.console.print(" R. Reload stats (from proxy cache)") - self.console.print(" RA. Reload all stats") # Force refresh options (only for providers that support it) has_quota_groups = bool( @@ -924,18 +1089,29 @@ def show_provider_detail_screen(self, provider: str): .get("quota_groups") ) + # Model toggle option (only show if provider has quota groups) - MOVED UP + if has_quota_groups: + show_models_status = ( + "ON" if self.config.get_show_models(provider) else "OFF" + ) + self.console.print( + f" T. Toggle model details ({show_models_status})" + ) + + self.console.print(" R. Reload stats (from proxy cache)") + self.console.print(" RA. Reload all stats") + if has_quota_groups: self.console.print() self.console.print( f" F. [yellow]Force refresh ALL {provider} quotas from API[/yellow]" ) - credentials = ( - self.cached_stats.get("providers", {}) - .get(provider, {}) - .get("credentials", []) + prov_stats_for_menu = ( + self.cached_stats.get("providers", {}).get(provider, {}) if self.cached_stats - else [] + else {} ) + credentials = get_credentials_list(prov_stats_for_menu) # Sort credentials naturally credentials = sorted(credentials, key=natural_sort_key) for idx, cred in enumerate(credentials, 1): @@ -945,6 +1121,14 @@ def show_provider_detail_screen(self, provider: str): f" F{idx}. Force refresh [{idx}] only ({email})" ) + # DEBUG: Add fake window for testing multi-window display + if has_quota_groups: + self.console.print() + self.console.print(" [dim]DEBUG:[/dim]") + self.console.print( + " W. [dim]Add fake 'daily' window (test multi-window)[/dim]" + ) + self.console.print() self.console.print(" B. Back to summary") self.console.print() @@ -957,6 +1141,13 @@ def show_provider_detail_screen(self, provider: str): elif choice == "G": # Toggle view mode self.view_mode = "global" if self.view_mode == "current" else "current" + elif choice == "T" and has_quota_groups: + # Toggle show models + new_state = self.config.toggle_show_models(provider) + status_str = "enabled" if new_state else "disabled" + self.console.print( + f"[dim]Model details {status_str} for {provider}[/dim]" + ) elif choice == "R": with self.console.status( f"[bold]Reloading {provider} stats...", spinner="dots" @@ -987,15 +1178,21 @@ def show_provider_detail_screen(self, provider: str): for err in rr["errors"]: self.console.print(f"[red] Error: {err}[/red]") Prompt.ask("Press Enter to continue", default="") + elif choice == "W" and has_quota_groups: + # DEBUG: Inject fake "daily" window for testing multi-window display + self._inject_fake_daily_window(provider) + self.console.print( + "[dim]Injected fake 'daily' window into cached stats[/dim]" + ) + Prompt.ask("Press Enter to continue", default="") elif choice.startswith("F") and choice[1:].isdigit() and has_quota_groups: idx = int(choice[1:]) - credentials = ( - self.cached_stats.get("providers", {}) - .get(provider, {}) - .get("credentials", []) + prov_stats_for_refresh = ( + self.cached_stats.get("providers", {}).get(provider, {}) if self.cached_stats - else [] + else {} ) + credentials = get_credentials_list(prov_stats_for_refresh) # Sort credentials naturally to match display order credentials = sorted(credentials, key=natural_sort_key) if 1 <= idx <= len(credentials): @@ -1030,44 +1227,46 @@ def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str tier = cred.get("tier", "") status = cred.get("status", "unknown") - # Check for active cooldowns - key_cooldown = cred.get("key_cooldown_remaining") - model_cooldowns = cred.get("model_cooldowns", {}) - has_cooldown = key_cooldown or model_cooldowns - - # Status indicator - if status == "exhausted": - status_icon = "[red]⛔ Exhausted[/red]" - elif status == "cooldown" or has_cooldown: - if key_cooldown: - status_icon = f"[yellow]⚠️ Cooldown ({format_cooldown(int(key_cooldown))})[/yellow]" + # Check for active cooldowns (new format: cooldowns dict) + cooldowns = cred.get("cooldowns", {}) + has_cooldown = bool(cooldowns) + + # Check for global cooldown + global_cooldown = cooldowns.get("_global_", {}) + global_cooldown_remaining = ( + global_cooldown.get("remaining_seconds", 0) if global_cooldown else 0 + ) + + # Status indicator using centralized config + status_config = STATUS_DISPLAY.get(status, STATUS_DISPLAY["active"]) + icon, label, color = status_config + + if status == "cooldown" or (has_cooldown and global_cooldown_remaining > 0): + if global_cooldown_remaining > 0: + status_icon = f"[{color}]{icon} {label} ({format_cooldown(int(global_cooldown_remaining))})[/{color}]" else: - status_icon = "[yellow]⚠️ Cooldown[/yellow]" + status_icon = f"[{color}]{icon} {label}[/{color}]" else: - status_icon = "[green]✅ Active[/green]" + status_icon = f"[{color}]{icon} {label}[/{color}]" # Header line display_name = email if email else identifier tier_str = f" ({tier})" if tier else "" header = f"[{idx}] {display_name}{tier_str} {status_icon}" - # Use global stats if in global mode - if self.view_mode == "global": - stats_source = cred.get("global", cred) - else: - stats_source = cred - - # Stats line - last_used = format_time_ago(cred.get("last_used_ts")) # Always from current - requests = stats_source.get("requests", 0) - tokens = stats_source.get("tokens", {}) + # Get stats using helper function + cred_stats = get_credential_stats(cred, self.view_mode) + last_used = format_time_ago(cred_stats.get("last_used_ts")) + requests = cred_stats.get("requests", 0) + tokens = cred_stats.get("tokens", {}) input_total = tokens.get("input_cached", 0) + tokens.get("input_uncached", 0) output = tokens.get("output", 0) - cost = format_cost(stats_source.get("approx_cost")) + cache_pct = tokens.get("input_cache_pct", 0) + cost = format_cost(cred_stats.get("approx_cost")) stats_line = ( f"Last used: {last_used} | Requests: {requests} | " - f"Tokens: {format_tokens(input_total)}/{format_tokens(output)}" + f"Tokens: {format_tokens(input_total)}/{format_tokens(output)} ({cache_pct}%)" ) if cost != "-": stats_line += f" | Cost: {cost}" @@ -1077,136 +1276,188 @@ def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str f"[dim]{stats_line}[/dim]", ] - # Model groups (for providers with quota tracking) - model_groups = cred.get("model_groups", {}) - - # Show cooldowns grouped by quota group (if model_groups exist) - if model_cooldowns: - if model_groups: - # Group cooldowns by quota group - group_cooldowns: Dict[ - str, int - ] = {} # group_name -> max_remaining_seconds - ungrouped_cooldowns: List[Tuple[str, int]] = [] - - for model_name, cooldown_info in model_cooldowns.items(): - remaining = cooldown_info.get("remaining_seconds", 0) - if remaining <= 0: - continue - - # Find which group this model belongs to - clean_model = model_name.split("/")[-1] - found_group = None - for group_name, group_info in model_groups.items(): - group_models = group_info.get("models", []) - if clean_model in group_models: - found_group = group_name - break - - if found_group: - group_cooldowns[found_group] = max( - group_cooldowns.get(found_group, 0), remaining + # Group usage (for providers with quota tracking) + group_usage = cred.get("group_usage", {}) + + # Show cooldowns + if cooldowns: + active_cooldowns = [] + for key, cooldown_info in cooldowns.items(): + if key == "_global_": + continue # Already shown in status + remaining = cooldown_info.get("remaining_seconds", 0) + if remaining > 0: + reason = cooldown_info.get("reason", "") + source = cooldown_info.get("source", "") + active_cooldowns.append((key, remaining, reason, source)) + + if active_cooldowns: + content_lines.append("") + content_lines.append("[yellow]Active Cooldowns:[/yellow]") + for key, remaining, reason, source in active_cooldowns: + source_str = f" ({source})" if source else "" + content_lines.append( + f" [yellow]:stopwatch: {key}: {format_cooldown(int(remaining))}{source_str}[/yellow]" + ) + + # Display group usage with per-window breakdown + # Note: group_usage is pre-sorted by limit (lowest first) from the API + if group_usage: + content_lines.append("") + content_lines.append("[bold]Quota Groups:[/bold]") + + for group_name, group_stats in group_usage.items(): + windows = group_stats.get("windows", {}) + if not windows: + continue + + # Get per-group status info + fc_exhausted = group_stats.get("fair_cycle_exhausted", False) + fc_reason = group_stats.get("fair_cycle_reason") + group_cooldown_remaining = group_stats.get("cooldown_remaining") + group_cooldown_source = group_stats.get("cooldown_source") + custom_cap = group_stats.get("custom_cap") + + for window_name, window_stats in windows.items(): + request_count = window_stats.get("request_count", 0) + limit = window_stats.get("limit") + remaining = window_stats.get("remaining") + reset_at = window_stats.get("reset_at") + max_recorded = window_stats.get("max_recorded_requests") + max_recorded_at = window_stats.get("max_recorded_at") + + # Calculate remaining percentage + if limit is not None and limit > 0: + remaining_val = ( + remaining + if remaining is not None + else max(0, limit - request_count) ) + remaining_pct = round(remaining_val / limit * 100, 1) + is_exhausted = remaining_val <= 0 else: - ungrouped_cooldowns.append((model_name, remaining)) - - if group_cooldowns or ungrouped_cooldowns: - content_lines.append("") - content_lines.append("[yellow]Active Cooldowns:[/yellow]") + remaining_pct = None + remaining_val = None + is_exhausted = False + + # Format reset time (only show if there's actual usage or cooldown) + reset_time_str = "" + if reset_at and (request_count > 0 or group_cooldown_remaining): + try: + reset_dt = datetime.fromtimestamp(reset_at) + reset_time_str = reset_dt.strftime("%b %d %H:%M") + except (ValueError, OSError): + reset_time_str = "" + + # Format max recorded info + max_info = "" + if max_recorded is not None: + max_info = f" Max: {max_recorded}" + if max_recorded_at: + try: + max_dt = datetime.fromtimestamp(max_recorded_at) + max_info += f" @ {max_dt.strftime('%b %d')}" + except (ValueError, OSError): + pass + + # Build display line + bar = create_progress_bar(remaining_pct) + + # Determine color (account for fair cycle) + if is_exhausted: + color = "red" + elif fc_exhausted: + color = "yellow" + elif remaining_pct is not None and remaining_pct < 20: + color = "yellow" + else: + color = "green" - # Show grouped cooldowns - for group_name in sorted(group_cooldowns.keys()): - remaining = group_cooldowns[group_name] - content_lines.append( - f" [yellow]⏱️ {group_name}: {format_cooldown(remaining)}[/yellow]" + # Format group name + if len(windows) > 1: + display_name = f"{group_name} ({window_name})" + else: + display_name = group_name + + # Format usage string with custom cap if applicable + if custom_cap: + cap_remaining = custom_cap.get("remaining", 0) + cap_limit = custom_cap.get("limit", 0) + api_limit = limit if limit else cap_limit + usage_str = f"{cap_remaining}/{cap_limit}({api_limit})" + # Recalculate percentage based on custom cap + if cap_limit > 0: + remaining_pct = round(cap_remaining / cap_limit * 100, 1) + bar = create_progress_bar(remaining_pct) + pct_str = ( + f"{remaining_pct}%" if remaining_pct is not None else "" ) + elif limit is not None: + usage_str = f"{remaining_val}/{limit}" + pct_str = f"{remaining_pct}%" + else: + usage_str = f"{request_count} req" + pct_str = "" - # Show ungrouped (shouldn't happen often) - for model_name, remaining in ungrouped_cooldowns: - short_model = model_name.split("/")[-1][:35] - content_lines.append( - f" [yellow]⏱️ {short_model}: {format_cooldown(remaining)}[/yellow]" + line = f" [{color}]{display_name:<{DETAIL_GROUP_NAME_WIDTH}} {usage_str:<{DETAIL_USAGE_WIDTH}} {pct_str:>{DETAIL_PCT_WIDTH}} {bar}[/{color}]" + + # Add reset time if applicable + if reset_time_str: + line += f" Resets: {reset_time_str}" + + # Add indicators + indicators = [] + + # Group cooldown indicator (if not already showing reset time) + if group_cooldown_remaining and not reset_time_str: + indicators.append( + f"[yellow]{INDICATOR_ICONS['cooldown']} {format_cooldown(int(group_cooldown_remaining))}[/yellow]" ) - else: - # No model groups - show per-model cooldowns - content_lines.append("") - content_lines.append("[yellow]Active Cooldowns:[/yellow]") - for model_name, cooldown_info in model_cooldowns.items(): - remaining = cooldown_info.get("remaining_seconds", 0) - if remaining > 0: - short_model = model_name.split("/")[-1][:35] - content_lines.append( - f" [yellow]⏱️ {short_model}: {format_cooldown(int(remaining))}[/yellow]" + + # Fair cycle indicator + if fc_exhausted: + indicators.append( + f"[yellow]{INDICATOR_ICONS['fair_cycle']} FC[/yellow]" ) - # Display model groups with quota info - if model_groups: - content_lines.append("") - for group_name, group_stats in model_groups.items(): - remaining_pct = group_stats.get("remaining_pct") - requests_used = group_stats.get("requests_used", 0) - requests_max = group_stats.get("requests_max") - requests_remaining = group_stats.get("requests_remaining") - is_exhausted = group_stats.get("is_exhausted", False) - reset_time = format_reset_time(group_stats.get("reset_time_iso")) - confidence = group_stats.get("confidence", "low") - - # Format display - use requests_remaining/max format - if requests_remaining is None and requests_max: - requests_remaining = max(0, requests_max - requests_used) - display = group_stats.get( - "display", f"{requests_remaining or 0}/{requests_max or '?'}" - ) - bar = create_progress_bar(remaining_pct) + # Custom cap indicator (only if not on cooldown, just to show cap exists) + if custom_cap and not group_cooldown_remaining and not fc_exhausted: + indicators.append(f"[dim]{INDICATOR_ICONS['custom_cap']}[/dim]") - # Build status text - always show reset time if available - has_reset_time = reset_time and reset_time != "-" + if indicators: + line += " " + " ".join(indicators) - # Color based on status - if is_exhausted: - color = "red" - if has_reset_time: - status_text = f"⛔ Resets: {reset_time}" - else: - status_text = "⛔ EXHAUSTED" - elif remaining_pct is not None and remaining_pct < 20: - color = "yellow" - if has_reset_time: - status_text = f"⚠️ Resets: {reset_time}" - else: - status_text = "⚠️ LOW" - else: - color = "green" - if has_reset_time: - status_text = f"Resets: {reset_time}" - else: - status_text = "" # Hide if unused/no reset time + # Add max info at the end + if max_info: + line += f" [dim]{max_info}[/dim]" - # Confidence indicator - conf_indicator = "" - if confidence == "low": - conf_indicator = " [dim](~)[/dim]" - elif confidence == "medium": - conf_indicator = " [dim](?)[/dim]" + content_lines.append(line) - pct_str = f"{remaining_pct}%" if remaining_pct is not None else "?%" + # Model usage (show if no group usage, or if toggle enabled via config) + model_usage = cred.get("model_usage", {}) + # Check config for show_models setting, default to showing only if no group_usage + show_models = self.config.get_show_models(provider) if group_usage else True + + if show_models and model_usage: + content_lines.append("") + content_lines.append("[dim]Models used:[/dim]") + for model_name, model_stats in model_usage.items(): + totals = model_stats.get("totals", {}) + req_count = totals.get("request_count", 0) + if req_count == 0: + continue # Skip models with no usage + + prompt = totals.get("prompt_tokens", 0) + cache_read = totals.get("prompt_tokens_cache_read", 0) + output_tokens = totals.get("output_tokens", 0) + model_cost = format_cost(totals.get("approx_cost")) + + total_input = prompt + cache_read + short_name = model_name.split("/")[-1][:30] content_lines.append( - f" [{color}]{group_name:<18} {display:<10} {pct_str:>4} {bar}[/{color}] {status_text}{conf_indicator}" + f" {short_name}: {req_count} req | {format_tokens(total_input)}/{format_tokens(output_tokens)} tokens" + + (f" | {model_cost}" if model_cost != "-" else "") ) - else: - # For providers without quota groups, show model breakdown if available - models = cred.get("models", {}) - if models: - content_lines.append("") - content_lines.append(" [dim]Models used:[/dim]") - for model_name, model_stats in models.items(): - req_count = model_stats.get("success_count", 0) - model_cost = format_cost(model_stats.get("approx_cost")) - # Shorten model name for display - short_name = model_name.split("/")[-1][:30] - content_lines.append( - f" {short_name}: {req_count} requests, {model_cost}" - ) self.console.print( Panel( @@ -1218,12 +1469,86 @@ def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str ) ) + def _inject_fake_daily_window(self, provider: str) -> None: + """ + DEBUG: Inject a fake 'daily' window into cached stats for testing. + + This modifies cached_stats in-place to add a second window to each + quota group, simulating multi-window display without needing real + multi-window data from the API. + """ + if not self.cached_stats: + return + + prov_stats = self.cached_stats.get("providers", {}).get(provider, {}) + if not prov_stats: + return + + # Inject into quota_groups (global view) + quota_groups = prov_stats.get("quota_groups", {}) + for group_name, group_stats in quota_groups.items(): + windows = group_stats.get("windows", {}) + if "daily" not in windows and windows: + # Copy the first window and modify it + first_window = next(iter(windows.values())) + daily_window = { + "total_used": int(first_window.get("total_used", 0) * 0.3), + "total_remaining": int(first_window.get("total_max", 100) * 0.7), + "total_max": first_window.get("total_max", 100), + "remaining_pct": 70.0, + "tier_availability": first_window.get("tier_availability", {}), + } + windows["daily"] = daily_window + + # Inject into credential group_usage (detail view) + credentials = prov_stats.get("credentials", {}) + if isinstance(credentials, dict): + cred_list = credentials.values() + else: + cred_list = credentials + + for cred in cred_list: + group_usage = cred.get("group_usage", {}) + for group_name, group_stats in group_usage.items(): + windows = group_stats.get("windows", {}) + if "daily" not in windows and windows: + # Copy the first window and create a daily version + first_window = next(iter(windows.values())) + limit = first_window.get("limit", 100) + daily_window = { + "request_count": int( + first_window.get("request_count", 0) * 0.3 + ), + "success_count": int( + first_window.get("success_count", 0) * 0.3 + ), + "failure_count": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "thinking_tokens": 0, + "output_tokens": 0, + "prompt_tokens_cache_read": 0, + "prompt_tokens_cache_write": 0, + "total_tokens": 0, + "limit": limit, + "remaining": int(limit * 0.7), + "max_recorded_requests": None, + "max_recorded_at": None, + "reset_at": time.time() + 3600, # 1 hour from now + "approx_cost": 0.0, + "first_used_at": None, + "last_used_at": None, + } + windows["daily"] = daily_window + def show_switch_remote_screen(self): """Display remote selection screen.""" clear_screen() self.console.print("━" * 78) - self.console.print("[bold cyan]🔄 Switch Remote[/bold cyan]") + self.console.print( + "[bold cyan]:arrows_counterclockwise: Switch Remote[/bold cyan]" + ) self.console.print("━" * 78) self.console.print() @@ -1258,9 +1583,9 @@ def show_switch_remote_screen(self): current_marker = " (current)" if is_current else "" if is_online: - status_icon = "[green]✅ Online[/green]" + status_icon = "[green]:white_check_mark: Online[/green]" else: - status_icon = f"[red]⚠️ {status_msg}[/red]" + status_icon = f"[red]:warning: {status_msg}[/red]" self.console.print( f" {idx}. {name:<20} {connection_display:<30} {status_icon}{current_marker}" @@ -1330,7 +1655,7 @@ def show_manage_remotes_screen(self): clear_screen() self.console.print("━" * 78) - self.console.print("[bold cyan]⚙️ Manage Remotes[/bold cyan]") + self.console.print("[bold cyan]:gear: Manage Remotes[/bold cyan]") self.console.print("━" * 78) self.console.print() diff --git a/src/proxy_app/quota_viewer_config.py b/src/proxy_app/quota_viewer_config.py index 7b2f573f..10b0d506 100644 --- a/src/proxy_app/quota_viewer_config.py +++ b/src/proxy_app/quota_viewer_config.py @@ -298,3 +298,48 @@ def get_api_key_from_env(self) -> Optional[str]: except IOError: pass return None + + def get_show_models(self, provider: str) -> bool: + """ + Get whether to show model breakdown for a provider. + + Args: + provider: Provider name + + Returns: + True if models should be shown, False otherwise + """ + provider_settings = self.config.get("provider_settings", {}) + return provider_settings.get(provider, {}).get("show_models", False) + + def set_show_models(self, provider: str, show: bool) -> bool: + """ + Set whether to show model breakdown for a provider. + + Args: + provider: Provider name + show: Whether to show models + + Returns: + True on success + """ + if "provider_settings" not in self.config: + self.config["provider_settings"] = {} + if provider not in self.config["provider_settings"]: + self.config["provider_settings"][provider] = {} + self.config["provider_settings"][provider]["show_models"] = show + return self._save() + + def toggle_show_models(self, provider: str) -> bool: + """ + Toggle the show_models setting for a provider. + + Args: + provider: Provider name + + Returns: + The new value (True if now showing, False if now hidden) + """ + current = self.get_show_models(provider) + self.set_show_models(provider, not current) + return not current diff --git a/src/proxy_app/settings_tool.py b/src/proxy_app/settings_tool.py index 689839bb..006082b6 100644 --- a/src/proxy_app/settings_tool.py +++ b/src/proxy_app/settings_tool.py @@ -45,6 +45,13 @@ except ImportError: IFLOW_DEFAULT_OAUTH_PORT = 11451 +try: + from rotator_library.providers.openai_codex_auth_base import OpenAICodexAuthBase + + OPENAI_CODEX_DEFAULT_OAUTH_PORT = OpenAICodexAuthBase.CALLBACK_PORT +except ImportError: + OPENAI_CODEX_DEFAULT_OAUTH_PORT = 1455 + def clear_screen(subtitle: str = ""): """ @@ -553,11 +560,21 @@ def remove_multiplier(self, provider: str, priority: int): }, } +# OpenAI Codex provider environment variables +OPENAI_CODEX_SETTINGS = { + "OPENAI_CODEX_OAUTH_PORT": { + "type": "int", + "default": OPENAI_CODEX_DEFAULT_OAUTH_PORT, + "description": "Local port for OAuth callback server during authentication", + }, +} + # Map provider names to their settings definitions PROVIDER_SETTINGS_MAP = { "antigravity": ANTIGRAVITY_SETTINGS, "gemini_cli": GEMINI_CLI_SETTINGS, "iflow": IFLOW_SETTINGS, + "openai_codex": OPENAI_CODEX_SETTINGS, } @@ -673,7 +690,7 @@ def _format_item( def _get_pending_status_text(self) -> str: """Get formatted pending changes status text for main menu.""" if not self.settings.has_pending(): - return "[dim]ℹ️ No pending changes[/dim]" + return "[dim]:information_source: No pending changes[/dim]" counts = self.settings.get_pending_counts() parts = [] @@ -690,7 +707,7 @@ def _get_pending_status_text(self) -> str: f"[red]{counts['remove']} removal{'s' if counts['remove'] > 1 else ''}[/red]" ) - return f"[bold]ℹ️ Pending changes: {', '.join(parts)}[/bold]" + return f"[bold]:information_source: Pending changes: {', '.join(parts)}[/bold]" self.running = True def get_available_providers(self) -> List[str]: @@ -739,21 +756,21 @@ def show_main_menu(self): self.console.print( Panel.fit( - "[bold cyan]🔧 Advanced Settings Configuration[/bold cyan]", + "[bold cyan]:wrench: Advanced Settings Configuration[/bold cyan]", border_style="cyan", ) ) self.console.print() - self.console.print("[bold]⚙️ Configuration Categories[/bold]") + self.console.print("[bold]:gear: Configuration Categories[/bold]") self.console.print() - self.console.print(" 1. 🌐 Custom Provider API Bases") + self.console.print(" 1. :globe_with_meridians: Custom Provider API Bases") self.console.print(" 2. 📦 Provider Model Definitions") self.console.print(" 3. ⚡ Concurrency Limits") - self.console.print(" 4. 🔄 Rotation Modes") + self.console.print(" 4. :arrows_counterclockwise: Rotation Modes") self.console.print(" 5. 🔬 Provider-Specific Settings") - self.console.print(" 6. 🎯 Model Filters (Ignore/Whitelist)") - self.console.print(" 7. 💾 Save & Exit") + self.console.print(" 6. :dart: Model Filters (Ignore/Whitelist)") + self.console.print(" 7. :floppy_disk: Save & Exit") self.console.print(" 8. 🚫 Exit Without Saving") self.console.print() @@ -796,13 +813,13 @@ def manage_custom_providers(self): self.console.print( Panel.fit( - "[bold cyan]🌐 Custom Provider API Bases[/bold cyan]", + "[bold cyan]:globe_with_meridians: Custom Provider API Bases[/bold cyan]", border_style="cyan", ) ) self.console.print() - self.console.print("[bold]📋 Configured Custom Providers[/bold]") + self.console.print("[bold]:clipboard: Configured Custom Providers[/bold]") self.console.print("━" * 70) # Build combined view with pending changes @@ -853,11 +870,11 @@ def manage_custom_providers(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]⚙️ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Add New Custom Provider") self.console.print(" 2. ✏️ Edit Existing Provider") - self.console.print(" 3. 🗑️ Remove Provider") + self.console.print(" 3. :wastebasket: Remove Provider") self.console.print(" 4. ↩️ Back to Settings Menu") self.console.print() @@ -875,7 +892,7 @@ def manage_custom_providers(self): if api_base: self.provider_mgr.add_provider(name, api_base) self.console.print( - f"\n[green]✅ Custom provider '{name}' staged![/green]" + f"\n[green]:white_check_mark: Custom provider '{name}' staged![/green]" ) self.console.print( f" To use: set {name.upper()}_API_KEY in credentials" @@ -915,7 +932,7 @@ def manage_custom_providers(self): if new_base and new_base != current_base: self.provider_mgr.edit_provider(name, new_base) self.console.print( - f"\n[green]✅ Custom provider '{name}' updated![/green]" + f"\n[green]:white_check_mark: Custom provider '{name}' updated![/green]" ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -962,12 +979,12 @@ def manage_custom_providers(self): key = f"{name.upper()}_API_BASE" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending addition of '{name}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending addition of '{name}' cancelled![/green]" ) else: self.provider_mgr.remove_provider(name) self.console.print( - f"\n[green]✅ Provider '{name}' marked for removal![/green]" + f"\n[green]:white_check_mark: Provider '{name}' marked for removal![/green]" ) input("\nPress Enter to continue...") @@ -990,7 +1007,7 @@ def manage_model_definitions(self): ) self.console.print() - self.console.print("[bold]📋 Configured Provider Models[/bold]") + self.console.print("[bold]:clipboard: Configured Provider Models[/bold]") self.console.print("━" * 70) # Build combined view with pending changes @@ -1063,12 +1080,12 @@ def manage_model_definitions(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]⚙️ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Add Models for Provider") self.console.print(" 2. ✏️ Edit Provider Models") self.console.print(" 3. 👁️ View Provider Models") - self.console.print(" 4. 🗑️ Remove Provider Models") + self.console.print(" 4. :wastebasket: Remove Provider Models") self.console.print(" 5. ↩️ Back to Settings Menu") self.console.print() @@ -1141,12 +1158,12 @@ def manage_model_definitions(self): key = f"{provider.upper()}{suffix}" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending models for '{provider}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending models for '{provider}' cancelled![/green]" ) else: self.model_mgr.remove_models(provider) self.console.print( - f"\n[green]✅ Model definitions marked for removal for '{provider}'![/green]" + f"\n[green]:white_check_mark: Model definitions marked for removal for '{provider}'![/green]" ) input("\nPress Enter to continue...") elif choice == "5": @@ -1244,7 +1261,7 @@ def add_model_definitions(self): if models: self.model_mgr.set_models(provider, models) self.console.print( - f"\n[green]✅ Model definitions saved for '{provider}'![/green]" + f"\n[green]:white_check_mark: Model definitions saved for '{provider}'![/green]" ) else: self.console.print("\n[yellow]No models added[/yellow]") @@ -1343,7 +1360,9 @@ def edit_model_definitions(self, providers: List[str]): if current_models: self.model_mgr.set_models(provider, current_models) - self.console.print(f"\n[green]✅ Models updated for '{provider}'![/green]") + self.console.print( + f"\n[green]:white_check_mark: Models updated for '{provider}'![/green]" + ) else: self.console.print( "\n[yellow]No models left - removing definition[/yellow]" @@ -1433,7 +1452,7 @@ def manage_provider_settings(self): self.console.print() self.console.print( - "[bold]📋 Available Providers with Custom Settings[/bold]" + "[bold]:clipboard: Available Providers with Custom Settings[/bold]" ) self.console.print("━" * 70) @@ -1450,7 +1469,7 @@ def manage_provider_settings(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]⚙️ Select Provider to Configure[/bold]") + self.console.print("[bold]:gear: Select Provider to Configure[/bold]") self.console.print() for idx, provider in enumerate(available_providers, 1): @@ -1492,7 +1511,7 @@ def _manage_single_provider_settings(self, provider: str): ) self.console.print() - self.console.print("[bold]📋 Current Settings[/bold]") + self.console.print("[bold]:clipboard: Current Settings[/bold]") self.console.print("━" * 70) # Display all settings with current values and pending changes @@ -1587,11 +1606,13 @@ def _manage_single_provider_settings(self, provider: str): "[dim]* = modified from default, + = pending add, ~ = pending edit, - = pending reset[/dim]" ) self.console.print() - self.console.print("[bold]⚙️ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" E. ✏️ Edit a Setting") - self.console.print(" R. 🔄 Reset Setting to Default") - self.console.print(" A. 🔄 Reset All to Defaults") + self.console.print( + " R. :arrows_counterclockwise: Reset Setting to Default" + ) + self.console.print(" A. :arrows_counterclockwise: Reset All to Defaults") self.console.print(" B. ↩️ Back to Provider Selection") self.console.print() @@ -1641,18 +1662,24 @@ def _edit_provider_setting( new_value = Confirm.ask("\nEnable this setting?", default=current) self.provider_settings_mgr.set_value(key, new_value, definition) status = "enabled" if new_value else "disabled" - self.console.print(f"\n[green]✅ {short_key} {status}![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} {status}![/green]" + ) elif setting_type == "int": new_value = IntPrompt.ask("\nNew value", default=current) self.provider_settings_mgr.set_value(key, new_value, definition) - self.console.print(f"\n[green]✅ {short_key} set to {new_value}![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} set to {new_value}![/green]" + ) else: new_value = Prompt.ask( "\nNew value", default=str(current) if current else "" ).strip() if new_value: self.provider_settings_mgr.set_value(key, new_value, definition) - self.console.print(f"\n[green]✅ {short_key} updated![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} updated![/green]" + ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -1677,7 +1704,9 @@ def _reset_provider_setting( if Confirm.ask(f"\nReset {short_key} to default ({default})?"): self.provider_settings_mgr.reset_to_default(key) - self.console.print(f"\n[green]✅ {short_key} reset to default![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} reset to default![/green]" + ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -1693,7 +1722,7 @@ def _reset_all_provider_settings(self, provider: str, settings_list: List[str]): for key in settings_list: self.provider_settings_mgr.reset_to_default(key) self.console.print( - f"\n[green]✅ All {display_name} settings reset to defaults![/green]" + f"\n[green]:white_check_mark: All {display_name} settings reset to defaults![/green]" ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -1711,13 +1740,13 @@ def manage_rotation_modes(self): self.console.print( Panel.fit( - "[bold cyan]🔄 Credential Rotation Mode Configuration[/bold cyan]", + "[bold cyan]:arrows_counterclockwise: Credential Rotation Mode Configuration[/bold cyan]", border_style="cyan", ) ) self.console.print() - self.console.print("[bold]📋 Rotation Modes Explained[/bold]") + self.console.print("[bold]:clipboard: Rotation Modes Explained[/bold]") self.console.print("━" * 70) self.console.print( " [cyan]balanced[/cyan] - Rotate credentials evenly across requests (default)" @@ -1726,7 +1755,9 @@ def manage_rotation_modes(self): " [cyan]sequential[/cyan] - Use one credential until exhausted (429), then switch" ) self.console.print() - self.console.print("[bold]📋 Current Rotation Mode Settings[/bold]") + self.console.print( + "[bold]:clipboard: Current Rotation Mode Settings[/bold]" + ) self.console.print("━" * 70) # Build combined view with pending changes @@ -1821,10 +1852,10 @@ def manage_rotation_modes(self): "[dim]* = custom setting (differs from provider default)[/dim]" ) self.console.print() - self.console.print("[bold]⚙️ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Set Rotation Mode for Provider") - self.console.print(" 2. 🗑️ Reset to Provider Default") + self.console.print(" 2. :wastebasket: Reset to Provider Default") self.console.print(" 3. ⚡ Configure Priority Concurrency Multipliers") self.console.print(" 4. ↩️ Back to Settings Menu") @@ -1888,7 +1919,7 @@ def manage_rotation_modes(self): self.rotation_mgr.set_mode(provider, new_mode) self.console.print( - f"\n[green]✅ Rotation mode for '{provider}' staged as {new_mode}![/green]" + f"\n[green]:white_check_mark: Rotation mode for '{provider}' staged as {new_mode}![/green]" ) input("\nPress Enter to continue...") @@ -1935,12 +1966,12 @@ def manage_rotation_modes(self): key = f"{prefix}{provider.upper()}" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending mode for '{provider}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending mode for '{provider}' cancelled![/green]" ) else: self.rotation_mgr.remove_mode(provider) self.console.print( - f"\n[green]✅ Rotation mode for '{provider}' marked for reset to default ({default_mode})![/green]" + f"\n[green]:white_check_mark: Rotation mode for '{provider}' marked for reset to default ({default_mode})![/green]" ) input("\nPress Enter to continue...") @@ -1965,7 +1996,9 @@ def manage_priority_multipliers(self): ) self.console.print() - self.console.print("[bold]📋 Current Priority Multiplier Settings[/bold]") + self.console.print( + "[bold]:clipboard: Current Priority Multiplier Settings[/bold]" + ) self.console.print("━" * 70) # Show all providers with their priority multipliers @@ -2009,7 +2042,9 @@ def manage_priority_multipliers(self): self.console.print(" [dim]No priority multipliers configured[/dim]") self.console.print() - self.console.print("[bold]ℹ️ About Priority Multipliers:[/bold]") + self.console.print( + "[bold]:information_source: About Priority Multipliers:[/bold]" + ) self.console.print( " Higher priority tiers (lower numbers) can have higher multipliers." ) @@ -2018,7 +2053,7 @@ def manage_priority_multipliers(self): self.console.print("━" * 70) self.console.print() self.console.print(" 1. ✏️ Set Priority Multiplier") - self.console.print(" 2. 🔄 Reset to Provider Default") + self.console.print(" 2. :arrows_counterclockwise: Reset to Provider Default") self.console.print(" 3. ↩️ Back") choice = Prompt.ask( @@ -2059,7 +2094,7 @@ def manage_priority_multipliers(self): provider, priority, multiplier ) self.console.print( - f"\n[green]✅ Priority {priority} multiplier for '{provider}' set to {multiplier}x[/green]" + f"\n[green]:white_check_mark: Priority {priority} multiplier for '{provider}' set to {multiplier}x[/green]" ) else: self.console.print( @@ -2101,7 +2136,7 @@ def manage_priority_multipliers(self): provider, priority ) self.console.print( - f"\n[green]✅ Reset priority {priority} for '{provider}' to default ({default}x)[/green]" + f"\n[green]:white_check_mark: Reset priority {priority} for '{provider}' to default ({default}x)[/green]" ) else: self.console.print( @@ -2125,7 +2160,7 @@ def manage_concurrency_limits(self): ) self.console.print() - self.console.print("[bold]📋 Current Concurrency Settings[/bold]") + self.console.print("[bold]:clipboard: Current Concurrency Settings[/bold]") self.console.print("━" * 70) # Build combined view with pending changes @@ -2185,11 +2220,11 @@ def manage_concurrency_limits(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]⚙️ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Add Concurrency Limit for Provider") self.console.print(" 2. ✏️ Edit Existing Limit") - self.console.print(" 3. 🗑️ Remove Limit (reset to default)") + self.console.print(" 3. :wastebasket: Remove Limit (reset to default)") self.console.print(" 4. ↩️ Back to Settings Menu") self.console.print() @@ -2236,11 +2271,11 @@ def manage_concurrency_limits(self): if 1 <= limit <= 100: self.concurrency_mgr.set_limit(provider, limit) self.console.print( - f"\n[green]✅ Concurrency limit staged for '{provider}': {limit} requests/key[/green]" + f"\n[green]:white_check_mark: Concurrency limit staged for '{provider}': {limit} requests/key[/green]" ) else: self.console.print( - "\n[red]❌ Limit must be between 1-100[/red]" + "\n[red]:x: Limit must be between 1-100[/red]" ) input("\nPress Enter to continue...") @@ -2278,7 +2313,7 @@ def manage_concurrency_limits(self): if new_limit != current_limit: self.concurrency_mgr.set_limit(provider, new_limit) self.console.print( - f"\n[green]✅ Concurrency limit updated for '{provider}': {new_limit} requests/key[/green]" + f"\n[green]:white_check_mark: Concurrency limit updated for '{provider}': {new_limit} requests/key[/green]" ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -2330,12 +2365,12 @@ def manage_concurrency_limits(self): key = f"{prefix}{provider.upper()}" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending limit for '{provider}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending limit for '{provider}' cancelled![/green]" ) else: self.concurrency_mgr.remove_limit(provider) self.console.print( - f"\n[green]✅ Limit marked for removal for '{provider}'[/green]" + f"\n[green]:white_check_mark: Limit marked for removal for '{provider}'[/green]" ) input("\nPress Enter to continue...") @@ -2346,7 +2381,7 @@ def _show_changes_summary(self): """Display categorized summary of all pending changes.""" self.console.print( Panel.fit( - "[bold cyan]📋 Pending Changes Summary[/bold cyan]", + "[bold cyan]:clipboard: Pending Changes Summary[/bold cyan]", border_style="cyan", ) ) @@ -2439,7 +2474,9 @@ def save_and_exit(self): if Confirm.ask("\n[bold yellow]Save all pending changes?[/bold yellow]"): self.settings.save() - self.console.print("\n[green]✅ All changes saved to .env![/green]") + self.console.print( + "\n[green]:white_check_mark: All changes saved to .env![/green]" + ) input("\nPress Enter to return to launcher...") else: self.console.print("\n[yellow]Changes not saved[/yellow]") diff --git a/src/rotator_library/README.md b/src/rotator_library/README.md index 22d2bf6e..657032ce 100644 --- a/src/rotator_library/README.md +++ b/src/rotator_library/README.md @@ -23,7 +23,7 @@ A robust, asynchronous, and thread-safe Python library for managing a pool of AP - **Credential Prioritization**: Automatic tier detection and priority-based credential selection (e.g., paid tier credentials used first for models that require them). - **Advanced Model Requirements**: Support for model-tier restrictions (e.g., Gemini 3 requires paid-tier credentials). - **Robust Streaming Support**: Includes a wrapper for streaming responses that reassembles fragmented JSON chunks. -- **Detailed Usage Tracking**: Tracks daily and global usage for each key, persisted to a JSON file. +- **Detailed Usage Tracking**: Tracks daily and global usage for each key, persisted per provider in `usage/usage_.json`. - **Automatic Daily Resets**: Automatically resets cooldowns and archives stats daily. - **Provider Agnostic**: Works with any provider supported by `litellm`. - **Extensible**: Easily add support for new providers through a simple plugin-based architecture. @@ -73,7 +73,7 @@ client = RotatingClient( api_keys=api_keys, oauth_credentials=oauth_credentials, max_retries=2, - usage_file_path="key_usage.json", + usage_file_path="usage.json", configure_logging=True, global_timeout=30, abort_on_callback_error=True, @@ -91,7 +91,7 @@ client = RotatingClient( - `api_keys` (`Optional[Dict[str, List[str]]]`): A dictionary mapping provider names (e.g., "openai", "anthropic") to a list of API keys. - `oauth_credentials` (`Optional[Dict[str, List[str]]]`): A dictionary mapping provider names (e.g., "gemini_cli", "qwen_code") to a list of file paths to OAuth credential JSON files. - `max_retries` (`int`, default: `2`): The number of times to retry a request with the *same key* if a transient server error (e.g., 500, 503) occurs. -- `usage_file_path` (`str`, default: `"key_usage.json"`): The path to the JSON file where usage statistics (tokens, cost, success counts) are persisted. +- `usage_file_path` (`str`, optional): Base path for usage persistence (defaults to `usage/` in the data directory). The client stores per-provider files under `usage/usage_.json` next to this path. - `configure_logging` (`bool`, default: `True`): If `True`, configures the library's logger to propagate logs to the root logger. Set to `False` if you want to handle logging configuration manually. - `global_timeout` (`int`, default: `30`): A hard time limit (in seconds) for the entire request lifecycle. If the request (including all retries) takes longer than this, it is aborted. - `abort_on_callback_error` (`bool`, default: `True`): If `True`, any exception raised by `pre_request_callback` will abort the request. If `False`, the error is logged and the request proceeds. diff --git a/src/rotator_library/background_refresher.py b/src/rotator_library/background_refresher.py index e3da1f76..acc66c89 100644 --- a/src/rotator_library/background_refresher.py +++ b/src/rotator_library/background_refresher.py @@ -234,9 +234,13 @@ async def _run_provider_background_job( # Run immediately on start if configured if run_on_start: try: - await provider.run_background_job( - self._client.usage_manager, credentials - ) + usage_manager = self._client.usage_managers.get(provider_name) + if usage_manager is None: + lib_logger.debug( + f"Skipping {provider_name} {job_name}: no UsageManager" + ) + return + await provider.run_background_job(usage_manager, credentials) lib_logger.debug(f"{provider_name} {job_name}: initial run complete") except Exception as e: lib_logger.error( @@ -247,9 +251,13 @@ async def _run_provider_background_job( while True: try: await asyncio.sleep(interval) - await provider.run_background_job( - self._client.usage_manager, credentials - ) + usage_manager = self._client.usage_managers.get(provider_name) + if usage_manager is None: + lib_logger.debug( + f"Skipping {provider_name} {job_name}: no UsageManager" + ) + return + await provider.run_background_job(usage_manager, credentials) lib_logger.debug(f"{provider_name} {job_name}: periodic run complete") except asyncio.CancelledError: lib_logger.debug(f"{provider_name} {job_name}: cancelled") @@ -259,6 +267,7 @@ async def _run_provider_background_job( async def _run(self): """The main loop for OAuth token refresh.""" + await self._client.initialize_usage_managers() # Initialize credentials (load persisted tiers) before starting await self._initialize_credentials() diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py deleted file mode 100644 index fdd12d67..00000000 --- a/src/rotator_library/client.py +++ /dev/null @@ -1,3581 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-only -# Copyright (c) 2026 Mirrowel - -import asyncio -import fnmatch -import json -import re -import codecs -import time -import os -import random -import httpx -import litellm -from litellm.exceptions import APIConnectionError -from litellm.litellm_core_utils.token_counter import token_counter -import logging -from pathlib import Path -from typing import List, Dict, Any, AsyncGenerator, Optional, Union, Tuple - -lib_logger = logging.getLogger("rotator_library") -# Ensure the logger is configured to propagate to the root logger -# which is set up in main.py. This allows the main app to control -# log levels and handlers centrally. -lib_logger.propagate = False - -from .usage_manager import UsageManager -from .failure_logger import log_failure, configure_failure_logger -from .error_handler import ( - PreRequestCallbackError, - CredentialNeedsReauthError, - classify_error, - NoAvailableKeysError, - should_rotate_on_error, - should_retry_same_key, - RequestErrorAccumulator, - mask_credential, -) -from .provider_config import ProviderConfig -from .providers import PROVIDER_PLUGINS -from .providers.openai_compatible_provider import OpenAICompatibleProvider -from .request_sanitizer import sanitize_request_payload -from .cooldown_manager import CooldownManager -from .credential_manager import CredentialManager -from .background_refresher import BackgroundRefresher -from .model_definitions import ModelDefinitions -from .transaction_logger import TransactionLogger -from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file -from .utils.suppress_litellm_warnings import suppress_litellm_serialization_warnings -from .config import ( - DEFAULT_MAX_RETRIES, - DEFAULT_GLOBAL_TIMEOUT, - DEFAULT_ROTATION_TOLERANCE, - DEFAULT_FAIR_CYCLE_DURATION, - DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, - DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, -) - - -class StreamedAPIError(Exception): - """Custom exception to signal an API error received over a stream.""" - - def __init__(self, message, data=None): - super().__init__(message) - self.data = data - - -class RotatingClient: - """ - A client that intelligently rotates and retries API keys using LiteLLM, - with support for both streaming and non-streaming responses. - """ - - def __init__( - self, - api_keys: Optional[Dict[str, List[str]]] = None, - oauth_credentials: Optional[Dict[str, List[str]]] = None, - max_retries: int = DEFAULT_MAX_RETRIES, - usage_file_path: Optional[Union[str, Path]] = None, - configure_logging: bool = True, - global_timeout: int = DEFAULT_GLOBAL_TIMEOUT, - abort_on_callback_error: bool = True, - litellm_provider_params: Optional[Dict[str, Any]] = None, - ignore_models: Optional[Dict[str, List[str]]] = None, - whitelist_models: Optional[Dict[str, List[str]]] = None, - enable_request_logging: bool = False, - max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, - rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE, - data_dir: Optional[Union[str, Path]] = None, - ): - """ - Initialize the RotatingClient with intelligent credential rotation. - - Args: - api_keys: Dictionary mapping provider names to lists of API keys - oauth_credentials: Dictionary mapping provider names to OAuth credential paths - max_retries: Maximum number of retry attempts per credential - usage_file_path: Path to store usage statistics. If None, uses data_dir/key_usage.json - configure_logging: Whether to configure library logging - global_timeout: Global timeout for requests in seconds - abort_on_callback_error: Whether to abort on pre-request callback errors - litellm_provider_params: Provider-specific parameters for LiteLLM - ignore_models: Models to ignore/blacklist per provider - whitelist_models: Models to explicitly whitelist per provider - enable_request_logging: Whether to enable detailed request logging - max_concurrent_requests_per_key: Max concurrent requests per key by provider - rotation_tolerance: Tolerance for weighted random credential rotation. - - 0.0: Deterministic, least-used credential always selected - - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - - 5.0+: High randomness, more unpredictable selection patterns - data_dir: Root directory for all data files (logs, cache, oauth_creds, key_usage.json). - If None, auto-detects: EXE directory if frozen, else current working directory. - """ - # Resolve data_dir early - this becomes the root for all file operations - if data_dir is not None: - self.data_dir = Path(data_dir).resolve() - else: - self.data_dir = get_default_root() - - # Configure failure logger to use correct logs directory - configure_failure_logger(get_logs_dir(self.data_dir)) - - os.environ["LITELLM_LOG"] = "ERROR" - litellm.set_verbose = False - litellm.drop_params = True - - # Suppress harmless Pydantic serialization warnings from litellm - # See: https://github.com/BerriAI/litellm/issues/11759 - # TODO: Remove this workaround once litellm patches the issue - suppress_litellm_serialization_warnings() - - if configure_logging: - # When True, this allows logs from this library to be handled - # by the parent application's logging configuration. - lib_logger.propagate = True - # Remove any default handlers to prevent duplicate logging - if lib_logger.hasHandlers(): - lib_logger.handlers.clear() - lib_logger.addHandler(logging.NullHandler()) - else: - lib_logger.propagate = False - - api_keys = api_keys or {} - oauth_credentials = oauth_credentials or {} - - # Filter out providers with empty lists of credentials to ensure validity - api_keys = {provider: keys for provider, keys in api_keys.items() if keys} - oauth_credentials = { - provider: paths for provider, paths in oauth_credentials.items() if paths - } - - if not api_keys and not oauth_credentials: - lib_logger.warning( - "No provider credentials configured. The client will be unable to make any API requests." - ) - - self.api_keys = api_keys - # Use provided oauth_credentials directly if available (already discovered by main.py) - # Only call discover_and_prepare() if no credentials were passed - if oauth_credentials: - self.oauth_credentials = oauth_credentials - else: - self.credential_manager = CredentialManager( - os.environ, oauth_dir=get_oauth_dir(self.data_dir) - ) - self.oauth_credentials = self.credential_manager.discover_and_prepare() - self.background_refresher = BackgroundRefresher(self) - self.oauth_providers = set(self.oauth_credentials.keys()) - - all_credentials = {} - for provider, keys in api_keys.items(): - all_credentials.setdefault(provider, []).extend(keys) - for provider, paths in self.oauth_credentials.items(): - all_credentials.setdefault(provider, []).extend(paths) - self.all_credentials = all_credentials - - self.max_retries = max_retries - self.global_timeout = global_timeout - self.abort_on_callback_error = abort_on_callback_error - - # Initialize provider plugins early so they can be used for rotation mode detection - self._provider_plugins = PROVIDER_PLUGINS - self._provider_instances = {} - - # Build provider rotation modes map - # Each provider can specify its preferred rotation mode ("balanced" or "sequential") - provider_rotation_modes = {} - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - if provider_class and hasattr(provider_class, "get_rotation_mode"): - # Use class method to get rotation mode (checks env var + class default) - mode = provider_class.get_rotation_mode(provider) - else: - # Fallback: check environment variable directly - env_key = f"ROTATION_MODE_{provider.upper()}" - mode = os.getenv(env_key, "balanced") - - provider_rotation_modes[provider] = mode - if mode != "balanced": - lib_logger.info(f"Provider '{provider}' using rotation mode: {mode}") - - # Build priority-based concurrency multiplier maps - # These are universal multipliers based on credential tier/priority - priority_multipliers: Dict[str, Dict[int, int]] = {} - priority_multipliers_by_mode: Dict[str, Dict[str, Dict[int, int]]] = {} - sequential_fallback_multipliers: Dict[str, int] = {} - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - - # Start with provider class defaults - if provider_class: - # Get default priority multipliers from provider class - if hasattr(provider_class, "default_priority_multipliers"): - default_multipliers = provider_class.default_priority_multipliers - if default_multipliers: - priority_multipliers[provider] = dict(default_multipliers) - - # Get sequential fallback from provider class - if hasattr(provider_class, "default_sequential_fallback_multiplier"): - fallback = provider_class.default_sequential_fallback_multiplier - if ( - fallback != DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER - ): # Only store if different from global default - sequential_fallback_multipliers[provider] = fallback - - # Override with environment variables - # Format: CONCURRENCY_MULTIPLIER__PRIORITY_= - # Format: CONCURRENCY_MULTIPLIER__PRIORITY__= - for key, value in os.environ.items(): - prefix = f"CONCURRENCY_MULTIPLIER_{provider.upper()}_PRIORITY_" - if key.startswith(prefix): - remainder = key[len(prefix) :] - try: - multiplier = int(value) - if multiplier < 1: - lib_logger.warning(f"Invalid {key}: {value}. Must be >= 1.") - continue - - # Check if mode-specific (e.g., _PRIORITY_1_SEQUENTIAL) - if "_" in remainder: - parts = remainder.rsplit("_", 1) - priority = int(parts[0]) - mode = parts[1].lower() - if mode in ("sequential", "balanced"): - # Mode-specific override - if provider not in priority_multipliers_by_mode: - priority_multipliers_by_mode[provider] = {} - if mode not in priority_multipliers_by_mode[provider]: - priority_multipliers_by_mode[provider][mode] = {} - priority_multipliers_by_mode[provider][mode][ - priority - ] = multiplier - lib_logger.info( - f"Provider '{provider}' priority {priority} ({mode} mode) multiplier: {multiplier}x" - ) - else: - # Assume it's part of the priority number (unlikely but handle gracefully) - lib_logger.warning(f"Unknown mode in {key}: {mode}") - else: - # Universal priority multiplier - priority = int(remainder) - if provider not in priority_multipliers: - priority_multipliers[provider] = {} - priority_multipliers[provider][priority] = multiplier - lib_logger.info( - f"Provider '{provider}' priority {priority} multiplier: {multiplier}x" - ) - except ValueError: - lib_logger.warning( - f"Invalid {key}: {value}. Could not parse priority or multiplier." - ) - - # Log configured multipliers - for provider, multipliers in priority_multipliers.items(): - if multipliers: - lib_logger.info( - f"Provider '{provider}' priority multipliers: {multipliers}" - ) - for provider, fallback in sequential_fallback_multipliers.items(): - lib_logger.info( - f"Provider '{provider}' sequential fallback multiplier: {fallback}x" - ) - - # Build fair cycle configuration - fair_cycle_enabled: Dict[str, bool] = {} - fair_cycle_tracking_mode: Dict[str, str] = {} - fair_cycle_cross_tier: Dict[str, bool] = {} - fair_cycle_duration: Dict[str, int] = {} - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - rotation_mode = provider_rotation_modes.get(provider, "balanced") - - # Fair cycle enabled - check env, then provider default, then derive from rotation mode - env_key = f"FAIR_CYCLE_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - fair_cycle_enabled[provider] = env_val.lower() in ("true", "1", "yes") - elif provider_class and hasattr( - provider_class, "default_fair_cycle_enabled" - ): - default_val = provider_class.default_fair_cycle_enabled - if default_val is not None: - fair_cycle_enabled[provider] = default_val - # None means use global default (enabled for all modes) - # Default: enabled for all rotation modes (not stored, handled in UsageManager) - - # Tracking mode - check env, then provider default - env_key = f"FAIR_CYCLE_TRACKING_MODE_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None and env_val.lower() in ("model_group", "credential"): - fair_cycle_tracking_mode[provider] = env_val.lower() - elif provider_class and hasattr( - provider_class, "default_fair_cycle_tracking_mode" - ): - fair_cycle_tracking_mode[provider] = ( - provider_class.default_fair_cycle_tracking_mode - ) - - # Cross-tier - check env, then provider default - env_key = f"FAIR_CYCLE_CROSS_TIER_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - fair_cycle_cross_tier[provider] = env_val.lower() in ( - "true", - "1", - "yes", - ) - elif provider_class and hasattr( - provider_class, "default_fair_cycle_cross_tier" - ): - if provider_class.default_fair_cycle_cross_tier: - fair_cycle_cross_tier[provider] = True - - # Duration - check provider-specific env, then provider default - env_key = f"FAIR_CYCLE_DURATION_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - try: - fair_cycle_duration[provider] = int(env_val) - except ValueError: - lib_logger.warning( - f"Invalid {env_key}: {env_val}. Must be integer." - ) - elif provider_class and hasattr( - provider_class, "default_fair_cycle_duration" - ): - duration = provider_class.default_fair_cycle_duration - if ( - duration != DEFAULT_FAIR_CYCLE_DURATION - ): # Only store if different from global default - fair_cycle_duration[provider] = duration - - # Build exhaustion cooldown threshold per provider - # Check global env first, then per-provider env, then provider class default - exhaustion_cooldown_threshold: Dict[str, int] = {} - global_threshold_str = os.getenv("EXHAUSTION_COOLDOWN_THRESHOLD") - global_threshold = DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD - if global_threshold_str: - try: - global_threshold = int(global_threshold_str) - except ValueError: - lib_logger.warning( - f"Invalid EXHAUSTION_COOLDOWN_THRESHOLD: {global_threshold_str}. Using default {DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD}." - ) - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - - # Check per-provider env var first - env_key = f"EXHAUSTION_COOLDOWN_THRESHOLD_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - try: - exhaustion_cooldown_threshold[provider] = int(env_val) - except ValueError: - lib_logger.warning( - f"Invalid {env_key}: {env_val}. Must be integer." - ) - elif provider_class and hasattr( - provider_class, "default_exhaustion_cooldown_threshold" - ): - threshold = provider_class.default_exhaustion_cooldown_threshold - if ( - threshold != DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD - ): # Only store if different from global default - exhaustion_cooldown_threshold[provider] = threshold - elif global_threshold != DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD: - # Use global threshold if set and different from default - exhaustion_cooldown_threshold[provider] = global_threshold - - # Log fair cycle configuration - for provider, enabled in fair_cycle_enabled.items(): - if not enabled: - lib_logger.info(f"Provider '{provider}' fair cycle: disabled") - for provider, mode in fair_cycle_tracking_mode.items(): - if mode != "model_group": - lib_logger.info( - f"Provider '{provider}' fair cycle tracking mode: {mode}" - ) - for provider, cross_tier in fair_cycle_cross_tier.items(): - if cross_tier: - lib_logger.info(f"Provider '{provider}' fair cycle cross-tier: enabled") - - # Build custom caps configuration - # Format: CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL_OR_GROUP}= - # Format: CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL_OR_GROUP}=: - custom_caps: Dict[ - str, Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]] - ] = {} - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - provider_upper = provider.upper() - - # Start with provider class defaults - if provider_class and hasattr(provider_class, "default_custom_caps"): - default_caps = provider_class.default_custom_caps - if default_caps: - custom_caps[provider] = {} - for tier_key, models_config in default_caps.items(): - custom_caps[provider][tier_key] = dict(models_config) - - # Parse environment variable overrides - cap_prefix = f"CUSTOM_CAP_{provider_upper}_T" - cooldown_prefix = f"CUSTOM_CAP_COOLDOWN_{provider_upper}_T" - - for env_key, env_value in os.environ.items(): - if env_key.startswith(cap_prefix) and not env_key.startswith( - cooldown_prefix - ): - # Parse cap value - remainder = env_key[len(cap_prefix) :] - tier_key, model_key = self._parse_custom_cap_env_key(remainder) - if tier_key is None: - continue - - if provider not in custom_caps: - custom_caps[provider] = {} - if tier_key not in custom_caps[provider]: - custom_caps[provider][tier_key] = {} - if model_key not in custom_caps[provider][tier_key]: - custom_caps[provider][tier_key][model_key] = {} - - # Store max_requests value - custom_caps[provider][tier_key][model_key]["max_requests"] = ( - env_value - ) - - elif env_key.startswith(cooldown_prefix): - # Parse cooldown config - remainder = env_key[len(cooldown_prefix) :] - tier_key, model_key = self._parse_custom_cap_env_key(remainder) - if tier_key is None: - continue - - # Parse mode:value format - if ":" in env_value: - mode, value_str = env_value.split(":", 1) - try: - value = int(value_str) - except ValueError: - lib_logger.warning( - f"Invalid cooldown value in {env_key}: {env_value}" - ) - continue - else: - mode = env_value - value = 0 - - if provider not in custom_caps: - custom_caps[provider] = {} - if tier_key not in custom_caps[provider]: - custom_caps[provider][tier_key] = {} - if model_key not in custom_caps[provider][tier_key]: - custom_caps[provider][tier_key][model_key] = {} - - custom_caps[provider][tier_key][model_key]["cooldown_mode"] = mode - custom_caps[provider][tier_key][model_key]["cooldown_value"] = value - - # Log custom caps configuration - for provider, tier_configs in custom_caps.items(): - for tier_key, models_config in tier_configs.items(): - for model_key, config in models_config.items(): - max_req = config.get("max_requests", "default") - cooldown = config.get("cooldown_mode", "quota_reset") - lib_logger.info( - f"Custom cap: {provider}/T{tier_key}/{model_key} = {max_req}, cooldown={cooldown}" - ) - - # Resolve usage file path - use provided path or default to data_dir - if usage_file_path is not None: - resolved_usage_path = Path(usage_file_path) - else: - resolved_usage_path = self.data_dir / "key_usage.json" - - self.usage_manager = UsageManager( - file_path=resolved_usage_path, - rotation_tolerance=rotation_tolerance, - provider_rotation_modes=provider_rotation_modes, - provider_plugins=PROVIDER_PLUGINS, - priority_multipliers=priority_multipliers, - priority_multipliers_by_mode=priority_multipliers_by_mode, - sequential_fallback_multipliers=sequential_fallback_multipliers, - fair_cycle_enabled=fair_cycle_enabled, - fair_cycle_tracking_mode=fair_cycle_tracking_mode, - fair_cycle_cross_tier=fair_cycle_cross_tier, - fair_cycle_duration=fair_cycle_duration, - exhaustion_cooldown_threshold=exhaustion_cooldown_threshold, - custom_caps=custom_caps, - ) - self._model_list_cache = {} - self.http_client = httpx.AsyncClient() - self.provider_config = ProviderConfig() - self.cooldown_manager = CooldownManager() - self.litellm_provider_params = litellm_provider_params or {} - self.ignore_models = ignore_models or {} - self.whitelist_models = whitelist_models or {} - self.enable_request_logging = enable_request_logging - self.model_definitions = ModelDefinitions() - - # Store and validate max concurrent requests per key - self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {} - # Validate all values are >= 1 - for provider, max_val in self.max_concurrent_requests_per_key.items(): - if max_val < 1: - lib_logger.warning( - f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1." - ) - self.max_concurrent_requests_per_key[provider] = 1 - - def _parse_custom_cap_env_key( - self, remainder: str - ) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: - """ - Parse the tier and model/group from a custom cap env var remainder. - - Args: - remainder: String after "CUSTOM_CAP_{PROVIDER}_T" prefix - e.g., "2_CLAUDE" or "2_3_CLAUDE" or "DEFAULT_CLAUDE" - - Returns: - (tier_key, model_key) tuple, or (None, None) if parse fails - """ - if not remainder: - return None, None - - remaining_parts = remainder.split("_") - if len(remaining_parts) < 2: - return None, None - - tier_key: Union[int, Tuple[int, ...], str, None] = None - model_key: Optional[str] = None - - # Tiers are numeric or "DEFAULT" - tier_parts: List[int] = [] - - for i, part in enumerate(remaining_parts): - if part == "DEFAULT": - tier_key = "default" - model_key = "_".join(remaining_parts[i + 1 :]) - break - elif part.isdigit(): - tier_parts.append(int(part)) - else: - # First non-numeric part is start of model name - if len(tier_parts) == 0: - return None, None - elif len(tier_parts) == 1: - tier_key = tier_parts[0] - else: - tier_key = tuple(tier_parts) - model_key = "_".join(remaining_parts[i:]) - break - else: - # All parts were tier parts, no model - return None, None - - if model_key: - # Convert model_key back to original format (for matching) - # Env vars use underscores, but we store with original names - # The matching in UsageManager will handle this - model_key = model_key.lower().replace("_", "-") - - return tier_key, model_key - - def _is_model_ignored(self, provider: str, model_id: str) -> bool: - """ - Checks if a model should be ignored based on the ignore list. - Supports full glob/fnmatch patterns for both full model IDs and model names. - - Pattern examples: - - "gpt-4" - exact match - - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) - - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) - - "*-preview*" - contains wildcard (matches anything with -preview) - - "*" - match all - """ - model_provider = model_id.split("/")[0] - if model_provider not in self.ignore_models: - return False - - ignore_list = self.ignore_models[model_provider] - if ignore_list == ["*"]: - return True - - try: - # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") - provider_model_name = model_id.split("/", 1)[1] - except IndexError: - provider_model_name = model_id - - for ignored_pattern in ignore_list: - # Use fnmatch for full glob pattern support - if fnmatch.fnmatch(provider_model_name, ignored_pattern) or fnmatch.fnmatch( - model_id, ignored_pattern - ): - return True - return False - - def _is_model_whitelisted(self, provider: str, model_id: str) -> bool: - """ - Checks if a model is explicitly whitelisted. - Supports full glob/fnmatch patterns for both full model IDs and model names. - - Pattern examples: - - "gpt-4" - exact match - - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) - - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) - - "*-preview*" - contains wildcard (matches anything with -preview) - - "*" - match all - """ - model_provider = model_id.split("/")[0] - if model_provider not in self.whitelist_models: - return False - - whitelist = self.whitelist_models[model_provider] - - try: - # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") - provider_model_name = model_id.split("/", 1)[1] - except IndexError: - provider_model_name = model_id - - for whitelisted_pattern in whitelist: - # Use fnmatch for full glob pattern support - if fnmatch.fnmatch( - provider_model_name, whitelisted_pattern - ) or fnmatch.fnmatch(model_id, whitelisted_pattern): - return True - return False - - def _sanitize_litellm_log(self, log_data: dict) -> dict: - """ - Recursively removes large data fields and sensitive information from litellm log - dictionaries to keep debug logs clean and secure. - """ - if not isinstance(log_data, dict): - return log_data - - # Keys to remove at any level of the dictionary - keys_to_pop = [ - "messages", - "input", - "response", - "data", - "api_key", - "api_base", - "original_response", - "additional_args", - ] - - # Keys that might contain nested dictionaries to clean - nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"] - - # Create a deep copy to avoid modifying the original log object in memory - clean_data = json.loads(json.dumps(log_data, default=str)) - - def clean_recursively(data_dict): - if not isinstance(data_dict, dict): - return - - # Remove sensitive/large keys - for key in keys_to_pop: - data_dict.pop(key, None) - - # Recursively clean nested dictionaries - for key in nested_keys: - if key in data_dict and isinstance(data_dict[key], dict): - clean_recursively(data_dict[key]) - - # Also iterate through all values to find any other nested dicts - for key, value in list(data_dict.items()): - if isinstance(value, dict): - clean_recursively(value) - - clean_recursively(clean_data) - return clean_data - - def _litellm_logger_callback(self, log_data: dict): - """ - Callback function to redirect litellm's logs to the library's logger. - This allows us to control the log level and destination of litellm's output. - It also cleans up error logs for better readability in debug files. - """ - # Filter out verbose pre_api_call and post_api_call logs - log_event_type = log_data.get("log_event_type") - if log_event_type in ["pre_api_call", "post_api_call"]: - return # Skip these verbose logs entirely - - # For successful calls or pre-call logs, a simple debug message is enough. - if not log_data.get("exception"): - sanitized_log = self._sanitize_litellm_log(log_data) - # We log it at the DEBUG level to ensure it goes to the debug file - # and not the console, based on the main.py configuration. - lib_logger.debug(f"LiteLLM Log: {sanitized_log}") - return - - # For failures, extract key info to make debug logs more readable. - model = log_data.get("model", "N/A") - call_id = log_data.get("litellm_call_id", "N/A") - error_info = log_data.get("standard_logging_object", {}).get( - "error_information", {} - ) - error_class = error_info.get("error_class", "UnknownError") - error_message = error_info.get( - "error_message", str(log_data.get("exception", "")) - ) - error_message = " ".join(error_message.split()) # Sanitize - - lib_logger.debug( - f"LiteLLM Callback Handled Error: Model={model} | " - f"Type={error_class} | Message='{error_message}'" - ) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - async def close(self): - """Close the HTTP client to prevent resource leaks.""" - if hasattr(self, "http_client") and self.http_client: - await self.http_client.aclose() - - def _apply_default_safety_settings( - self, litellm_kwargs: Dict[str, Any], provider: str - ): - """ - Ensure default Gemini safety settings are present when calling the Gemini provider. - This will not override any explicit settings provided by the request. It accepts - either OpenAI-compatible generic `safety_settings` (dict) or direct Gemini-style - `safetySettings` (list of dicts). Missing categories will be added with safe defaults. - """ - if provider != "gemini": - return - - # Generic defaults (openai-compatible style) - default_generic = { - "harassment": "OFF", - "hate_speech": "OFF", - "sexually_explicit": "OFF", - "dangerous_content": "OFF", - "civic_integrity": "BLOCK_NONE", - } - - # Gemini defaults (direct Gemini format) - default_gemini = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, - ] - - # If generic form is present, ensure missing generic keys are filled in - if "safety_settings" in litellm_kwargs and isinstance( - litellm_kwargs["safety_settings"], dict - ): - for k, v in default_generic.items(): - if k not in litellm_kwargs["safety_settings"]: - litellm_kwargs["safety_settings"][k] = v - return - - # If Gemini form is present, ensure missing gemini categories are appended - if "safetySettings" in litellm_kwargs and isinstance( - litellm_kwargs["safetySettings"], list - ): - present = { - item.get("category") - for item in litellm_kwargs["safetySettings"] - if isinstance(item, dict) - } - for d in default_gemini: - if d["category"] not in present: - litellm_kwargs["safetySettings"].append(d) - return - - # Neither present: set generic defaults so provider conversion will translate them - if ( - "safety_settings" not in litellm_kwargs - and "safetySettings" not in litellm_kwargs - ): - litellm_kwargs["safety_settings"] = default_generic.copy() - - def get_oauth_credentials(self) -> Dict[str, List[str]]: - return self.oauth_credentials - - def _is_custom_openai_compatible_provider(self, provider_name: str) -> bool: - """ - Checks if a provider is a custom OpenAI-compatible provider. - - Custom providers are identified by: - 1. Having a _API_BASE environment variable set, AND - 2. NOT being in the list of known LiteLLM providers - """ - return self.provider_config.is_custom_provider(provider_name) - - def _get_provider_instance(self, provider_name: str): - """ - Lazily initializes and returns a provider instance. - Only initializes providers that have configured credentials. - - Args: - provider_name: The name of the provider to get an instance for. - For OAuth providers, this may include "_oauth" suffix - (e.g., "antigravity_oauth"), but credentials are stored - under the base name (e.g., "antigravity"). - - Returns: - Provider instance if credentials exist, None otherwise. - """ - # For OAuth providers, credentials are stored under base name (without _oauth suffix) - # e.g., "antigravity_oauth" plugin → credentials under "antigravity" - credential_key = provider_name - if provider_name.endswith("_oauth"): - base_name = provider_name[:-6] # Remove "_oauth" - if base_name in self.oauth_providers: - credential_key = base_name - - # Only initialize providers for which we have credentials - if credential_key not in self.all_credentials: - lib_logger.debug( - f"Skipping provider '{provider_name}' initialization: no credentials configured" - ) - return None - - if provider_name not in self._provider_instances: - if provider_name in self._provider_plugins: - self._provider_instances[provider_name] = self._provider_plugins[ - provider_name - ]() - elif self._is_custom_openai_compatible_provider(provider_name): - # Create a generic OpenAI-compatible provider for custom providers - try: - self._provider_instances[provider_name] = OpenAICompatibleProvider( - provider_name - ) - except ValueError: - # If the provider doesn't have the required environment variables, treat it as a standard provider - return None - else: - return None - return self._provider_instances[provider_name] - - def _resolve_model_id(self, model: str, provider: str) -> str: - """ - Resolves the actual model ID to send to the provider. - - For custom models with name/ID mappings, returns the ID. - Otherwise, returns the model name unchanged. - - Args: - model: Full model string with provider (e.g., "iflow/DS-v3.2") - provider: Provider name (e.g., "iflow") - - Returns: - Full model string with ID (e.g., "iflow/deepseek-v3.2") - """ - # Extract model name from "provider/model_name" format - model_name = model.split("/")[-1] if "/" in model else model - - # Try to get provider instance to check for model definitions - provider_plugin = self._get_provider_instance(provider) - - # Check if provider has model definitions - if provider_plugin and hasattr(provider_plugin, "model_definitions"): - model_id = provider_plugin.model_definitions.get_model_id( - provider, model_name - ) - if model_id and model_id != model_name: - # Return with provider prefix - return f"{provider}/{model_id}" - - # Fallback: use client's own model definitions - model_id = self.model_definitions.get_model_id(provider, model_name) - if model_id and model_id != model_name: - return f"{provider}/{model_id}" - - # No conversion needed, return original - return model - - async def _safe_streaming_wrapper( - self, - stream: Any, - key: str, - model: str, - request: Optional[Any] = None, - provider_plugin: Optional[Any] = None, - ) -> AsyncGenerator[Any, None]: - """ - A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully, - and distinguishes between content and streamed errors. - - FINISH_REASON HANDLING: - Providers just translate chunks - this wrapper handles ALL finish_reason logic: - 1. Strip finish_reason from intermediate chunks (litellm defaults to "stop") - 2. Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop - 3. Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) - """ - last_usage = None - stream_completed = False - stream_iterator = stream.__aiter__() - json_buffer = "" - accumulated_finish_reason = None # Track strongest finish_reason across chunks - has_tool_calls = False # Track if ANY tool calls were seen in stream - - try: - while True: - if request and await request.is_disconnected(): - lib_logger.info( - f"Client disconnected. Aborting stream for credential {mask_credential(key)}." - ) - break - - try: - chunk = await stream_iterator.__anext__() - if json_buffer: - lib_logger.warning( - f"Discarding incomplete JSON buffer from previous chunk: {json_buffer}" - ) - json_buffer = "" - - # Convert chunk to dict, handling both litellm.ModelResponse and raw dicts - if hasattr(chunk, "dict"): - chunk_dict = chunk.dict() - elif hasattr(chunk, "model_dump"): - chunk_dict = chunk.model_dump() - else: - chunk_dict = chunk - - # === FINISH_REASON LOGIC === - # Providers send raw chunks without finish_reason logic. - # This wrapper determines finish_reason based on accumulated state. - if "choices" in chunk_dict and chunk_dict["choices"]: - choice = chunk_dict["choices"][0] - delta = choice.get("delta", {}) - usage = chunk_dict.get("usage", {}) - - # Track tool_calls across ALL chunks - if we ever see one, finish_reason must be tool_calls - if delta.get("tool_calls"): - has_tool_calls = True - accumulated_finish_reason = "tool_calls" - - # Detect final chunk: has usage with completion_tokens > 0 - has_completion_tokens = ( - usage - and isinstance(usage, dict) - and usage.get("completion_tokens", 0) > 0 - ) - - if has_completion_tokens: - # FINAL CHUNK: Determine correct finish_reason - if has_tool_calls: - # Tool calls always win - choice["finish_reason"] = "tool_calls" - elif accumulated_finish_reason: - # Use accumulated reason (length, content_filter, etc.) - choice["finish_reason"] = accumulated_finish_reason - else: - # Default to stop - choice["finish_reason"] = "stop" - else: - # INTERMEDIATE CHUNK: Never emit finish_reason - # (litellm.ModelResponse defaults to "stop" which is wrong) - choice["finish_reason"] = None - - yield f"data: {json.dumps(chunk_dict)}\n\n" - - if hasattr(chunk, "usage") and chunk.usage: - last_usage = chunk.usage - - except StopAsyncIteration: - stream_completed = True - if json_buffer: - lib_logger.info( - f"Stream ended with incomplete data in buffer: {json_buffer}" - ) - if last_usage: - # Create a dummy ModelResponse for recording (only usage matters) - dummy_response = litellm.ModelResponse(usage=last_usage) - await self.usage_manager.record_success( - key, model, dummy_response - ) - else: - # If no usage seen (rare), record success without tokens/cost - await self.usage_manager.record_success(key, model) - - break - - except CredentialNeedsReauthError as e: - # This credential needs re-authentication but re-auth is already queued. - # Wrap it so the outer retry loop can rotate to the next credential. - # No scary traceback needed - this is an expected recovery scenario. - raise StreamedAPIError("Credential needs re-authentication", data=e) - - except ( - litellm.RateLimitError, - litellm.ServiceUnavailableError, - litellm.InternalServerError, - APIConnectionError, - httpx.HTTPStatusError, - ) as e: - # This is a critical, typed error from litellm or httpx that signals a key failure. - # We do not try to parse it here. We wrap it and raise it immediately - # for the outer retry loop to handle. - lib_logger.warning( - f"Caught a critical API error mid-stream: {type(e).__name__}. Signaling for credential rotation." - ) - raise StreamedAPIError("Provider error received in stream", data=e) - - except Exception as e: - try: - raw_chunk = "" - # Google streams errors inside a bytes representation (b'{...}'). - # We use regex to extract the content, which is more reliable than splitting. - match = re.search(r"b'(\{.*\})'", str(e), re.DOTALL) - if match: - # The extracted string is unicode-escaped (e.g., '\\n'). We must decode it. - raw_chunk = codecs.decode(match.group(1), "unicode_escape") - else: - # Fallback for other potential error formats that use "Received chunk:". - chunk_from_split = ( - str(e).split("Received chunk:")[-1].strip() - ) - if chunk_from_split != str( - e - ): # Ensure the split actually did something - raw_chunk = chunk_from_split - - if not raw_chunk: - # If we could not extract a valid chunk, we cannot proceed with reassembly. - # This indicates a different, unexpected error type. Re-raise it. - raise e - - # Append the clean chunk to the buffer and try to parse. - json_buffer += raw_chunk - parsed_data = json.loads(json_buffer) - - # If parsing succeeds, we have the complete object. - lib_logger.info( - f"Successfully reassembled JSON from stream: {json_buffer}" - ) - - # Wrap the complete error object and raise it. The outer function will decide how to handle it. - raise StreamedAPIError( - "Provider error received in stream", data=parsed_data - ) - - except json.JSONDecodeError: - # This is the expected outcome if the JSON in the buffer is not yet complete. - lib_logger.info( - f"Buffer still incomplete. Waiting for more chunks: {json_buffer}" - ) - continue # Continue to the next loop to get the next chunk. - except StreamedAPIError: - # Re-raise to be caught by the outer retry handler. - raise - except Exception as buffer_exc: - # If the error was not a JSONDecodeError, it's an unexpected internal error. - lib_logger.error( - f"Error during stream buffering logic: {buffer_exc}. Discarding buffer." - ) - json_buffer = ( - "" # Clear the corrupted buffer to prevent further issues. - ) - raise buffer_exc - - except StreamedAPIError: - # This is caught by the acompletion retry logic. - # We re-raise it to ensure it's not caught by the generic 'except Exception'. - raise - - except Exception as e: - # Catch any other unexpected errors during streaming. - lib_logger.error(f"Caught unexpected exception of type: {type(e).__name__}") - lib_logger.error( - f"An unexpected error occurred during the stream for credential {mask_credential(key)}: {e}" - ) - # We still need to raise it so the client knows something went wrong. - raise - - finally: - # This block now runs regardless of how the stream terminates (completion, client disconnect, etc.). - # The primary goal is to ensure usage is always logged internally. - await self.usage_manager.release_key(key, model) - lib_logger.info( - f"STREAM FINISHED and lock released for credential {mask_credential(key)}." - ) - - # Only send [DONE] if the stream completed naturally and the client is still there. - # This prevents sending [DONE] to a disconnected client or after an error. - if stream_completed and ( - not request or not await request.is_disconnected() - ): - yield "data: [DONE]\n\n" - - async def _transaction_logging_stream_wrapper( - self, - stream: Any, - transaction_logger: Optional[TransactionLogger], - request_data: Dict[str, Any], - ) -> Any: - """ - Wrap a stream to log chunks and final response to TransactionLogger. - - This wrapper: - 1. Yields chunks unchanged (passthrough) - 2. Parses SSE chunks and logs them via transaction_logger.log_stream_chunk() - 3. Collects chunks for final response assembly - 4. After stream ends, assembles and logs final response - - Args: - stream: The streaming generator (yields SSE strings like "data: {...}") - transaction_logger: Optional TransactionLogger instance - request_data: Original request data for context - """ - chunks = [] - try: - async for chunk_str in stream: - yield chunk_str - - # Log chunk if logging enabled - if ( - transaction_logger - and isinstance(chunk_str, str) - and chunk_str.strip() - and chunk_str.startswith("data:") - ): - content = chunk_str[len("data:") :].strip() - if content and content != "[DONE]": - try: - chunk_data = json.loads(content) - chunks.append(chunk_data) - transaction_logger.log_stream_chunk(chunk_data) - except json.JSONDecodeError: - lib_logger.warning( - f"TransactionLogger: Failed to parse chunk: {content[:100]}" - ) - finally: - # Assemble and log final response after stream ends - if transaction_logger and chunks: - try: - final_response = TransactionLogger.assemble_streaming_response( - chunks, request_data - ) - transaction_logger.log_response(final_response) - except Exception as e: - lib_logger.warning( - f"TransactionLogger: Failed to assemble/log final response: {e}" - ) - - async def _execute_with_retry( - self, - api_call: callable, - request: Optional[Any], - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> Any: - """A generic retry mechanism for non-streaming API calls.""" - model = kwargs.get("model") - if not model: - raise ValueError("'model' is a required parameter.") - - provider = model.split("/")[0] - if provider not in self.all_credentials: - raise ValueError( - f"No API keys or OAuth credentials configured for provider: {provider}" - ) - - # Extract internal logging parameters (not passed to API) - parent_log_dir = kwargs.pop("_parent_log_dir", None) - - # Establish a global deadline for the entire request lifecycle. - deadline = time.time() + self.global_timeout - - # Create transaction logger if request logging is enabled - transaction_logger = None - if self.enable_request_logging: - transaction_logger = TransactionLogger( - provider, - model, - enabled=True, - api_format="oai", - parent_dir=parent_log_dir, - ) - transaction_logger.log_request(kwargs) - - # Create a mutable copy of the keys and shuffle it to ensure - # that the key selection is randomized, which is crucial when - # multiple keys have the same usage stats. - credentials_for_provider = list(self.all_credentials[provider]) - random.shuffle(credentials_for_provider) - - # Filter out credentials that are unavailable (queued for re-auth) - provider_plugin = self._get_provider_instance(provider) - if provider_plugin and hasattr(provider_plugin, "is_credential_available"): - available_creds = [ - cred - for cred in credentials_for_provider - if provider_plugin.is_credential_available(cred) - ] - if available_creds: - credentials_for_provider = available_creds - # If all credentials are unavailable, keep the original list - # (better to try unavailable creds than fail immediately) - - tried_creds = set() - last_exception = None - - # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded. - - # Resolve model ID early, before any credential operations - # This ensures consistent model ID usage for acquisition, release, and tracking - resolved_model = self._resolve_model_id(model, provider) - if resolved_model != model: - lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") - model = resolved_model - kwargs["model"] = model # Ensure kwargs has the resolved model for litellm - - # [NEW] Filter by model tier requirement and build priority map - credential_priorities = None - if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): - required_tier = provider_plugin.get_model_tier_requirement(model) - if required_tier is not None: - # Filter OUT only credentials we KNOW are too low priority - # Keep credentials with unknown priority (None) - they might be high priority - incompatible_creds = [] - compatible_creds = [] - unknown_creds = [] - - for cred in credentials_for_provider: - if hasattr(provider_plugin, "get_credential_priority"): - priority = provider_plugin.get_credential_priority(cred) - if priority is None: - # Unknown priority - keep it, will be discovered on first use - unknown_creds.append(cred) - elif priority <= required_tier: - # Known compatible priority - compatible_creds.append(cred) - else: - # Known incompatible priority (too low) - incompatible_creds.append(cred) - else: - # Provider doesn't support priorities - keep all - unknown_creds.append(cred) - - # If we have any known-compatible or unknown credentials, use them - tier_compatible_creds = compatible_creds + unknown_creds - if tier_compatible_creds: - credentials_for_provider = tier_compatible_creds - if compatible_creds and unknown_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." - ) - elif compatible_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible credentials." - ) - else: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." - ) - elif incompatible_creds: - # Only known-incompatible credentials remain - lib_logger.warning( - f"Model {model} requires priority <= {required_tier} credentials, " - f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " - f"Request will likely fail." - ) - - # Build priority map and tier names map for usage_manager - credential_tier_names = None - if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): - credential_priorities = {} - credential_tier_names = {} - for cred in credentials_for_provider: - priority = provider_plugin.get_credential_priority(cred) - if priority is not None: - credential_priorities[cred] = priority - # Also get tier name for logging - if hasattr(provider_plugin, "get_credential_tier_name"): - tier_name = provider_plugin.get_credential_tier_name(cred) - if tier_name: - credential_tier_names[cred] = tier_name - - if credential_priorities: - lib_logger.debug( - f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" - ) - - # Initialize error accumulator for tracking errors across credential rotation - error_accumulator = RequestErrorAccumulator() - error_accumulator.model = model - error_accumulator.provider = provider - - while ( - len(tried_creds) < len(credentials_for_provider) and time.time() < deadline - ): - current_cred = None - key_acquired = False - try: - # Check for a provider-wide cooldown first. - if await self.cooldown_manager.is_cooling_down(provider): - remaining_cooldown = ( - await self.cooldown_manager.get_cooldown_remaining(provider) - ) - remaining_budget = deadline - time.time() - - # If the cooldown is longer than the remaining time budget, fail fast. - if remaining_cooldown > remaining_budget: - lib_logger.warning( - f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." - ) - break - - lib_logger.warning( - f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds." - ) - await asyncio.sleep(remaining_cooldown) - - creds_to_try = [ - c for c in credentials_for_provider if c not in tried_creds - ] - if not creds_to_try: - break - - # Get count of credentials not on cooldown for this model - availability_stats = ( - await self.usage_manager.get_credential_availability_stats( - creds_to_try, model, credential_priorities - ) - ) - available_count = availability_stats["available"] - total_count = len(credentials_for_provider) - on_cooldown = availability_stats["on_cooldown"] - fc_excluded = availability_stats["fair_cycle_excluded"] - - # Build compact exclusion breakdown - exclusion_parts = [] - if on_cooldown > 0: - exclusion_parts.append(f"cd:{on_cooldown}") - if fc_excluded > 0: - exclusion_parts.append(f"fc:{fc_excluded}") - exclusion_str = ( - f",{','.join(exclusion_parts)}" if exclusion_parts else "" - ) - - lib_logger.info( - f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{available_count}({total_count}{exclusion_str})" - ) - max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) - current_cred = await self.usage_manager.acquire_key( - available_keys=creds_to_try, - model=model, - deadline=deadline, - max_concurrent=max_concurrent, - credential_priorities=credential_priorities, - credential_tier_names=credential_tier_names, - all_provider_credentials=credentials_for_provider, - ) - key_acquired = True - tried_creds.add(current_cred) - - litellm_kwargs = kwargs.copy() - - # [NEW] Merge provider-specific params - if provider in self.litellm_provider_params: - litellm_kwargs["litellm_params"] = { - **self.litellm_provider_params[provider], - **litellm_kwargs.get("litellm_params", {}), - } - - provider_plugin = self._get_provider_instance(provider) - - # Model ID is already resolved before the loop, and kwargs['model'] is updated. - # No further resolution needed here. - - # Apply model-specific options for custom providers - if provider_plugin and hasattr(provider_plugin, "get_model_options"): - model_options = provider_plugin.get_model_options(model) - if model_options: - # Merge model options into litellm_kwargs - for key, value in model_options.items(): - if key == "reasoning_effort": - litellm_kwargs["reasoning_effort"] = value - elif key not in litellm_kwargs: - litellm_kwargs[key] = value - - if provider_plugin and provider_plugin.has_custom_logic(): - lib_logger.debug( - f"Provider '{provider}' has custom logic. Delegating call." - ) - litellm_kwargs["credential_identifier"] = current_cred - litellm_kwargs["transaction_context"] = ( - transaction_logger.get_context() if transaction_logger else None - ) - - # Retry loop for custom providers - mirrors streaming path error handling - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback(request, litellm_kwargs) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - response = await provider_plugin.acompletion( - self.http_client, **litellm_kwargs - ) - - # For non-streaming, success is immediate - await self.usage_manager.record_success( - current_cred, model, response - ) - - await self.usage_manager.release_key(current_cred, model) - key_acquired = False - - # Log response to transaction logger - if transaction_logger: - response_data = ( - response.model_dump() - if hasattr(response, "model_dump") - else response - ) - transaction_logger.log_response(response_data) - - return response - - except ( - litellm.RateLimitError, - httpx.HTTPStatusError, - ) as e: - last_exception = e - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}) during custom provider call. Failing." - ) - raise last_exception - - # Handle rate limits with cooldown (exclude quota_exceeded) - if classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." - ) - break # Rotate to next credential - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." - ) - break - - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time > remaining_budget: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." - ) - break - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - except Exception as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Record in accumulator - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing." - ) - raise last_exception - - # Handle rate limits with cooldown (exclude quota_exceeded) - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ) or classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break # Rotate to next credential - - # If the inner loop breaks, it means the key failed and we need to rotate. - # Continue to the next iteration of the outer while loop to pick a new key. - continue - - else: # This is the standard API Key / litellm-handled provider logic - is_oauth = provider in self.oauth_providers - if is_oauth: # Standard OAuth provider (not custom) - # ... (logic to set headers) ... - pass - else: # API Key - litellm_kwargs["api_key"] = current_cred - - provider_instance = self._get_provider_instance(provider) - if provider_instance: - # Ensure default Gemini safety settings are present (without overriding request) - try: - self._apply_default_safety_settings( - litellm_kwargs, provider - ) - except Exception: - # If anything goes wrong here, avoid breaking the request flow. - lib_logger.debug( - "Could not apply default safety settings; continuing." - ) - - if "safety_settings" in litellm_kwargs: - converted_settings = ( - provider_instance.convert_safety_settings( - litellm_kwargs["safety_settings"] - ) - ) - if converted_settings is not None: - litellm_kwargs["safety_settings"] = converted_settings - else: - del litellm_kwargs["safety_settings"] - - if provider == "gemini" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - if provider == "nvidia_nim" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - - if "gemma-3" in model and "messages" in litellm_kwargs: - litellm_kwargs["messages"] = [ - {"role": "user", "content": m["content"]} - if m.get("role") == "system" - else m - for m in litellm_kwargs["messages"] - ] - - litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) - - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback(request, litellm_kwargs) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - # Convert model parameters for custom providers right before LiteLLM call - final_kwargs = self.provider_config.convert_for_litellm( - **litellm_kwargs - ) - - response = await api_call( - **final_kwargs, - logger_fn=self._litellm_logger_callback, - ) - - await self.usage_manager.record_success( - current_cred, model, response - ) - - await self.usage_manager.release_key(current_cred, model) - key_acquired = False - - # Log response to transaction logger - if transaction_logger: - response_data = ( - response.model_dump() - if hasattr(response, "model_dump") - else response - ) - transaction_logger.log_response(response_data) - - return response - - except litellm.RateLimitError as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - - # Extract a clean error message for the user-facing log - error_message = str(e).split("\n")[0] - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - lib_logger.info( - f"Key {mask_credential(current_cred)} hit rate limit for {model}. Rotating key." - ) - - # Only trigger provider-wide cooldown for rate limits, not quota issues - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ): - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break # Move to the next key - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - # Record in accumulator only on final failure for this key - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Key {mask_credential(current_cred)} failed after max retries due to server error. Rotating." - ) - break # Move to the next key - - # For temporary errors, wait before retrying with the same key. - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - - # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately. - if wait_time > remaining_budget: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Retry wait ({wait_time:.2f}s) exceeds budget ({remaining_budget:.2f}s). Rotating key." - ) - break - - lib_logger.warning( - f"Key {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue # Retry with the same key - - except httpx.HTTPStatusError as e: - # Handle HTTP errors from httpx (e.g., from custom providers like Antigravity) - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - lib_logger.warning( - f"Key {mask_credential(current_cred)} HTTP {e.response.status_code} ({classified_error.error_type})." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing request." - ) - raise last_exception - - # Record in accumulator after confirming it's a rotatable error - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) - if classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - # Check if we should retry same key (server errors with retries left) - if ( - should_retry_same_key(classified_error) - and attempt < self.max_retries - 1 - ): - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time <= remaining_budget: - lib_logger.warning( - f"Server error, retrying same key in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - # Record failure and rotate to next key - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.info( - f"Rotating to next key after {classified_error.error_type} error." - ) - break - - except Exception as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - if request and await request.is_disconnected(): - lib_logger.warning( - f"Client disconnected. Aborting retries for {mask_credential(current_cred)}." - ) - raise last_exception - - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - lib_logger.warning( - f"Key {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." - ) - - # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ) or classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing request." - ) - raise last_exception - - # Record in accumulator after confirming it's a rotatable error - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break # Try next key for other errors - finally: - if key_acquired and current_cred: - await self.usage_manager.release_key(current_cred, model) - - # Check if we exhausted all credentials or timed out - if time.time() >= deadline: - error_accumulator.timeout_occurred = True - - if error_accumulator.has_errors(): - # Log concise summary for server logs - lib_logger.error(error_accumulator.build_log_message()) - - # Return the structured error response for the client - return error_accumulator.build_client_error_response() - - # Return None to indicate failure without error details (shouldn't normally happen) - lib_logger.warning( - "Unexpected state: request failed with no recorded errors. " - "This may indicate a logic error in error tracking." - ) - return None - - async def _streaming_acompletion_with_retry( - self, - request: Optional[Any], - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> AsyncGenerator[str, None]: - """A dedicated generator for retrying streaming completions with full request preparation and per-key retries.""" - model = kwargs.get("model") - provider = model.split("/")[0] - - # Extract internal logging parameters (not passed to API) - parent_log_dir = kwargs.pop("_parent_log_dir", None) - - # Create a mutable copy of the keys and shuffle it. - credentials_for_provider = list(self.all_credentials[provider]) - random.shuffle(credentials_for_provider) - - # Filter out credentials that are unavailable (queued for re-auth) - provider_plugin = self._get_provider_instance(provider) - if provider_plugin and hasattr(provider_plugin, "is_credential_available"): - available_creds = [ - cred - for cred in credentials_for_provider - if provider_plugin.is_credential_available(cred) - ] - if available_creds: - credentials_for_provider = available_creds - # If all credentials are unavailable, keep the original list - # (better to try unavailable creds than fail immediately) - - deadline = time.time() + self.global_timeout - - # Create transaction logger if request logging is enabled - transaction_logger = None - if self.enable_request_logging: - transaction_logger = TransactionLogger( - provider, - model, - enabled=True, - api_format="oai", - parent_dir=parent_log_dir, - ) - transaction_logger.log_request(kwargs) - - tried_creds = set() - last_exception = None - - consecutive_quota_failures = 0 - - # Resolve model ID early, before any credential operations - # This ensures consistent model ID usage for acquisition, release, and tracking - resolved_model = self._resolve_model_id(model, provider) - if resolved_model != model: - lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") - model = resolved_model - kwargs["model"] = model # Ensure kwargs has the resolved model for litellm - - # [NEW] Filter by model tier requirement and build priority map - credential_priorities = None - if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): - required_tier = provider_plugin.get_model_tier_requirement(model) - if required_tier is not None: - # Filter OUT only credentials we KNOW are too low priority - # Keep credentials with unknown priority (None) - they might be high priority - incompatible_creds = [] - compatible_creds = [] - unknown_creds = [] - - for cred in credentials_for_provider: - if hasattr(provider_plugin, "get_credential_priority"): - priority = provider_plugin.get_credential_priority(cred) - if priority is None: - # Unknown priority - keep it, will be discovered on first use - unknown_creds.append(cred) - elif priority <= required_tier: - # Known compatible priority - compatible_creds.append(cred) - else: - # Known incompatible priority (too low) - incompatible_creds.append(cred) - else: - # Provider doesn't support priorities - keep all - unknown_creds.append(cred) - - # If we have any known-compatible or unknown credentials, use them - tier_compatible_creds = compatible_creds + unknown_creds - if tier_compatible_creds: - credentials_for_provider = tier_compatible_creds - if compatible_creds and unknown_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." - ) - elif compatible_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible credentials." - ) - else: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." - ) - elif incompatible_creds: - # Only known-incompatible credentials remain - lib_logger.warning( - f"Model {model} requires priority <= {required_tier} credentials, " - f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " - f"Request will likely fail." - ) - - # Build priority map and tier names map for usage_manager - credential_tier_names = None - if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): - credential_priorities = {} - credential_tier_names = {} - for cred in credentials_for_provider: - priority = provider_plugin.get_credential_priority(cred) - if priority is not None: - credential_priorities[cred] = priority - # Also get tier name for logging - if hasattr(provider_plugin, "get_credential_tier_name"): - tier_name = provider_plugin.get_credential_tier_name(cred) - if tier_name: - credential_tier_names[cred] = tier_name - - if credential_priorities: - lib_logger.debug( - f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" - ) - - # Initialize error accumulator for tracking errors across credential rotation - error_accumulator = RequestErrorAccumulator() - error_accumulator.model = model - error_accumulator.provider = provider - - try: - while ( - len(tried_creds) < len(credentials_for_provider) - and time.time() < deadline - ): - current_cred = None - key_acquired = False - try: - if await self.cooldown_manager.is_cooling_down(provider): - remaining_cooldown = ( - await self.cooldown_manager.get_cooldown_remaining(provider) - ) - remaining_budget = deadline - time.time() - if remaining_cooldown > remaining_budget: - lib_logger.warning( - f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." - ) - break - lib_logger.warning( - f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds." - ) - await asyncio.sleep(remaining_cooldown) - - creds_to_try = [ - c for c in credentials_for_provider if c not in tried_creds - ] - if not creds_to_try: - lib_logger.warning( - f"All credentials for provider {provider} have been tried. No more credentials to rotate to." - ) - break - - # Get count of credentials not on cooldown for this model - availability_stats = ( - await self.usage_manager.get_credential_availability_stats( - creds_to_try, model, credential_priorities - ) - ) - available_count = availability_stats["available"] - total_count = len(credentials_for_provider) - on_cooldown = availability_stats["on_cooldown"] - fc_excluded = availability_stats["fair_cycle_excluded"] - - # Build compact exclusion breakdown - exclusion_parts = [] - if on_cooldown > 0: - exclusion_parts.append(f"cd:{on_cooldown}") - if fc_excluded > 0: - exclusion_parts.append(f"fc:{fc_excluded}") - exclusion_str = ( - f",{','.join(exclusion_parts)}" if exclusion_parts else "" - ) - - lib_logger.info( - f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{available_count}({total_count}{exclusion_str})" - ) - max_concurrent = self.max_concurrent_requests_per_key.get( - provider, 1 - ) - current_cred = await self.usage_manager.acquire_key( - available_keys=creds_to_try, - model=model, - deadline=deadline, - max_concurrent=max_concurrent, - credential_priorities=credential_priorities, - credential_tier_names=credential_tier_names, - all_provider_credentials=credentials_for_provider, - ) - key_acquired = True - tried_creds.add(current_cred) - - litellm_kwargs = kwargs.copy() - if "reasoning_effort" in kwargs: - litellm_kwargs["reasoning_effort"] = kwargs["reasoning_effort"] - - # [NEW] Merge provider-specific params - if provider in self.litellm_provider_params: - litellm_kwargs["litellm_params"] = { - **self.litellm_provider_params[provider], - **litellm_kwargs.get("litellm_params", {}), - } - - provider_plugin = self._get_provider_instance(provider) - - # Model ID is already resolved before the loop, and kwargs['model'] is updated. - # No further resolution needed here. - - # Apply model-specific options for custom providers - if provider_plugin and hasattr( - provider_plugin, "get_model_options" - ): - model_options = provider_plugin.get_model_options(model) - if model_options: - # Merge model options into litellm_kwargs - for key, value in model_options.items(): - if key == "reasoning_effort": - litellm_kwargs["reasoning_effort"] = value - elif key not in litellm_kwargs: - litellm_kwargs[key] = value - if provider_plugin and provider_plugin.has_custom_logic(): - lib_logger.debug( - f"Provider '{provider}' has custom logic. Delegating call." - ) - litellm_kwargs["credential_identifier"] = current_cred - litellm_kwargs["transaction_context"] = ( - transaction_logger.get_context() - if transaction_logger - else None - ) - - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback( - request, litellm_kwargs - ) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - response = await provider_plugin.acompletion( - self.http_client, **litellm_kwargs - ) - - lib_logger.info( - f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." - ) - - key_acquired = False - stream_generator = self._safe_streaming_wrapper( - response, - current_cred, - model, - request, - provider_plugin, - ) - - # Wrap with transaction logging - logged_stream = ( - self._transaction_logging_stream_wrapper( - stream_generator, transaction_logger, kwargs - ) - ) - - async for chunk in logged_stream: - yield chunk - return - - except ( - StreamedAPIError, - litellm.RateLimitError, - httpx.HTTPStatusError, - ) as e: - last_exception = e - # If the exception is our custom wrapper, unwrap the original error - original_exc = getattr(e, "data", e) - classified_error = classify_error( - original_exc, provider=provider - ) - error_message = str(original_exc).split("\n")[0] - - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}) during custom stream. Failing." - ) - raise last_exception - - # Handle rate limits with cooldown (exclude quota_exceeded) - if classified_error.error_type == "rate_limit": - cooldown_duration = ( - classified_error.retry_after or 60 - ) - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." - ) - break - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." - ) - break - - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time > remaining_budget: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." - ) - break - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - except Exception as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Record in accumulator - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing." - ) - raise last_exception - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break - - # If the inner loop breaks, it means the key failed and we need to rotate. - # Continue to the next iteration of the outer while loop to pick a new key. - continue - - else: # This is the standard API Key / litellm-handled provider logic - is_oauth = provider in self.oauth_providers - if is_oauth: # Standard OAuth provider (not custom) - # ... (logic to set headers) ... - pass - else: # API Key - litellm_kwargs["api_key"] = current_cred - - provider_instance = self._get_provider_instance(provider) - if provider_instance: - # Ensure default Gemini safety settings are present (without overriding request) - try: - self._apply_default_safety_settings( - litellm_kwargs, provider - ) - except Exception: - lib_logger.debug( - "Could not apply default safety settings for streaming path; continuing." - ) - - if "safety_settings" in litellm_kwargs: - converted_settings = ( - provider_instance.convert_safety_settings( - litellm_kwargs["safety_settings"] - ) - ) - if converted_settings is not None: - litellm_kwargs["safety_settings"] = converted_settings - else: - del litellm_kwargs["safety_settings"] - - if provider == "gemini" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - if provider == "nvidia_nim" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - - if "gemma-3" in model and "messages" in litellm_kwargs: - litellm_kwargs["messages"] = [ - {"role": "user", "content": m["content"]} - if m.get("role") == "system" - else m - for m in litellm_kwargs["messages"] - ] - - litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) - - # If the provider is 'qwen_code', set the custom provider to 'qwen' - # and strip the prefix from the model name for LiteLLM. - if provider == "qwen_code": - litellm_kwargs["custom_llm_provider"] = "qwen" - litellm_kwargs["model"] = model.split("/", 1)[1] - - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback(request, litellm_kwargs) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - # lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}") - # Convert model parameters for custom providers right before LiteLLM call - final_kwargs = self.provider_config.convert_for_litellm( - **litellm_kwargs - ) - - response = await litellm.acompletion( - **final_kwargs, - logger_fn=self._litellm_logger_callback, - ) - - lib_logger.info( - f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." - ) - - key_acquired = False - stream_generator = self._safe_streaming_wrapper( - response, - current_cred, - model, - request, - provider_instance, - ) - - # Wrap with transaction logging - logged_stream = self._transaction_logging_stream_wrapper( - stream_generator, transaction_logger, kwargs - ) - - async for chunk in logged_stream: - yield chunk - return - - except ( - StreamedAPIError, - litellm.RateLimitError, - httpx.HTTPStatusError, - ) as e: - last_exception = e - - # This is the final, robust handler for streamed errors. - error_payload = {} - cleaned_str = None - # The actual exception might be wrapped in our StreamedAPIError. - original_exc = getattr(e, "data", e) - classified_error = classify_error( - original_exc, provider=provider - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}) during litellm stream. Failing." - ) - raise last_exception - - try: - # The full error JSON is in the string representation of the exception. - json_str_match = re.search( - r"(\{.*\})", str(original_exc), re.DOTALL - ) - if json_str_match: - cleaned_str = codecs.decode( - json_str_match.group(1), "unicode_escape" - ) - error_payload = json.loads(cleaned_str) - except (json.JSONDecodeError, TypeError): - error_payload = {} - - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - raw_response_text=cleaned_str, - ) - - error_details = error_payload.get("error", {}) - error_status = error_details.get("status", "") - error_message_text = error_details.get( - "message", str(original_exc).split("\n")[0] - ) - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message_text - ) - - if ( - "quota" in error_message_text.lower() - or "resource_exhausted" in error_status.lower() - ): - consecutive_quota_failures += 1 - - quota_value = "N/A" - quota_id = "N/A" - if "details" in error_details and isinstance( - error_details.get("details"), list - ): - for detail in error_details["details"]: - if isinstance(detail.get("violations"), list): - for violation in detail["violations"]: - if "quotaValue" in violation: - quota_value = violation[ - "quotaValue" - ] - if "quotaId" in violation: - quota_id = violation["quotaId"] - if ( - quota_value != "N/A" - and quota_id != "N/A" - ): - break - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - - if consecutive_quota_failures >= 3: - # Fatal: likely input data too large - client_error_message = ( - f"Request failed after 3 consecutive quota errors (input may be too large). " - f"Limit: {quota_value} (Quota ID: {quota_id})" - ) - lib_logger.error( - f"Fatal quota error for {mask_credential(current_cred)}. ID: {quota_id}, Limit: {quota_value}" - ) - yield f"data: {json.dumps({'error': {'message': client_error_message, 'type': 'proxy_fatal_quota_error'}})}\n\n" - yield "data: [DONE]\n\n" - return - else: - lib_logger.warning( - f"Cred {mask_credential(current_cred)} quota error ({consecutive_quota_failures}/3). Rotating." - ) - break - - else: - consecutive_quota_failures = 0 - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type}. Rotating." - ) - - if classified_error.error_type == "rate_limit": - cooldown_duration = ( - classified_error.retry_after or 60 - ) - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - consecutive_quota_failures = 0 - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message_text = str(e).split("\n")[0] - - # Record error in accumulator (server errors are transient, not abnormal) - error_accumulator.record_error( - current_cred, classified_error, error_message_text - ) - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - lib_logger.warning( - f"Credential {mask_credential(current_cred)} failed after max retries for model {model} due to a server error. Rotating key silently." - ) - # [MODIFIED] Do not yield to the client here. - break - - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time > remaining_budget: - lib_logger.warning( - f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early." - ) - break - - lib_logger.warning( - f"Credential {mask_credential(current_cred)} encountered a server error for model {model}. Reason: '{error_message_text}'. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - except Exception as e: - consecutive_quota_failures = 0 - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message_text = str(e).split("\n")[0] - - # Record error in accumulator - error_accumulator.record_error( - current_cred, classified_error, error_message_text - ) - - lib_logger.warning( - f"Credential {mask_credential(current_cred)} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message_text}." - ) - - # Handle rate limits with cooldown (exclude quota_exceeded) - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ) or classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - lib_logger.warning( - f"Rate limit detected for {provider}. Starting {cooldown_duration}s cooldown." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - # Non-rotatable errors - fail immediately - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing request." - ) - raise last_exception - - # Record failure and rotate to next key - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.info( - f"Rotating to next key after {classified_error.error_type} error." - ) - break - - finally: - if key_acquired and current_cred: - await self.usage_manager.release_key(current_cred, model) - - # Build detailed error response using error accumulator - error_accumulator.timeout_occurred = time.time() >= deadline - - if error_accumulator.has_errors(): - # Log concise summary for server logs - lib_logger.error(error_accumulator.build_log_message()) - - # Build structured error response for client - error_response = error_accumulator.build_client_error_response() - error_data = error_response - else: - # Fallback if no errors were recorded (shouldn't happen) - final_error_message = ( - "Request failed: No available API keys after rotation or timeout." - ) - if last_exception: - final_error_message = ( - f"Request failed. Last error: {str(last_exception)}" - ) - error_data = { - "error": {"message": final_error_message, "type": "proxy_error"} - } - lib_logger.error(final_error_message) - - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - - except NoAvailableKeysError as e: - lib_logger.error( - f"A streaming request failed because no keys were available within the time budget: {e}" - ) - error_data = {"error": {"message": str(e), "type": "proxy_busy"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - except Exception as e: - # This will now only catch fatal errors that should be raised, like invalid requests. - lib_logger.error( - f"An unhandled exception occurred in streaming retry logic: {e}", - exc_info=True, - ) - error_data = { - "error": { - "message": f"An unexpected error occurred: {str(e)}", - "type": "proxy_internal_error", - } - } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - - def acompletion( - self, - request: Optional[Any] = None, - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> Union[Any, AsyncGenerator[str, None]]: - """ - Dispatcher for completion requests. - - Args: - request: Optional request object, used for client disconnect checks and logging. - pre_request_callback: Optional async callback function to be called before each API request attempt. - The callback will receive the `request` object and the prepared request `kwargs` as arguments. - This can be used for custom logic such as request validation, logging, or rate limiting. - If the callback raises an exception, the completion request will be aborted and the exception will propagate. - - Returns: - The completion response object, or an async generator for streaming responses, or None if all retries fail. - """ - # Handle iflow provider: remove stream_options to avoid HTTP 406 - model = kwargs.get("model", "") - provider = model.split("/")[0] if "/" in model else "" - - if provider == "iflow" and "stream_options" in kwargs: - lib_logger.debug( - "Removing stream_options for iflow provider to avoid HTTP 406" - ) - kwargs.pop("stream_options", None) - - if kwargs.get("stream"): - # Only add stream_options for providers that support it (excluding iflow) - if provider != "iflow": - if "stream_options" not in kwargs: - kwargs["stream_options"] = {} - if "include_usage" not in kwargs["stream_options"]: - kwargs["stream_options"]["include_usage"] = True - - return self._streaming_acompletion_with_retry( - request=request, pre_request_callback=pre_request_callback, **kwargs - ) - else: - return self._execute_with_retry( - litellm.acompletion, - request=request, - pre_request_callback=pre_request_callback, - **kwargs, - ) - - def aembedding( - self, - request: Optional[Any] = None, - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> Any: - """ - Executes an embedding request with retry logic. - - Args: - request: Optional request object, used for client disconnect checks and logging. - pre_request_callback: Optional async callback function to be called before each API request attempt. - The callback will receive the `request` object and the prepared request `kwargs` as arguments. - This can be used for custom logic such as request validation, logging, or rate limiting. - If the callback raises an exception, the embedding request will be aborted and the exception will propagate. - - Returns: - The embedding response object, or None if all retries fail. - """ - return self._execute_with_retry( - litellm.aembedding, - request=request, - pre_request_callback=pre_request_callback, - **kwargs, - ) - - def token_count(self, **kwargs) -> int: - """Calculates the number of tokens for a given text or list of messages. - - For Antigravity provider models, this also includes the preprompt tokens - that get injected during actual API calls (agent instruction + identity override). - This ensures token counts match actual usage. - """ - model = kwargs.get("model") - text = kwargs.get("text") - messages = kwargs.get("messages") - - if not model: - raise ValueError("'model' is a required parameter.") - - # Calculate base token count - if messages: - base_count = token_counter(model=model, messages=messages) - elif text: - base_count = token_counter(model=model, text=text) - else: - raise ValueError("Either 'text' or 'messages' must be provided.") - - # Add preprompt tokens for Antigravity provider - # The Antigravity provider injects system instructions during actual API calls, - # so we need to account for those tokens in the count - provider = model.split("/")[0] if "/" in model else "" - if provider == "antigravity": - try: - from .providers.antigravity_provider import ( - get_antigravity_preprompt_text, - ) - - preprompt_text = get_antigravity_preprompt_text() - if preprompt_text: - preprompt_tokens = token_counter(model=model, text=preprompt_text) - base_count += preprompt_tokens - except ImportError: - # Provider not available, skip preprompt token counting - pass - - return base_count - - async def get_available_models(self, provider: str) -> List[str]: - """Returns a list of available models for a specific provider, with caching.""" - lib_logger.info(f"Getting available models for provider: {provider}") - if provider in self._model_list_cache: - lib_logger.debug(f"Returning cached models for provider: {provider}") - return self._model_list_cache[provider] - - credentials_for_provider = self.all_credentials.get(provider) - if not credentials_for_provider: - lib_logger.warning(f"No credentials for provider: {provider}") - return [] - - # Create a copy and shuffle it to randomize the starting credential - shuffled_credentials = list(credentials_for_provider) - random.shuffle(shuffled_credentials) - - provider_instance = self._get_provider_instance(provider) - if provider_instance: - # For providers with hardcoded models (like gemini_cli), we only need to call once. - # For others, we might need to try multiple keys if one is invalid. - # The current logic of iterating works for both, as the credential is not - # always used in get_models. - for credential in shuffled_credentials: - try: - # Display last 6 chars for API keys, or the filename for OAuth paths - cred_display = mask_credential(credential) - lib_logger.debug( - f"Attempting to get models for {provider} with credential {cred_display}" - ) - models = await provider_instance.get_models( - credential, self.http_client - ) - lib_logger.info( - f"Got {len(models)} models for provider: {provider}" - ) - - # Whitelist and blacklist logic - final_models = [] - for m in models: - is_whitelisted = self._is_model_whitelisted(provider, m) - is_blacklisted = self._is_model_ignored(provider, m) - - if is_whitelisted: - final_models.append(m) - continue - - if not is_blacklisted: - final_models.append(m) - - if len(final_models) != len(models): - lib_logger.info( - f"Filtered out {len(models) - len(final_models)} models for provider {provider}." - ) - - self._model_list_cache[provider] = final_models - return final_models - except Exception as e: - classified_error = classify_error(e, provider=provider) - cred_display = mask_credential(credential) - lib_logger.debug( - f"Failed to get models for provider {provider} with credential {cred_display}: {classified_error.error_type}. Trying next credential." - ) - continue # Try the next credential - - lib_logger.error( - f"Failed to get models for provider {provider} after trying all credentials." - ) - return [] - - async def get_all_available_models( - self, grouped: bool = True - ) -> Union[Dict[str, List[str]], List[str]]: - """Returns a list of all available models, either grouped by provider or as a flat list.""" - lib_logger.info("Getting all available models...") - - all_providers = list(self.all_credentials.keys()) - tasks = [self.get_available_models(provider) for provider in all_providers] - results = await asyncio.gather(*tasks, return_exceptions=True) - - all_provider_models = {} - for provider, result in zip(all_providers, results): - if isinstance(result, Exception): - lib_logger.error( - f"Failed to get models for provider {provider}: {result}" - ) - all_provider_models[provider] = [] - else: - all_provider_models[provider] = result - - lib_logger.info("Finished getting all available models.") - if grouped: - return all_provider_models - else: - flat_models = [] - for models in all_provider_models.values(): - flat_models.extend(models) - return flat_models - - async def get_quota_stats( - self, - provider_filter: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Get quota and usage stats for all credentials. - - This returns cached/disk data aggregated by provider. - For provider-specific quota info (e.g., Antigravity quota groups), - it enriches the data from provider plugins. - - Args: - provider_filter: If provided, only return stats for this provider - - Returns: - Complete stats dict ready for the /v1/quota-stats endpoint - """ - # Get base stats from usage manager - stats = await self.usage_manager.get_stats_for_endpoint(provider_filter) - - # Enrich with provider-specific quota data - for provider, prov_stats in stats.get("providers", {}).items(): - provider_class = self._provider_plugins.get(provider) - if not provider_class: - continue - - # Get or create provider instance - if provider not in self._provider_instances: - self._provider_instances[provider] = provider_class() - provider_instance = self._provider_instances[provider] - - # Check if provider has quota tracking (like Antigravity) - if hasattr(provider_instance, "_get_effective_quota_groups"): - # Add quota group summary - quota_groups = provider_instance._get_effective_quota_groups() - prov_stats["quota_groups"] = {} - - for group_name, group_models in quota_groups.items(): - group_stats = { - "models": group_models, - "credentials_total": 0, - "credentials_exhausted": 0, - "avg_remaining_pct": 0, - "total_remaining_pcts": [], - # Total requests tracking across all credentials - "total_requests_used": 0, - "total_requests_max": 0, - # Tier breakdown: tier_name -> {"total": N, "active": M} - "tiers": {}, - } - - # Calculate per-credential quota for this group - for cred in prov_stats.get("credentials", []): - models_data = cred.get("models", {}) - group_stats["credentials_total"] += 1 - - # Track tier - get directly from provider cache since cred["tier"] not set yet - tier = cred.get("tier") - if not tier and hasattr( - provider_instance, "project_tier_cache" - ): - cred_path = cred.get("full_path", "") - tier = provider_instance.project_tier_cache.get(cred_path) - tier = tier or "unknown" - - # Initialize tier entry if needed with priority for sorting - if tier not in group_stats["tiers"]: - priority = 10 # default - if hasattr(provider_instance, "_resolve_tier_priority"): - priority = provider_instance._resolve_tier_priority( - tier - ) - group_stats["tiers"][tier] = { - "total": 0, - "active": 0, - "priority": priority, - } - group_stats["tiers"][tier]["total"] += 1 - - # Find model with VALID baseline (not just any model with stats) - model_stats = None - for model in group_models: - candidate = self._find_model_stats_in_data( - models_data, model, provider, provider_instance - ) - if candidate: - baseline = candidate.get("baseline_remaining_fraction") - if baseline is not None: - model_stats = candidate - break - # Keep first found as fallback (for request counts) - if model_stats is None: - model_stats = candidate - - if model_stats: - baseline = model_stats.get("baseline_remaining_fraction") - req_count = model_stats.get("request_count", 0) - max_req = model_stats.get("quota_max_requests") or 0 - - # Accumulate totals (one model per group per credential) - group_stats["total_requests_used"] += req_count - group_stats["total_requests_max"] += max_req - - if baseline is not None: - remaining_pct = int(baseline * 100) - group_stats["total_remaining_pcts"].append( - remaining_pct - ) - if baseline <= 0: - group_stats["credentials_exhausted"] += 1 - else: - # Credential is active (has quota remaining) - group_stats["tiers"][tier]["active"] += 1 - - # Calculate average remaining percentage (per-credential average) - if group_stats["total_remaining_pcts"]: - group_stats["avg_remaining_pct"] = int( - sum(group_stats["total_remaining_pcts"]) - / len(group_stats["total_remaining_pcts"]) - ) - del group_stats["total_remaining_pcts"] - - # Calculate total remaining percentage (global) - if group_stats["total_requests_max"] > 0: - used = group_stats["total_requests_used"] - max_r = group_stats["total_requests_max"] - group_stats["total_requests_remaining"] = max_r - used - group_stats["total_remaining_pct"] = max( - 0, int((1 - used / max_r) * 100) - ) - else: - group_stats["total_requests_remaining"] = 0 - # Fallback to avg_remaining_pct when max_requests unavailable - # This handles providers like Firmware that only provide percentage - group_stats["total_remaining_pct"] = group_stats.get("avg_remaining_pct") - - prov_stats["quota_groups"][group_name] = group_stats - - # Also enrich each credential with formatted quota group info - for cred in prov_stats.get("credentials", []): - cred["model_groups"] = {} - models_data = cred.get("models", {}) - - for group_name, group_models in quota_groups.items(): - # Find model with VALID baseline (prefer over any model with stats) - # Also track the best reset_ts across all models in the group - model_stats = None - best_reset_ts = None - - for model in group_models: - candidate = self._find_model_stats_in_data( - models_data, model, provider, provider_instance - ) - if candidate: - # Track the best (latest) reset_ts from any model in group - candidate_reset_ts = candidate.get("quota_reset_ts") - if candidate_reset_ts: - if ( - best_reset_ts is None - or candidate_reset_ts > best_reset_ts - ): - best_reset_ts = candidate_reset_ts - - baseline = candidate.get("baseline_remaining_fraction") - if baseline is not None: - model_stats = candidate - # Don't break - continue to find best reset_ts - # Keep first found as fallback - if model_stats is None: - model_stats = candidate - - if model_stats: - baseline = model_stats.get("baseline_remaining_fraction") - max_req = model_stats.get("quota_max_requests") - req_count = model_stats.get("request_count", 0) - # Use best_reset_ts from any model in the group - reset_ts = best_reset_ts or model_stats.get( - "quota_reset_ts" - ) - - remaining_pct = ( - int(baseline * 100) if baseline is not None else None - ) - is_exhausted = baseline is not None and baseline <= 0 - - # Format reset time - reset_iso = None - if reset_ts: - try: - from datetime import datetime, timezone - - reset_iso = datetime.fromtimestamp( - reset_ts, tz=timezone.utc - ).isoformat() - except (ValueError, OSError): - pass - - requests_remaining = ( - max(0, max_req - req_count) if max_req else 0 - ) - - # Determine display format - # Priority: requests (if max known) > percentage (if baseline available) > unknown - if max_req: - display = f"{requests_remaining}/{max_req}" - elif remaining_pct is not None: - display = f"{remaining_pct}%" - else: - display = "?/?" - - cred["model_groups"][group_name] = { - "remaining_pct": remaining_pct, - "requests_used": req_count, - "requests_remaining": requests_remaining, - "requests_max": max_req, - "display": display, - "is_exhausted": is_exhausted, - "reset_time_iso": reset_iso, - "models": group_models, - "confidence": self._get_baseline_confidence( - model_stats - ), - } - - # Recalculate credential's requests from model_groups - # This fixes double-counting when models share quota groups - if cred.get("model_groups"): - group_requests = sum( - g.get("requests_used", 0) - for g in cred["model_groups"].values() - ) - cred["requests"] = group_requests - - # HACK: Fix global requests if present - # This is a simplified fix that sets global.requests = current group_requests. - # TODO: Properly track archived requests per quota group in usage_manager.py - # so that global stats correctly sum: current_period + archived_periods - # without double-counting models that share quota groups. - # See: usage_manager.py lines 2388-2404 where global stats are built - # by iterating all models (causing double-counting for grouped models). - if cred.get("global"): - cred["global"]["requests"] = group_requests - - # Try to get email from provider's cache - cred_path = cred.get("full_path", "") - if hasattr(provider_instance, "project_tier_cache"): - tier = provider_instance.project_tier_cache.get(cred_path) - if tier: - cred["tier"] = tier - - return stats - - def _find_model_stats_in_data( - self, - models_data: Dict[str, Any], - model: str, - provider: str, - provider_instance: Any, - ) -> Optional[Dict[str, Any]]: - """ - Find model stats in models_data, trying various name variants. - - Handles aliased model names (e.g., gemini-3-pro-preview -> gemini-3-pro-high) - by using the provider's _user_to_api_model() mapping. - - Args: - models_data: Dict of model_name -> stats from credential - model: Model name to look up (user-facing name) - provider: Provider name for prefixing - provider_instance: Provider instance for alias methods - - Returns: - Model stats dict if found, None otherwise - """ - # Try direct match with and without provider prefix - prefixed_model = f"{provider}/{model}" - model_stats = models_data.get(prefixed_model) or models_data.get(model) - - if model_stats: - return model_stats - - # Try with API model name (e.g., gemini-3-pro-preview -> gemini-3-pro-high) - if hasattr(provider_instance, "_user_to_api_model"): - api_model = provider_instance._user_to_api_model(model) - if api_model != model: - prefixed_api = f"{provider}/{api_model}" - model_stats = models_data.get(prefixed_api) or models_data.get( - api_model - ) - - return model_stats - - def _get_baseline_confidence(self, model_stats: Dict) -> str: - """ - Determine confidence level based on baseline age. - - Args: - model_stats: Model statistics dict with baseline_fetched_at - - Returns: - "high" | "medium" | "low" - """ - baseline_fetched_at = model_stats.get("baseline_fetched_at") - if not baseline_fetched_at: - return "low" - - age_seconds = time.time() - baseline_fetched_at - if age_seconds < 300: # 5 minutes - return "high" - elif age_seconds < 1800: # 30 minutes - return "medium" - return "low" - - async def reload_usage_from_disk(self) -> None: - """ - Force reload usage data from disk. - - Useful when wanting fresh stats without making external API calls. - """ - await self.usage_manager.reload_from_disk() - - async def force_refresh_quota( - self, - provider: Optional[str] = None, - credential: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Force refresh quota from external API. - - For Antigravity, this fetches live quota data from the API. - For other providers, this is a no-op (just reloads from disk). - - Args: - provider: If specified, only refresh this provider - credential: If specified, only refresh this specific credential - - Returns: - Refresh result dict with success/failure info - """ - result = { - "action": "force_refresh", - "scope": "credential" - if credential - else ("provider" if provider else "all"), - "provider": provider, - "credential": credential, - "credentials_refreshed": 0, - "success_count": 0, - "failed_count": 0, - "duration_ms": 0, - "errors": [], - } - - start_time = time.time() - - # Determine which providers to refresh - if provider: - providers_to_refresh = ( - [provider] if provider in self.all_credentials else [] - ) - else: - providers_to_refresh = list(self.all_credentials.keys()) - - for prov in providers_to_refresh: - provider_class = self._provider_plugins.get(prov) - if not provider_class: - continue - - # Get or create provider instance - if prov not in self._provider_instances: - self._provider_instances[prov] = provider_class() - provider_instance = self._provider_instances[prov] - - # Check if provider supports quota refresh (like Antigravity) - if hasattr(provider_instance, "fetch_initial_baselines"): - # Get credentials to refresh - if credential: - # Find full path for this credential - creds_to_refresh = [] - for cred_path in self.all_credentials.get(prov, []): - if cred_path.endswith(credential) or cred_path == credential: - creds_to_refresh.append(cred_path) - break - else: - creds_to_refresh = self.all_credentials.get(prov, []) - - if not creds_to_refresh: - continue - - try: - # Fetch live quota from API for ALL specified credentials - quota_results = await provider_instance.fetch_initial_baselines( - creds_to_refresh - ) - - # Store baselines in usage manager - if hasattr(provider_instance, "_store_baselines_to_usage_manager"): - stored = ( - await provider_instance._store_baselines_to_usage_manager( - quota_results, self.usage_manager - ) - ) - result["success_count"] += stored - - result["credentials_refreshed"] += len(creds_to_refresh) - - # Count failures - for cred_path, data in quota_results.items(): - if data.get("status") != "success": - result["failed_count"] += 1 - result["errors"].append( - f"{Path(cred_path).name}: {data.get('error', 'Unknown error')}" - ) - - except Exception as e: - lib_logger.error(f"Failed to refresh quota for {prov}: {e}") - result["errors"].append(f"{prov}: {str(e)}") - result["failed_count"] += len(creds_to_refresh) - - result["duration_ms"] = int((time.time() - start_time) * 1000) - return result - - # --- Anthropic API Compatibility Methods --- - - async def anthropic_messages( - self, - request: "AnthropicMessagesRequest", - raw_request: Optional[Any] = None, - pre_request_callback: Optional[callable] = None, - ) -> Any: - """ - Handle Anthropic Messages API requests. - - This method accepts requests in Anthropic's format, translates them to - OpenAI format internally, processes them through the existing acompletion - method, and returns responses in Anthropic's format. - - Args: - request: An AnthropicMessagesRequest object - raw_request: Optional raw request object for disconnect checks - pre_request_callback: Optional async callback before each API request - - Returns: - For non-streaming: dict in Anthropic Messages format - For streaming: AsyncGenerator yielding Anthropic SSE format strings - """ - from .anthropic_compat import ( - translate_anthropic_request, - openai_to_anthropic_response, - anthropic_streaming_wrapper, - ) - import uuid - - request_id = f"msg_{uuid.uuid4().hex[:24]}" - original_model = request.model - - # Extract provider from model for logging - provider = original_model.split("/")[0] if "/" in original_model else "unknown" - - # Create Anthropic transaction logger if request logging is enabled - anthropic_logger = None - if self.enable_request_logging: - anthropic_logger = TransactionLogger( - provider, - original_model, - enabled=True, - api_format="ant", - ) - # Log original Anthropic request - anthropic_logger.log_request( - request.model_dump(exclude_none=True), - filename="anthropic_request.json", - ) - - # Translate Anthropic request to OpenAI format - openai_request = translate_anthropic_request(request) - - # Pass parent log directory to acompletion for nested logging - if anthropic_logger and anthropic_logger.log_dir: - openai_request["_parent_log_dir"] = anthropic_logger.log_dir - - if request.stream: - # Streaming response - response_generator = self.acompletion( - request=raw_request, - pre_request_callback=pre_request_callback, - **openai_request, - ) - - # Create disconnect checker if raw_request provided - is_disconnected = None - if raw_request is not None and hasattr(raw_request, "is_disconnected"): - is_disconnected = raw_request.is_disconnected - - # Return the streaming wrapper - # Note: For streaming, the anthropic response logging happens in the wrapper - return anthropic_streaming_wrapper( - openai_stream=response_generator, - original_model=original_model, - request_id=request_id, - is_disconnected=is_disconnected, - transaction_logger=anthropic_logger, - ) - else: - # Non-streaming response - response = await self.acompletion( - request=raw_request, - pre_request_callback=pre_request_callback, - **openai_request, - ) - - # Convert OpenAI response to Anthropic format - openai_response = ( - response.model_dump() - if hasattr(response, "model_dump") - else dict(response) - ) - anthropic_response = openai_to_anthropic_response( - openai_response, original_model - ) - - # Override the ID with our request ID - anthropic_response["id"] = request_id - - # Log Anthropic response - if anthropic_logger: - anthropic_logger.log_response( - anthropic_response, - filename="anthropic_response.json", - ) - - return anthropic_response - - async def anthropic_count_tokens( - self, - request: "AnthropicCountTokensRequest", - ) -> dict: - """ - Handle Anthropic count_tokens API requests. - - Counts the number of tokens that would be used by a Messages API request. - This is useful for estimating costs and managing context windows. - - Args: - request: An AnthropicCountTokensRequest object - - Returns: - Dict with input_tokens count in Anthropic format - """ - from .anthropic_compat import ( - anthropic_to_openai_messages, - anthropic_to_openai_tools, - ) - import json - - anthropic_request = request.model_dump(exclude_none=True) - - openai_messages = anthropic_to_openai_messages( - anthropic_request.get("messages", []), anthropic_request.get("system") - ) - - # Count tokens for messages - message_tokens = self.token_count( - model=request.model, - messages=openai_messages, - ) - - # Count tokens for tools if present - tool_tokens = 0 - if request.tools: - # Tools add tokens based on their definitions - # Convert to JSON string and count tokens for tool definitions - openai_tools = anthropic_to_openai_tools( - [tool.model_dump() for tool in request.tools] - ) - if openai_tools: - # Serialize tools to count their token contribution - tools_text = json.dumps(openai_tools) - tool_tokens = self.token_count( - model=request.model, - text=tools_text, - ) - - total_tokens = message_tokens + tool_tokens - - return {"input_tokens": total_tokens} diff --git a/src/rotator_library/client/__init__.py b/src/rotator_library/client/__init__.py new file mode 100644 index 00000000..4307f84d --- /dev/null +++ b/src/rotator_library/client/__init__.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Client package for LLM API key rotation. + +This package provides the RotatingClient and associated components +for intelligent credential rotation and retry logic. + +Public API: + RotatingClient: Main client class for making API requests + StreamedAPIError: Exception for streaming errors + +Components (for advanced usage): + RequestExecutor: Unified retry/rotation logic + CredentialFilter: Tier compatibility filtering + ModelResolver: Model name resolution + ProviderTransforms: Provider-specific transforms + StreamingHandler: Streaming response processing +""" + +from .rotating_client import RotatingClient +from ..core.errors import StreamedAPIError + +# Also expose components for advanced usage +from .executor import RequestExecutor +from .filters import CredentialFilter +from .models import ModelResolver +from .transforms import ProviderTransforms +from .streaming import StreamingHandler +from .anthropic import AnthropicHandler +from .types import AvailabilityStats, RetryState, ExecutionResult + +__all__ = [ + # Main public API + "RotatingClient", + "StreamedAPIError", + # Components + "RequestExecutor", + "CredentialFilter", + "ModelResolver", + "ProviderTransforms", + "StreamingHandler", + "AnthropicHandler", + # Types + "AvailabilityStats", + "RetryState", + "ExecutionResult", +] diff --git a/src/rotator_library/client/anthropic.py b/src/rotator_library/client/anthropic.py new file mode 100644 index 00000000..507e82fb --- /dev/null +++ b/src/rotator_library/client/anthropic.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Anthropic API compatibility handler for RotatingClient. + +This module provides Anthropic SDK compatibility methods that allow using +Anthropic's Messages API format with the credential rotation system. +""" + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional + +from ..anthropic_compat import ( + AnthropicMessagesRequest, + AnthropicCountTokensRequest, + translate_anthropic_request, + openai_to_anthropic_response, + anthropic_streaming_wrapper, + anthropic_to_openai_messages, + anthropic_to_openai_tools, +) +from ..transaction_logger import TransactionLogger + +if TYPE_CHECKING: + from .rotating_client import RotatingClient + +lib_logger = logging.getLogger("rotator_library") + + +class AnthropicHandler: + """ + Handler for Anthropic API compatibility methods. + + This class provides methods to handle Anthropic Messages API requests + by translating them to OpenAI format, processing through the client's + acompletion method, and converting responses back to Anthropic format. + + Example: + handler = AnthropicHandler(client) + response = await handler.messages(request, raw_request) + """ + + def __init__(self, client: "RotatingClient"): + """ + Initialize the Anthropic handler. + + Args: + client: The RotatingClient instance to use for completions + """ + self._client = client + + async def messages( + self, + request: AnthropicMessagesRequest, + raw_request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + ) -> Any: + """ + Handle Anthropic Messages API requests. + + This method accepts requests in Anthropic's format, translates them to + OpenAI format internally, processes them through the existing acompletion + method, and returns responses in Anthropic's format. + + Args: + request: An AnthropicMessagesRequest object + raw_request: Optional raw request object for disconnect checks + pre_request_callback: Optional async callback before each API request + + Returns: + For non-streaming: dict in Anthropic Messages format + For streaming: AsyncGenerator yielding Anthropic SSE format strings + """ + request_id = f"msg_{uuid.uuid4().hex[:24]}" + original_model = request.model + + # Extract provider from model for logging + provider = original_model.split("/")[0] if "/" in original_model else "unknown" + + # Create Anthropic transaction logger if request logging is enabled + anthropic_logger = None + if self._client.enable_request_logging: + anthropic_logger = TransactionLogger( + provider, + original_model, + enabled=True, + api_format="ant", + ) + # Log original Anthropic request + anthropic_logger.log_request( + request.model_dump(exclude_none=True), + filename="anthropic_request.json", + ) + + # Translate Anthropic request to OpenAI format + openai_request = translate_anthropic_request(request) + + # Pass parent log directory to acompletion for nested logging + if anthropic_logger and anthropic_logger.log_dir: + openai_request["_parent_log_dir"] = anthropic_logger.log_dir + + if request.stream: + # Streaming response + response_generator = await self._client.acompletion( + request=raw_request, + pre_request_callback=pre_request_callback, + **openai_request, + ) + + # Create disconnect checker if raw_request provided + is_disconnected = None + if raw_request is not None and hasattr(raw_request, "is_disconnected"): + is_disconnected = raw_request.is_disconnected + + # Return the streaming wrapper + # Note: For streaming, the anthropic response logging happens in the wrapper + return anthropic_streaming_wrapper( + openai_stream=response_generator, + original_model=original_model, + request_id=request_id, + is_disconnected=is_disconnected, + transaction_logger=anthropic_logger, + ) + else: + # Non-streaming response + response = await self._client.acompletion( + request=raw_request, + pre_request_callback=pre_request_callback, + **openai_request, + ) + + # Convert OpenAI response to Anthropic format + openai_response = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) + ) + anthropic_response = openai_to_anthropic_response( + openai_response, original_model + ) + + # Override the ID with our request ID + anthropic_response["id"] = request_id + + # Log Anthropic response + if anthropic_logger: + anthropic_logger.log_response( + anthropic_response, + filename="anthropic_response.json", + ) + + return anthropic_response + + async def count_tokens( + self, + request: AnthropicCountTokensRequest, + ) -> dict: + """ + Handle Anthropic count_tokens API requests. + + Counts the number of tokens that would be used by a Messages API request. + This is useful for estimating costs and managing context windows. + + Args: + request: An AnthropicCountTokensRequest object + + Returns: + Dict with input_tokens count in Anthropic format + """ + anthropic_request = request.model_dump(exclude_none=True) + + openai_messages = anthropic_to_openai_messages( + anthropic_request.get("messages", []), anthropic_request.get("system") + ) + + # Count tokens for messages + message_tokens = self._client.token_count( + model=request.model, + messages=openai_messages, + ) + + # Count tokens for tools if present + tool_tokens = 0 + if request.tools: + # Tools add tokens based on their definitions + # Convert to JSON string and count tokens for tool definitions + openai_tools = anthropic_to_openai_tools( + [tool.model_dump() for tool in request.tools] + ) + if openai_tools: + # Serialize tools to count their token contribution + tools_text = json.dumps(openai_tools) + tool_tokens = self._client.token_count( + model=request.model, + text=tools_text, + ) + + total_tokens = message_tokens + tool_tokens + + return {"input_tokens": total_tokens} diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py new file mode 100644 index 00000000..8f50bacb --- /dev/null +++ b/src/rotator_library/client/executor.py @@ -0,0 +1,1322 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Unified request execution with retry and rotation. + +This module extracts and unifies the retry logic that was duplicated in: +- _execute_with_retry (lines 1174-1945) +- _streaming_acompletion_with_retry (lines 1947-2780) + +The RequestExecutor provides a single code path for all request types, +with streaming vs non-streaming handled as a parameter. +""" + +import asyncio +import json +import logging +import os +import random +import time +from typing import ( + Any, + AsyncGenerator, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Union, +) + +import httpx +import litellm +from litellm.exceptions import ( + APIConnectionError, + RateLimitError, + ServiceUnavailableError, + InternalServerError, +) + +from ..core.types import RequestContext, ErrorAction +from ..core.errors import ( + NoAvailableKeysError, + PreRequestCallbackError, + StreamedAPIError, + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + mask_credential, +) +from ..core.constants import ( + DEFAULT_MAX_RETRIES, + DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD, +) +from ..request_sanitizer import sanitize_request_payload +from ..transaction_logger import TransactionLogger +from ..failure_logger import log_failure + +from .types import RetryState, AvailabilityStats +from .filters import CredentialFilter +from .transforms import ProviderTransforms +from .streaming import StreamingHandler + +if TYPE_CHECKING: + from ..usage import UsageManager + +lib_logger = logging.getLogger("rotator_library") + + +class RequestExecutor: + """ + Unified retry/rotation logic for all request types. + + This class handles: + - Credential rotation across providers + - Per-credential retry with backoff + - Error classification and handling + - Streaming and non-streaming requests + """ + + def __init__( + self, + usage_managers: Dict[str, "UsageManager"], + cooldown_manager: Any, + credential_filter: CredentialFilter, + provider_transforms: ProviderTransforms, + provider_plugins: Dict[str, Any], + http_client: httpx.AsyncClient, + max_retries: int = DEFAULT_MAX_RETRIES, + global_timeout: int = 30, + abort_on_callback_error: bool = True, + litellm_provider_params: Optional[Dict[str, Any]] = None, + litellm_logger_fn: Optional[Any] = None, + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize RequestExecutor. + + Args: + usage_managers: Dict mapping provider names to UsageManager instances + cooldown_manager: CooldownManager instance + credential_filter: CredentialFilter instance + provider_transforms: ProviderTransforms instance + provider_plugins: Dict mapping provider names to plugin classes + http_client: Shared httpx.AsyncClient for provider requests + max_retries: Max retries per credential + global_timeout: Global request timeout in seconds + abort_on_callback_error: Abort on pre-request callback errors + litellm_provider_params: Optional dict of provider-specific LiteLLM + parameters to merge into requests (e.g., custom headers, timeouts) + litellm_logger_fn: Optional callback function for LiteLLM logging + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._usage_managers = usage_managers + self._cooldown = cooldown_manager + self._filter = credential_filter + self._transforms = provider_transforms + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + self._http_client = http_client + self._max_retries = max_retries + self._global_timeout = global_timeout + self._abort_on_callback_error = abort_on_callback_error + self._litellm_provider_params = litellm_provider_params or {} + self._litellm_logger_fn = litellm_logger_fn + # StreamingHandler no longer needs usage_manager - we pass cred_context directly + self._streaming_handler = StreamingHandler() + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """Get or create a plugin instance for a provider.""" + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def _has_tier_support(self, provider: str) -> bool: + """ + Check if provider has tier/priority configuration. + + Providers with tier support define tier_priorities mapping + (e.g., Antigravity, GeminiCli, NanoGpt). + + Args: + provider: Provider name + + Returns: + True if provider has tier configuration, False otherwise + """ + plugin = self._get_plugin_instance(provider) + if not plugin: + return False + tier_priorities = getattr(plugin, "tier_priorities", {}) + return bool(tier_priorities) + + def _get_usage_display( + self, + state: Any, + model: str, + quota_group: Optional[str], + usage_manager: "UsageManager", + ) -> int: + """ + Get usage count from the primary window. + + This returns the same usage count used for credential selection, + ensuring consistency between what's logged and what's used for rotation. + + Args: + state: CredentialState object + model: Model name + quota_group: Optional quota group name + usage_manager: UsageManager instance + + Returns: + Request count from primary window, or 0 if unavailable + """ + if not state: + return 0 + + window_manager = getattr(usage_manager, "window_manager", None) + if not window_manager: + return state.totals.request_count + + primary_def = window_manager.get_primary_definition() + if not primary_def: + return state.totals.request_count + + # Get windows based on what the primary definition applies to + # This mirrors the logic in selection/engine.py:_get_usage_count + windows = None + if primary_def.applies_to == "model": + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + elif primary_def.applies_to == "group": + group_key = quota_group or model + group_stats = state.get_group_stats(group_key, create=False) + if group_stats: + windows = group_stats.windows + + if windows: + window = window_manager.get_active_window(windows, primary_def.name) + if window: + return window.request_count + + return state.totals.request_count + + def _get_quota_display( + self, + state: Any, + model: str, + quota_group: Optional[str], + usage_manager: "UsageManager", + ) -> str: + """ + Get quota display string for logging. + + Checks group stats first (if quota_group provided), then falls back + to model stats. Returns a formatted string like "5/50 [90%]". + + Args: + state: CredentialState object + model: Model name + quota_group: Optional quota group name + usage_manager: UsageManager instance + + Returns: + Formatted quota display string, or "?/?" if unavailable + """ + if not state: + return "?/?" + + window_manager = getattr(usage_manager, "window_manager", None) + if not window_manager: + return "?/?" + + primary_def = window_manager.get_primary_definition() + if not primary_def: + return "?/?" + + window = None + # Check GROUP first if quota_group provided (shared limits) + if quota_group: + group_stats = state.get_group_stats(quota_group, create=False) + if group_stats: + window = group_stats.windows.get(primary_def.name) + + # Fall back to MODEL if no group limit found + if window is None or window.limit is None: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + window = model_stats.windows.get(primary_def.name) + + # Display quota if we found a window with a limit + if window and window.limit is not None: + remaining = max(0, window.limit - window.request_count) + pct = round(remaining / window.limit * 100) if window.limit else 0 + return f"{window.request_count}/{window.limit} [{pct}%]" + + return "?/?" + + def _log_acquiring_credential( + self, + model: str, + tried_count: int, + availability: Dict[str, Any], + ) -> None: + """ + Log credential acquisition attempt with availability info. + + Args: + model: Model name + tried_count: Number of credentials already tried + availability: Availability stats dict from usage manager + """ + blocked = availability.get("blocked_by", {}) + blocked_parts = [] + if blocked.get("cooldowns"): + blocked_parts.append(f"cd:{blocked['cooldowns']}") + if blocked.get("fair_cycle"): + blocked_parts.append(f"fc:{blocked['fair_cycle']}") + if blocked.get("custom_caps"): + blocked_parts.append(f"cap:{blocked['custom_caps']}") + if blocked.get("window_limits"): + blocked_parts.append(f"wl:{blocked['window_limits']}") + if blocked.get("concurrent"): + blocked_parts.append(f"con:{blocked['concurrent']}") + blocked_str = f"({', '.join(blocked_parts)})" if blocked_parts else "" + lib_logger.info( + f"Acquiring credential for model {model}. Tried: {tried_count}/" + f"{availability.get('available', 0)}({availability.get('total', 0)}{blocked_str})" + ) + + async def _prepare_request_kwargs( + self, + provider: str, + model: str, + cred: str, + context: "RequestContext", + ) -> Dict[str, Any]: + """ + Prepare request kwargs with transforms, sanitization, and provider params. + + Args: + provider: Provider name + model: Model name + cred: Credential string + context: Request context + + Returns: + Prepared kwargs dict for the LiteLLM call + """ + # Apply transforms + kwargs = await self._transforms.apply( + provider, model, cred, context.kwargs.copy() + ) + + # Sanitize request payload + kwargs = sanitize_request_payload(kwargs, model) + + # Apply provider-specific LiteLLM params + self._apply_litellm_provider_params(provider, kwargs) + + # Add transaction context for provider logging + if context.transaction_logger: + kwargs["transaction_context"] = context.transaction_logger.get_context() + + return kwargs + + def _log_acquired_credential( + self, + cred: str, + model: str, + state: Any, + quota_group: Optional[str], + availability: Dict[str, Any], + usage_manager: "UsageManager", + ) -> None: + """ + Log successful credential acquisition. + + Format varies based on provider capabilities: + - Providers with tier support: (tier, priority, selection, quota) + - Providers without tiers but with quotas: (selection, quota) + - Providers without tiers or quotas: (selection, usage) + + Args: + cred: Credential string + model: Model name + state: CredentialState object + quota_group: Optional quota group + availability: Availability stats dict + usage_manager: UsageManager instance + """ + selection_mode = availability.get("rotation_mode") + + # Extract provider from model (e.g., "nvidia_nim" from "nvidia_nim/deepseek-ai/...") + provider = model.split("/")[0] if "/" in model else None + + if provider and self._has_tier_support(provider): + # Full format with tier/priority/quota for providers with tier configuration + tier = state.tier if state else None + priority = state.priority if state else None + quota_display = self._get_quota_display( + state, model, quota_group, usage_manager + ) + lib_logger.info( + f"Acquired key {mask_credential(cred)} for model {model} " + f"(tier: {tier}, priority: {priority}, selection: {selection_mode}, quota: {quota_display})" + ) + else: + # Simple format for providers without tier configuration + # Check if there's quota info available (limit set on window) + quota_display = self._get_quota_display( + state, model, quota_group, usage_manager + ) + if quota_display != "?/?": + # Has quota limits - show selection and quota + lib_logger.info( + f"Acquired key {mask_credential(cred)} for model {model} " + f"(selection: {selection_mode}, quota: {quota_display})" + ) + else: + # No quota limits - show selection and usage from primary window + usage = self._get_usage_display( + state, model, quota_group, usage_manager + ) + lib_logger.info( + f"Acquired key {mask_credential(cred)} for model {model} " + f"(selection: {selection_mode}, usage: {usage})" + ) + + async def _run_pre_request_callback( + self, + context: "RequestContext", + kwargs: Dict[str, Any], + ) -> None: + """ + Run pre-request callback if configured. + + Args: + context: Request context + kwargs: Request kwargs + + Raises: + PreRequestCallbackError: If callback fails and abort_on_callback_error is True + """ + if context.pre_request_callback: + try: + await context.pre_request_callback(context.request, kwargs) + except Exception as e: + if self._abort_on_callback_error: + raise PreRequestCallbackError(str(e)) from e + lib_logger.warning(f"Pre-request callback failed: {e}") + + async def execute( + self, + context: RequestContext, + ) -> Union[Any, AsyncGenerator[str, None]]: + """ + Execute request with retry/rotation. + + This is the main entry point for request execution. + + Args: + context: RequestContext with all request details + + Returns: + Response object or async generator for streaming + """ + if context.streaming: + return self._execute_streaming(context) + else: + return await self._execute_non_streaming(context) + + async def _prepare_execution( + self, + context: RequestContext, + ) -> Tuple["UsageManager", Any, List[str], Optional[str], Dict[str, Any]]: + provider = context.provider + model = context.model + + usage_manager = self._usage_managers.get(provider) + if not usage_manager: + raise NoAvailableKeysError(f"No UsageManager for provider {provider}") + + filter_result = self._filter.filter_by_tier( + context.credentials, model, provider + ) + credentials = filter_result.all_usable + quota_group = usage_manager.get_model_quota_group(model) + + await self._ensure_initialized(usage_manager, context, filter_result) + await self._validate_request(provider, model, context.kwargs) + + if not credentials: + raise NoAvailableKeysError(f"No compatible credentials for model {model}") + + request_headers = ( + dict(context.request.headers) if context.request is not None else {} + ) + + return usage_manager, filter_result, credentials, quota_group, request_headers + + async def _execute_non_streaming( + self, + context: RequestContext, + ) -> Any: + """ + Execute non-streaming request with retry/rotation. + + Args: + context: RequestContext with all request details + + Returns: + Response object + """ + provider = context.provider + model = context.model + deadline = context.deadline + + ( + usage_manager, + filter_result, + credentials, + quota_group, + request_headers, + ) = await self._prepare_execution(context) + + error_accumulator = RequestErrorAccumulator() + error_accumulator.model = model + error_accumulator.provider = provider + + retry_state = RetryState() + last_exception: Optional[Exception] = None + + while time.time() < deadline: + # Check for untried credentials + untried = [c for c in credentials if c not in retry_state.tried_credentials] + if not untried: + lib_logger.warning( + f"All {len(credentials)} credentials tried for {model}" + ) + break + + # Wait for provider cooldown + await self._wait_for_cooldown(provider, deadline) + + # Acquire credential using context manager + try: + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + self._log_acquiring_credential( + model, len(retry_state.tried_credentials), availability + ) + async with await usage_manager.acquire_credential( + model=model, + quota_group=quota_group, + candidates=untried, + priorities=filter_result.priorities, + deadline=deadline, + ) as cred_context: + cred = cred_context.credential + retry_state.record_attempt(cred) + + state = getattr(usage_manager, "states", {}).get( + cred_context.stable_id + ) + self._log_acquired_credential( + cred, model, state, quota_group, availability, usage_manager + ) + + try: + # Prepare request kwargs + kwargs = await self._prepare_request_kwargs( + provider, model, cred, context + ) + + # Get provider plugin + plugin = self._get_plugin_instance(provider) + + # Execute request with retries + for attempt in range(self._max_retries): + try: + lib_logger.info( + f"Attempting call with credential {mask_credential(cred)} " + f"(Attempt {attempt + 1}/{self._max_retries})" + ) + # Pre-request callback + await self._run_pre_request_callback(context, kwargs) + + # Make the API call + if plugin and plugin.has_custom_logic(): + kwargs["credential_identifier"] = cred + response = await plugin.acompletion( + self._http_client, **kwargs + ) + else: + # Standard LiteLLM call + kwargs["api_key"] = cred + self._apply_litellm_logger(kwargs) + # Remove internal context before litellm call + kwargs.pop("transaction_context", None) + response = await litellm.acompletion(**kwargs) + + # Success! Extract token usage if available + ( + prompt_tokens, + completion_tokens, + prompt_tokens_cached, + prompt_tokens_cache_write, + thinking_tokens, + ) = self._extract_usage_tokens(response) + approx_cost = self._calculate_cost( + provider, model, response + ) + response_headers = self._extract_response_headers( + response + ) + + cred_context.mark_success( + response=response, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cached, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + response_headers=response_headers, + ) + + lib_logger.info( + f"Recorded usage from response object for key {mask_credential(cred)}" + ) + + # Log response if transaction logging enabled + if context.transaction_logger: + try: + response_data = ( + response.model_dump() + if hasattr(response, "model_dump") + else response + ) + context.transaction_logger.log_response( + response_data + ) + except Exception as log_err: + lib_logger.debug( + f"Failed to log response: {log_err}" + ) + + return response + + except Exception as e: + last_exception = e + action = await self._handle_error_with_context( + e, + cred_context, + model, + provider, + attempt, + error_accumulator, + retry_state, + request_headers, + ) + + if action == ErrorAction.RETRY_SAME: + continue + elif action == ErrorAction.ROTATE: + break # Try next credential + else: # FAIL + raise + + except PreRequestCallbackError: + raise + except Exception: + # Let context manager handle cleanup + pass + + except NoAvailableKeysError: + break + + # All credentials exhausted + error_accumulator.timeout_occurred = time.time() >= deadline + if last_exception and not error_accumulator.has_errors(): + raise last_exception + + # Return error response + return error_accumulator.build_client_error_response() + + async def _execute_streaming( + self, + context: RequestContext, + ) -> AsyncGenerator[str, None]: + """ + Execute streaming request with retry/rotation. + + This is an async generator that yields SSE-formatted strings. + + Args: + context: RequestContext with all request details + + Yields: + SSE-formatted strings + """ + provider = context.provider + model = context.model + deadline = context.deadline + + try: + ( + usage_manager, + filter_result, + credentials, + quota_group, + request_headers, + ) = await self._prepare_execution(context) + except NoAvailableKeysError as exc: + error_data = { + "error": { + "message": str(exc), + "type": "proxy_error", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + + error_accumulator = RequestErrorAccumulator() + error_accumulator.model = model + error_accumulator.provider = provider + + retry_state = RetryState() + last_exception: Optional[Exception] = None + + try: + while time.time() < deadline: + # Check for untried credentials + untried = [ + c for c in credentials if c not in retry_state.tried_credentials + ] + if not untried: + lib_logger.warning( + f"All {len(credentials)} credentials tried for {model}" + ) + break + + # Wait for provider cooldown + remaining = deadline - time.time() + if remaining <= 0: + break + await self._wait_for_cooldown(provider, deadline) + + # Acquire credential using context manager + try: + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + self._log_acquiring_credential( + model, len(retry_state.tried_credentials), availability + ) + async with await usage_manager.acquire_credential( + model=model, + quota_group=quota_group, + candidates=untried, + priorities=filter_result.priorities, + deadline=deadline, + ) as cred_context: + cred = cred_context.credential + retry_state.record_attempt(cred) + + state = getattr(usage_manager, "states", {}).get( + cred_context.stable_id + ) + self._log_acquired_credential( + cred, model, state, quota_group, availability, usage_manager + ) + + try: + # Prepare request kwargs + kwargs = await self._prepare_request_kwargs( + provider, model, cred, context + ) + + # Add stream options (but not for iflow - it returns 406) + if provider != "iflow": + if "stream_options" not in kwargs: + kwargs["stream_options"] = {} + if "include_usage" not in kwargs["stream_options"]: + kwargs["stream_options"]["include_usage"] = True + + # Get provider plugin + plugin = self._get_plugin_instance(provider) + skip_cost_calculation = bool( + plugin + and getattr(plugin, "skip_cost_calculation", False) + ) + + # Execute request with retries + for attempt in range(self._max_retries): + try: + lib_logger.info( + f"Attempting stream with credential {mask_credential(cred)} " + f"(Attempt {attempt + 1}/{self._max_retries})" + ) + # Pre-request callback + await self._run_pre_request_callback( + context, kwargs + ) + + # Make the API call + if plugin and plugin.has_custom_logic(): + kwargs["credential_identifier"] = cred + stream = await plugin.acompletion( + self._http_client, **kwargs + ) + else: + kwargs["api_key"] = cred + kwargs["stream"] = True + self._apply_litellm_logger(kwargs) + # Remove internal context before litellm call + kwargs.pop("transaction_context", None) + stream = await litellm.acompletion(**kwargs) + + # Hand off to streaming handler with cred_context + # The handler will call mark_success on completion + base_stream = self._streaming_handler.wrap_stream( + stream, + cred, + model, + context.request, + cred_context, + skip_cost_calculation=skip_cost_calculation, + ) + + lib_logger.info( + f"Stream connection established for credential {mask_credential(cred)}. " + "Processing response." + ) + + # Wrap with transaction logging if enabled + if context.transaction_logger: + async for ( + chunk + ) in self._transaction_logging_stream_wrapper( + base_stream, + context.transaction_logger, + context.kwargs, + ): + yield chunk + else: + async for chunk in base_stream: + yield chunk + return + + except StreamedAPIError as e: + last_exception = e + original = getattr(e, "data", e) + classified = classify_error(original, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(original)[:150] + ) + + # Track consecutive quota failures + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + lib_logger.error( + "3 consecutive quota errors in streaming - " + "request may be too large" + ) + cred_context.mark_failure(classified) + error_data = { + "error": { + "message": "Request exceeds quota for all credentials", + "type": "quota_exhausted", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + else: + retry_state.reset_quota_failures() + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except (RateLimitError, httpx.HTTPStatusError) as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + + # Track consecutive quota failures + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + lib_logger.error( + "3 consecutive quota errors in streaming - " + "request may be too large" + ) + cred_context.mark_failure(classified) + error_data = { + "error": { + "message": "Request exceeds quota for all credentials", + "type": "quota_exhausted", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + else: + retry_state.reset_quota_failures() + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + # Check for small cooldown - retry same key instead of rotating + small_cooldown_threshold = int( + os.environ.get( + "SMALL_COOLDOWN_RETRY_THRESHOLD", + DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD, + ) + ) + if ( + classified.retry_after is not None + and 0 + < classified.retry_after + < small_cooldown_threshold + and attempt < self._max_retries - 1 + ): + remaining = deadline - time.time() + if classified.retry_after <= remaining: + lib_logger.info( + f"Retrying {mask_credential(cred)} in {classified.retry_after:.1f}s " + f"(small cooldown {classified.retry_after}s < {small_cooldown_threshold}s threshold)" + ) + await asyncio.sleep(classified.retry_after) + continue # Retry same key + + cred_context.mark_failure(classified) + break # Rotate + + except ( + APIConnectionError, + InternalServerError, + ServiceUnavailableError, + ) as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + + if attempt >= self._max_retries - 1: + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + cred_context.mark_failure(classified) + break # Rotate + + # Calculate wait time + wait_time = classified.retry_after or ( + 2**attempt + ) + random.uniform(0, 1) + remaining = deadline - time.time() + if wait_time > remaining: + break # No time to wait + + await asyncio.sleep(wait_time) + continue # Retry + + except Exception as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except PreRequestCallbackError: + raise + except Exception: + # Let context manager handle cleanup + pass + + except NoAvailableKeysError: + break + + # All credentials exhausted or timeout + error_accumulator.timeout_occurred = time.time() >= deadline + error_data = error_accumulator.build_client_error_response() + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + except NoAvailableKeysError as e: + lib_logger.error(f"No keys available: {e}") + error_data = {"error": {"message": str(e), "type": "proxy_busy"}} + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + lib_logger.error(f"Unhandled exception in streaming: {e}", exc_info=True) + error_data = {"error": {"message": str(e), "type": "proxy_internal_error"}} + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + def _apply_litellm_provider_params( + self, provider: str, kwargs: Dict[str, Any] + ) -> None: + """Merge provider-specific LiteLLM parameters into request kwargs.""" + params = self._litellm_provider_params.get(provider) + if not params: + return + kwargs["litellm_params"] = { + **params, + **kwargs.get("litellm_params", {}), + } + + def _apply_litellm_logger(self, kwargs: Dict[str, Any]) -> None: + """Attach LiteLLM logger callback if configured.""" + if self._litellm_logger_fn and "logger_fn" not in kwargs: + kwargs["logger_fn"] = self._litellm_logger_fn + + def _extract_response_headers(self, response: Any) -> Optional[Dict[str, Any]]: + """Extract response headers from LiteLLM response objects.""" + if hasattr(response, "response") and response.response is not None: + headers = getattr(response.response, "headers", None) + if headers is not None: + return dict(headers) + headers = getattr(response, "headers", None) + if headers is not None: + return dict(headers) + return None + + async def _wait_for_cooldown( + self, + provider: str, + deadline: float, + ) -> None: + """ + Wait for provider-level cooldown to end. + + Args: + provider: Provider name + deadline: Request deadline + """ + if not self._cooldown: + return + + remaining = await self._cooldown.get_remaining_cooldown(provider) + if remaining > 0: + budget = deadline - time.time() + if remaining > budget: + lib_logger.warning( + f"Provider {provider} cooldown ({remaining:.1f}s) exceeds budget ({budget:.1f}s)" + ) + return # Will fail on no keys available + lib_logger.info(f"Waiting {remaining:.1f}s for {provider} cooldown") + await asyncio.sleep(remaining) + + async def _handle_error_with_context( + self, + error: Exception, + cred_context: Any, # CredentialContext + model: str, + provider: str, + attempt: int, + error_accumulator: RequestErrorAccumulator, + retry_state: RetryState, + request_headers: Dict[str, Any], + ) -> str: + """ + Handle an error and determine next action. + + Args: + error: The caught exception + cred_context: CredentialContext for marking failure + model: Model name + provider: Provider name + attempt: Current attempt number + error_accumulator: Error tracking + retry_state: Retry state tracking + + Returns: + ErrorAction indicating what to do next + """ + classified = classify_error(error, provider) + error_message = str(error)[:150] + credential = cred_context.credential + + log_failure( + api_key=credential, + model=model, + attempt=attempt + 1, + error=error, + request_headers=request_headers, + ) + + # Check for quota errors + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + # Likely request is too large + lib_logger.error( + f"3 consecutive quota errors - request may be too large" + ) + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + return ErrorAction.FAIL + else: + retry_state.reset_quota_failures() + + # Check if should rotate + if not should_rotate_on_error(classified): + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + return ErrorAction.FAIL + + # Check if should retry same key (including small cooldown auto-retry) + small_cooldown_threshold = int( + os.environ.get( + "SMALL_COOLDOWN_RETRY_THRESHOLD", DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD + ) + ) + is_small_cooldown = ( + classified.retry_after is not None + and 0 < classified.retry_after < small_cooldown_threshold + ) + + if ( + should_retry_same_key(classified, small_cooldown_threshold) + and attempt < self._max_retries - 1 + ): + wait_time = classified.retry_after or (2**attempt) + random.uniform(0, 1) + retry_reason = ( + f" (small cooldown {classified.retry_after}s < {small_cooldown_threshold}s threshold)" + if is_small_cooldown + else "" + ) + lib_logger.info( + f"Retrying {mask_credential(credential)} in {wait_time:.1f}s{retry_reason}" + ) + await asyncio.sleep(wait_time) + return ErrorAction.RETRY_SAME + + # Record error and rotate + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + lib_logger.info( + f"Rotating from {mask_credential(credential)} after {classified.error_type}" + ) + return ErrorAction.ROTATE + + async def _ensure_initialized( + self, + usage_manager: "UsageManager", + context: RequestContext, + filter_result: "FilterResult", + ) -> None: + if usage_manager.initialized: + return + await usage_manager.initialize( + context.credentials, + priorities=filter_result.priorities, + tiers=filter_result.tier_names, + ) + + async def _validate_request( + self, + provider: str, + model: str, + kwargs: Dict[str, Any], + ) -> None: + plugin = self._get_plugin_instance(provider) + if not plugin or not hasattr(plugin, "validate_request"): + return + + result = plugin.validate_request(kwargs, model) + if asyncio.iscoroutine(result): + result = await result + if result is False: + raise ValueError(f"Request validation failed for {provider}/{model}") + if isinstance(result, str): + raise ValueError(result) + + def _extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cache_write_tokens = 0 + thinking_tokens = 0 + + if hasattr(response, "usage") and response.usage: + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 + + prompt_details = getattr(response.usage, "prompt_tokens_details", None) + if prompt_details: + if isinstance(prompt_details, dict): + cached_tokens = prompt_details.get("cached_tokens", 0) or 0 + cache_write_tokens = ( + prompt_details.get("cache_creation_tokens", 0) or 0 + ) + else: + cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0 + cache_write_tokens = ( + getattr(prompt_details, "cache_creation_tokens", 0) or 0 + ) + + completion_details = getattr( + response.usage, "completion_tokens_details", None + ) + if completion_details: + if isinstance(completion_details, dict): + thinking_tokens = completion_details.get("reasoning_tokens", 0) or 0 + else: + thinking_tokens = ( + getattr(completion_details, "reasoning_tokens", 0) or 0 + ) + + cache_read_tokens = getattr(response.usage, "cache_read_tokens", None) + if cache_read_tokens is not None: + cached_tokens = cache_read_tokens or 0 + cache_creation_tokens = getattr( + response.usage, "cache_creation_tokens", None + ) + if cache_creation_tokens is not None: + cache_write_tokens = cache_creation_tokens or 0 + + if thinking_tokens and completion_tokens >= thinking_tokens: + completion_tokens = completion_tokens - thinking_tokens + + uncached_prompt = max(0, prompt_tokens - cached_tokens) + return ( + uncached_prompt, + completion_tokens, + cached_tokens, + cache_write_tokens, + thinking_tokens, + ) + + def _calculate_cost(self, provider: str, model: str, response: Any) -> float: + plugin = self._get_plugin_instance(provider) + if plugin and getattr(plugin, "skip_cost_calculation", False): + return 0.0 + + try: + if isinstance(response, litellm.EmbeddingResponse): + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token") + if input_cost: + return (response.usage.prompt_tokens or 0) * input_cost + return 0.0 + + cost = litellm.completion_cost( + completion_response=response, + model=model, + ) + return float(cost) if cost is not None else 0.0 + except Exception as exc: + lib_logger.debug(f"Cost calculation failed for {model}: {exc}") + return 0.0 + + async def _transaction_logging_stream_wrapper( + self, + stream: AsyncGenerator[str, None], + transaction_logger: TransactionLogger, + request_kwargs: Dict[str, Any], + ) -> AsyncGenerator[str, None]: + """ + Wrap a stream to log chunks and final response to TransactionLogger. + + Yields all chunks unchanged while accumulating them for final logging. + + Args: + stream: The SSE stream from wrap_stream + transaction_logger: TransactionLogger instance + request_kwargs: Original request kwargs for context + + Yields: + SSE-formatted strings unchanged + """ + chunks = [] + + async for sse_line in stream: + yield sse_line + + # Parse and accumulate for final logging + if sse_line.startswith("data: ") and not sse_line.startswith( + "data: [DONE]" + ): + try: + content = sse_line[6:].strip() + if content: + chunk_data = json.loads(content) + chunks.append(chunk_data) + transaction_logger.log_stream_chunk(chunk_data) + except json.JSONDecodeError: + lib_logger.debug( + f"Failed to parse chunk for logging: {sse_line[:100]}" + ) + + # Log assembled final response + if chunks: + try: + final_response = TransactionLogger.assemble_streaming_response(chunks) + transaction_logger.log_response(final_response) + except Exception as e: + lib_logger.debug( + f"Failed to assemble/log final streaming response: {e}" + ) diff --git a/src/rotator_library/client/filters.py b/src/rotator_library/client/filters.py new file mode 100644 index 00000000..1165bc5f --- /dev/null +++ b/src/rotator_library/client/filters.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Credential filtering by tier compatibility and priority. + +Extracts the tier filtering logic that was duplicated in client.py +at lines 1242-1315 and 2004-2076. +""" + +import logging +from typing import Any, Dict, List, Optional + +from ..core.types import FilterResult + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialFilter: + """ + Filter and group credentials by tier compatibility and priority. + + This class extracts the credential filtering logic that was previously + duplicated in both _execute_with_retry and _streaming_acompletion_with_retry. + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the CredentialFilter. + + Args: + provider_plugins: Dict mapping provider names to plugin classes/instances + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """ + Get or create a plugin instance for a provider. + + Args: + provider: Provider name + + Returns: + Plugin instance or None if not found + """ + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + # Check if it's a class or already an instance + if isinstance(plugin_class, type): + lib_logger.debug( + f"[CredentialFilter] CREATING NEW INSTANCE for {provider}" + ) + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def filter_by_tier( + self, + credentials: List[str], + model: str, + provider: str, + ) -> FilterResult: + """ + Filter credentials by tier compatibility for a model. + + Args: + credentials: List of credential identifiers + model: Model being requested + provider: Provider name + + Returns: + FilterResult with categorized credentials + """ + plugin = self._get_plugin_instance(provider) + + # Get tier requirement for model + required_tier = None + if plugin and hasattr(plugin, "get_model_tier_requirement"): + required_tier = plugin.get_model_tier_requirement(model) + + compatible: List[str] = [] + unknown: List[str] = [] + incompatible: List[str] = [] + priorities: Dict[str, int] = {} + tier_names: Dict[str, str] = {} + + for cred in credentials: + # Get priority and tier name + priority = None + tier_name = None + + if plugin: + if hasattr(plugin, "get_credential_priority"): + priority = plugin.get_credential_priority(cred) + if hasattr(plugin, "get_credential_tier_name"): + tier_name = plugin.get_credential_tier_name(cred) + + if priority is not None: + priorities[cred] = priority + if tier_name: + tier_names[cred] = tier_name + + # Categorize by tier compatibility + if required_tier is None: + # No tier requirement - all compatible + compatible.append(cred) + elif priority is None: + # Unknown priority - keep as candidate + unknown.append(cred) + elif priority <= required_tier: + # Known compatible (lower priority number = higher tier) + compatible.append(cred) + else: + # Known incompatible + incompatible.append(cred) + + # Log if all credentials are incompatible + if incompatible and not compatible and not unknown: + lib_logger.warning( + f"Model {model} requires tier <= {required_tier}, " + f"but all {len(incompatible)} credentials are incompatible" + ) + + return FilterResult( + compatible=compatible, + unknown=unknown, + incompatible=incompatible, + priorities=priorities, + tier_names=tier_names, + ) + + def group_by_priority( + self, + credentials: List[str], + priorities: Dict[str, int], + ) -> Dict[int, List[str]]: + """ + Group credentials by priority level. + + Args: + credentials: List of credential identifiers + priorities: Dict mapping credentials to priority levels + + Returns: + Dict mapping priority levels to credential lists, sorted by priority + """ + groups: Dict[int, List[str]] = {} + + for cred in credentials: + priority = priorities.get(cred, 999) + if priority not in groups: + groups[priority] = [] + groups[priority].append(cred) + + # Return sorted by priority (lower = higher priority) + return dict(sorted(groups.items())) + + def get_highest_priority_credentials( + self, + credentials: List[str], + priorities: Dict[str, int], + ) -> List[str]: + """ + Get credentials with the highest priority (lowest priority number). + + Args: + credentials: List of credential identifiers + priorities: Dict mapping credentials to priority levels + + Returns: + List of credentials with the highest priority + """ + if not credentials: + return [] + + groups = self.group_by_priority(credentials, priorities) + if not groups: + return credentials + + # Get the lowest priority number (highest priority) + highest_priority = min(groups.keys()) + return groups[highest_priority] diff --git a/src/rotator_library/client/models.py b/src/rotator_library/client/models.py new file mode 100644 index 00000000..27c9015c --- /dev/null +++ b/src/rotator_library/client/models.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Model name resolution and filtering. + +Extracts model-related logic from client.py including: +- _resolve_model_id (lines 867-902) +- _is_model_ignored (lines 587-619) +- _is_model_whitelisted (lines 621-651) +""" + +import fnmatch +import logging +from typing import Any, Dict, List, Optional + +lib_logger = logging.getLogger("rotator_library") + + +class ModelResolver: + """ + Resolve model names and apply filtering rules. + + Handles: + - Model ID resolution (display name -> actual ID) + - Whitelist/blacklist filtering + - Provider prefix handling + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + model_definitions: Optional[Any] = None, + ignore_models: Optional[Dict[str, List[str]]] = None, + whitelist_models: Optional[Dict[str, List[str]]] = None, + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the ModelResolver. + + Args: + provider_plugins: Dict mapping provider names to plugin classes + model_definitions: ModelDefinitions instance for ID mapping + ignore_models: Models to ignore/blacklist per provider + whitelist_models: Models to explicitly whitelist per provider + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + self._definitions = model_definitions + self._ignore = ignore_models or {} + self._whitelist = whitelist_models or {} + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """ + Get or create a plugin instance for a provider. + """ + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def resolve_model_id(self, model: str, provider: str) -> str: + """ + Resolve display name to actual model ID. + + For custom models with name/ID mappings, returns the ID. + Otherwise, returns the model name unchanged. + + Args: + model: Full model string with provider (e.g., "iflow/DS-v3.2") + provider: Provider name (e.g., "iflow") + + Returns: + Full model string with ID (e.g., "iflow/deepseek-v3.2") + """ + model_name = model.split("/")[-1] if "/" in model else model + + # Check provider plugin first + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "model_definitions"): + resolved = plugin.model_definitions.get_model_id(provider, model_name) + if resolved and resolved != model_name: + return f"{provider}/{resolved}" + + # Fallback to client-level definitions + if self._definitions: + resolved = self._definitions.get_model_id(provider, model_name) + if resolved and resolved != model_name: + return f"{provider}/{resolved}" + + return model + + def is_model_allowed(self, model: str, provider: str) -> bool: + """ + Check if model passes whitelist/blacklist filters. + + Whitelist takes precedence over blacklist. + + Args: + model: Model string (with or without provider prefix) + provider: Provider name + + Returns: + True if model is allowed, False if blocked + """ + # Whitelist takes precedence + if self._is_whitelisted(model, provider): + return True + + # Then check blacklist + if self._is_blacklisted(model, provider): + return False + + return True + + def _is_blacklisted(self, model: str, provider: str) -> bool: + """ + Check if model is blacklisted. + + Supports glob patterns: + - "gpt-4" - exact match + - "gpt-4*" - prefix wildcard + - "*-preview" - suffix wildcard + - "*" - match all + + Args: + model: Model string + provider: Provider name (used to get ignore list) + + Returns: + True if model is blacklisted + """ + model_provider = model.split("/")[0] if "/" in model else provider + + if model_provider not in self._ignore: + return False + + ignore_list = self._ignore[model_provider] + if ignore_list == ["*"]: + return True + + # Extract model name without provider prefix + model_name = model.split("/", 1)[1] if "/" in model else model + + for pattern in ignore_list: + # Use fnmatch for glob pattern support + if fnmatch.fnmatch(model_name, pattern): + return True + if fnmatch.fnmatch(model, pattern): + return True + + return False + + def _is_whitelisted(self, model: str, provider: str) -> bool: + """ + Check if model is whitelisted. + + Same pattern support as blacklist. + + Args: + model: Model string + provider: Provider name + + Returns: + True if model is whitelisted + """ + model_provider = model.split("/")[0] if "/" in model else provider + + if model_provider not in self._whitelist: + return False + + whitelist = self._whitelist[model_provider] + model_name = model.split("/", 1)[1] if "/" in model else model + + for pattern in whitelist: + if fnmatch.fnmatch(model_name, pattern): + return True + if fnmatch.fnmatch(model, pattern): + return True + + return False + + @staticmethod + def extract_provider(model: str) -> str: + """ + Extract provider name from model string. + + Args: + model: Model string (e.g., "openai/gpt-4") + + Returns: + Provider name (e.g., "openai") or empty string if no prefix + """ + return model.split("/")[0] if "/" in model else "" + + @staticmethod + def strip_provider(model: str) -> str: + """ + Strip provider prefix from model string. + + Args: + model: Model string (e.g., "openai/gpt-4") + + Returns: + Model name without prefix (e.g., "gpt-4") + """ + return model.split("/", 1)[1] if "/" in model else model + + @staticmethod + def ensure_provider_prefix(model: str, provider: str) -> str: + """ + Ensure model string has provider prefix. + + Args: + model: Model string + provider: Provider name to add if missing + + Returns: + Model string with provider prefix + """ + if "/" in model: + return model + return f"{provider}/{model}" diff --git a/src/rotator_library/client/rotating_client.py b/src/rotator_library/client/rotating_client.py new file mode 100644 index 00000000..a5cee0fc --- /dev/null +++ b/src/rotator_library/client/rotating_client.py @@ -0,0 +1,904 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Slim RotatingClient facade. + +This is a lightweight facade that delegates to extracted components: +- RequestExecutor: Unified retry/rotation logic +- CredentialFilter: Tier compatibility filtering +- ModelResolver: Model name resolution and filtering +- ProviderTransforms: Provider-specific request mutations +- StreamingHandler: Streaming response processing + +The original client.py was ~3000 lines. This facade is ~300 lines, +with all complexity moved to specialized modules. +""" + +import asyncio +import json +import logging +import os +import random +import time +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, TYPE_CHECKING + +import httpx +import litellm +from litellm.litellm_core_utils.token_counter import token_counter + +from ..core.types import RequestContext +from ..core.errors import NoAvailableKeysError, mask_credential +from ..core.config import ConfigLoader +from ..core.constants import ( + DEFAULT_MAX_RETRIES, + DEFAULT_GLOBAL_TIMEOUT, + DEFAULT_ROTATION_TOLERANCE, +) + +from .filters import CredentialFilter +from .models import ModelResolver +from .transforms import ProviderTransforms +from .executor import RequestExecutor +from .anthropic import AnthropicHandler + +# Import providers and other dependencies +from ..providers import PROVIDER_PLUGINS +from ..cooldown_manager import CooldownManager +from ..credential_manager import CredentialManager +from ..background_refresher import BackgroundRefresher +from ..model_definitions import ModelDefinitions +from ..transaction_logger import TransactionLogger +from ..provider_config import ProviderConfig as LiteLLMProviderConfig +from ..utils.paths import get_default_root, get_logs_dir, get_oauth_dir +from ..utils.suppress_litellm_warnings import suppress_litellm_serialization_warnings +from ..failure_logger import configure_failure_logger + +# Import new usage package +from ..usage import UsageManager as NewUsageManager +from ..usage.config import load_provider_usage_config, WindowDefinition + +if TYPE_CHECKING: + from ..anthropic_compat import AnthropicMessagesRequest, AnthropicCountTokensRequest + +lib_logger = logging.getLogger("rotator_library") + + +class RotatingClient: + """ + A client that intelligently rotates and retries API keys using LiteLLM, + with support for both streaming and non-streaming responses. + + This is a slim facade that delegates to specialized components: + - RequestExecutor: Handles retry/rotation logic + - CredentialFilter: Filters credentials by tier + - ModelResolver: Resolves model names + - ProviderTransforms: Applies provider-specific transforms + """ + + def __init__( + self, + api_keys: Optional[Dict[str, List[str]]] = None, + oauth_credentials: Optional[Dict[str, List[str]]] = None, + max_retries: int = DEFAULT_MAX_RETRIES, + usage_file_path: Optional[Union[str, Path]] = None, + configure_logging: bool = True, + global_timeout: int = DEFAULT_GLOBAL_TIMEOUT, + abort_on_callback_error: bool = True, + litellm_provider_params: Optional[Dict[str, Any]] = None, + ignore_models: Optional[Dict[str, List[str]]] = None, + whitelist_models: Optional[Dict[str, List[str]]] = None, + enable_request_logging: bool = False, + max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, + rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE, + data_dir: Optional[Union[str, Path]] = None, + ): + """ + Initialize the RotatingClient. + + See original client.py for full parameter documentation. + """ + # Resolve data directory + self.data_dir = Path(data_dir).resolve() if data_dir else get_default_root() + + # Configure logging + configure_failure_logger(get_logs_dir(self.data_dir)) + os.environ["LITELLM_LOG"] = "ERROR" + litellm.set_verbose = False + litellm.drop_params = True + suppress_litellm_serialization_warnings() + + if configure_logging: + lib_logger.propagate = True + if lib_logger.hasHandlers(): + lib_logger.handlers.clear() + lib_logger.addHandler(logging.NullHandler()) + else: + lib_logger.propagate = False + + # Process credentials + api_keys = api_keys or {} + oauth_credentials = oauth_credentials or {} + api_keys = {p: k for p, k in api_keys.items() if k} + oauth_credentials = {p: c for p, c in oauth_credentials.items() if c} + + if not api_keys and not oauth_credentials: + lib_logger.warning( + "No provider credentials configured. Client will be unable to make requests." + ) + + # Discover OAuth credentials if not provided + if oauth_credentials: + self.oauth_credentials = oauth_credentials + else: + cred_manager = CredentialManager( + os.environ, oauth_dir=get_oauth_dir(self.data_dir) + ) + self.oauth_credentials = cred_manager.discover_and_prepare() + + # Build combined credentials + self.all_credentials: Dict[str, List[str]] = {} + for provider, keys in api_keys.items(): + self.all_credentials.setdefault(provider, []).extend(keys) + for provider, paths in self.oauth_credentials.items(): + self.all_credentials.setdefault(provider, []).extend(paths) + + self.api_keys = api_keys + self.oauth_providers = set(self.oauth_credentials.keys()) + + # Store configuration + self.max_retries = max_retries + self.global_timeout = global_timeout + self.abort_on_callback_error = abort_on_callback_error + self.litellm_provider_params = litellm_provider_params or {} + self._litellm_logger_fn = self._litellm_logger_callback + self.enable_request_logging = enable_request_logging + self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {} + + # Validate concurrent requests config + for provider, max_val in self.max_concurrent_requests_per_key.items(): + if max_val < 1: + lib_logger.warning( + f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1." + ) + self.max_concurrent_requests_per_key[provider] = 1 + + # Initialize configuration loader + self._config_loader = ConfigLoader(PROVIDER_PLUGINS) + + # Initialize components + self._provider_plugins = PROVIDER_PLUGINS + self._provider_instances: Dict[str, Any] = {} + + # Initialize managers + self.cooldown_manager = CooldownManager() + self.background_refresher = BackgroundRefresher(self) + self.model_definitions = ModelDefinitions() + self.provider_config = LiteLLMProviderConfig() + self.http_client = httpx.AsyncClient() + + # Initialize extracted components + self._credential_filter = CredentialFilter( + PROVIDER_PLUGINS, + provider_instances=self._provider_instances, + ) + self._model_resolver = ModelResolver( + PROVIDER_PLUGINS, + self.model_definitions, + ignore_models or {}, + whitelist_models or {}, + provider_instances=self._provider_instances, + ) + self._provider_transforms = ProviderTransforms( + PROVIDER_PLUGINS, + self.provider_config, + provider_instances=self._provider_instances, + ) + + # Initialize UsageManagers (one per provider) using new usage package + self._usage_managers: Dict[str, NewUsageManager] = {} + + # Resolve usage file path base + if usage_file_path: + base_path = Path(usage_file_path) + if base_path.suffix: + base_path = base_path.parent + self._usage_base_path = base_path / "usage" + else: + self._usage_base_path = self.data_dir / "usage" + self._usage_base_path.mkdir(parents=True, exist_ok=True) + + # Build provider configs using ConfigLoader + provider_configs = {} + for provider in self.all_credentials.keys(): + provider_configs[provider] = self._config_loader.load_provider_config( + provider + ) + + # Create UsageManager for each provider + for provider, credentials in self.all_credentials.items(): + config = load_provider_usage_config(provider, PROVIDER_PLUGINS) + # Override tolerance from constructor param + config.rotation_tolerance = rotation_tolerance + + self._apply_usage_reset_config(provider, credentials, config) + + usage_file = self._usage_base_path / f"usage_{provider}.json" + + # Get max concurrent for this provider (default to 1 if not set) + max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) + + manager = NewUsageManager( + provider=provider, + file_path=usage_file, + provider_plugins=PROVIDER_PLUGINS, + config=config, + max_concurrent_per_key=max_concurrent, + ) + self._usage_managers[provider] = manager + + # Initialize executor with new usage managers + self._executor = RequestExecutor( + usage_managers=self._usage_managers, + cooldown_manager=self.cooldown_manager, + credential_filter=self._credential_filter, + provider_transforms=self._provider_transforms, + provider_plugins=PROVIDER_PLUGINS, + http_client=self.http_client, + max_retries=max_retries, + global_timeout=global_timeout, + abort_on_callback_error=abort_on_callback_error, + litellm_provider_params=self.litellm_provider_params, + litellm_logger_fn=self._litellm_logger_fn, + provider_instances=self._provider_instances, + ) + + self._model_list_cache: Dict[str, List[str]] = {} + self._usage_initialized = False + self._usage_init_lock = asyncio.Lock() + + # Initialize Anthropic compatibility handler + self._anthropic_handler = AnthropicHandler(self) + + async def __aenter__(self): + await self.initialize_usage_managers() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def initialize_usage_managers(self) -> None: + """Initialize usage managers once before background jobs run.""" + if self._usage_initialized: + return + async with self._usage_init_lock: + if self._usage_initialized: + return + for provider, manager in self._usage_managers.items(): + credentials = self.all_credentials.get(provider, []) + priorities, tiers = self._get_credential_metadata(provider, credentials) + await manager.initialize( + credentials, priorities=priorities, tiers=tiers + ) + summaries = [] + for provider, manager in self._usage_managers.items(): + credentials = self.all_credentials.get(provider, []) + status = ( + f"loaded {manager.loaded_credentials}" + if manager.loaded_from_storage + else "fresh" + ) + summaries.append(f"{provider}:{len(credentials)} ({status})") + if summaries: + lib_logger.info( + f"Usage managers initialized: {', '.join(sorted(summaries))}" + ) + self._usage_initialized = True + + async def close(self): + """Close the HTTP client and save usage data.""" + # Save and shutdown new usage managers + for manager in self._usage_managers.values(): + await manager.shutdown() + + if hasattr(self, "http_client") and self.http_client: + await self.http_client.aclose() + + async def acompletion( + self, + request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + **kwargs, + ) -> Union[Any, AsyncGenerator[str, None]]: + """ + Dispatcher for completion requests. + + Returns: + Response object or async generator for streaming + """ + model = kwargs.get("model", "") + provider = model.split("/")[0] if "/" in model else "" + + if not provider or provider not in self.all_credentials: + raise ValueError( + f"Invalid model format or no credentials for provider: {model}" + ) + + # Extract internal logging parameters (not passed to API) + parent_log_dir = kwargs.pop("_parent_log_dir", None) + + # Resolve model ID + resolved_model = self._model_resolver.resolve_model_id(model, provider) + kwargs["model"] = resolved_model + + # Create transaction logger if enabled + transaction_logger = None + if self.enable_request_logging: + transaction_logger = TransactionLogger( + provider=provider, + model=resolved_model, + enabled=True, + parent_dir=parent_log_dir, + ) + transaction_logger.log_request(kwargs) + + # Build request context + context = RequestContext( + model=resolved_model, + provider=provider, + kwargs=kwargs, + streaming=kwargs.get("stream", False), + credentials=self.all_credentials.get(provider, []), + deadline=time.time() + self.global_timeout, + request=request, + pre_request_callback=pre_request_callback, + transaction_logger=transaction_logger, + ) + + return await self._executor.execute(context) + + def aembedding( + self, + request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + **kwargs, + ) -> Any: + """ + Execute an embedding request with retry logic. + """ + model = kwargs.get("model", "") + provider = model.split("/")[0] if "/" in model else "" + + if not provider or provider not in self.all_credentials: + raise ValueError( + f"Invalid model format or no credentials for provider: {model}" + ) + + # Build request context (embeddings are never streaming) + context = RequestContext( + model=model, + provider=provider, + kwargs=kwargs, + streaming=False, + credentials=self.all_credentials.get(provider, []), + deadline=time.time() + self.global_timeout, + request=request, + pre_request_callback=pre_request_callback, + ) + + return self._executor.execute(context) + + def token_count(self, **kwargs) -> int: + """Calculate token count for text or messages. + + For Antigravity provider models, this also includes the preprompt tokens + that get injected during actual API calls (agent instruction + identity override). + This ensures token counts match actual usage. + """ + model = kwargs.get("model") + text = kwargs.get("text") + messages = kwargs.get("messages") + + if not model: + raise ValueError("'model' is required") + + # Calculate base token count + if messages: + base_count = token_counter(model=model, messages=messages) + elif text: + base_count = token_counter(model=model, text=text) + else: + raise ValueError("Either 'text' or 'messages' must be provided") + + # Add preprompt tokens for Antigravity provider + # The Antigravity provider injects system instructions during actual API calls, + # so we need to account for those tokens in the count + provider = model.split("/")[0] if "/" in model else "" + if provider == "antigravity": + try: + from ..providers.antigravity_provider import ( + get_antigravity_preprompt_text, + ) + + preprompt_text = get_antigravity_preprompt_text() + if preprompt_text: + preprompt_tokens = token_counter(model=model, text=preprompt_text) + base_count += preprompt_tokens + except ImportError: + # Provider not available, skip preprompt token counting + pass + + return base_count + + async def get_available_models(self, provider: str) -> List[str]: + """Get available models for a provider with caching.""" + if provider in self._model_list_cache: + return self._model_list_cache[provider] + + credentials = self.all_credentials.get(provider, []) + if not credentials: + return [] + + # Shuffle and try each credential + shuffled = list(credentials) + random.shuffle(shuffled) + + plugin = self._get_provider_instance(provider) + if not plugin: + return [] + + for cred in shuffled: + try: + models = await plugin.get_models(cred, self.http_client) + + # Apply whitelist/blacklist + final = [ + m + for m in models + if self._model_resolver.is_model_allowed(m, provider) + ] + + self._model_list_cache[provider] = final + return final + + except Exception as e: + lib_logger.debug( + f"Failed to get models for {provider} with {mask_credential(cred)}: {e}" + ) + continue + + return [] + + async def get_all_available_models( + self, + grouped: bool = True, + ) -> Union[Dict[str, List[str]], List[str]]: + """Get all available models across all providers.""" + providers = list(self.all_credentials.keys()) + tasks = [self.get_available_models(p) for p in providers] + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_models: Dict[str, List[str]] = {} + for provider, result in zip(providers, results): + if isinstance(result, Exception): + lib_logger.error(f"Failed to get models for {provider}: {result}") + all_models[provider] = [] + else: + all_models[provider] = result + + if grouped: + return all_models + else: + flat = [] + for models in all_models.values(): + flat.extend(models) + return flat + + async def get_quota_stats( + self, + provider_filter: Optional[str] = None, + ) -> Dict[str, Any]: + """Get quota and usage stats for all credentials. + + Args: + provider_filter: Optional provider name to filter results + + Returns: + Dict with stats per provider + """ + providers = {} + + for provider, manager in self._usage_managers.items(): + if provider_filter and provider != provider_filter: + continue + + stats = await manager.get_stats_for_endpoint() + + # Skip providers with no activity (filters out invalid/unused providers) + if stats.get("total_requests", 0) == 0: + continue + + providers[provider] = stats + + summary = { + "total_providers": len(providers), + "total_credentials": 0, + "active_credentials": 0, + "exhausted_credentials": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_total_cost": None, + } + + for prov in providers.values(): + summary["total_credentials"] += prov.get("credential_count", 0) + summary["active_credentials"] += prov.get("active_count", 0) + summary["exhausted_credentials"] += prov.get("exhausted_count", 0) + summary["total_requests"] += prov.get("total_requests", 0) + tokens = prov.get("tokens", {}) + summary["tokens"]["input_cached"] += tokens.get("input_cached", 0) + summary["tokens"]["input_uncached"] += tokens.get("input_uncached", 0) + summary["tokens"]["output"] += tokens.get("output", 0) + + total_input = ( + summary["tokens"]["input_cached"] + summary["tokens"]["input_uncached"] + ) + summary["tokens"]["input_cache_pct"] = ( + round(summary["tokens"]["input_cached"] / total_input * 100, 1) + if total_input > 0 + else 0 + ) + + approx_total_cost = 0.0 + has_cost = False + for prov in providers.values(): + cost = prov.get("approx_cost") + if cost: + approx_total_cost += cost + has_cost = True + summary["approx_total_cost"] = approx_total_cost if has_cost else None + + return { + "providers": providers, + "summary": summary, + "data_source": "cache", + "timestamp": time.time(), + } + + def get_oauth_credentials(self) -> Dict[str, List[str]]: + """Get discovered OAuth credentials.""" + return self.oauth_credentials + + def _get_provider_instance(self, provider: str) -> Optional[Any]: + """Get or create a provider plugin instance.""" + if provider not in self.all_credentials: + return None + + if provider not in self._provider_instances: + plugin_class = self._provider_plugins.get(provider) + if plugin_class: + self._provider_instances[provider] = plugin_class() + else: + return None + + return self._provider_instances[provider] + + def _get_credential_metadata( + self, + provider: str, + credentials: List[str], + ) -> tuple[Dict[str, int], Dict[str, str]]: + """Resolve priority and tier metadata for credentials.""" + plugin = self._get_provider_instance(provider) + priorities: Dict[str, int] = {} + tiers: Dict[str, str] = {} + + if not plugin: + return priorities, tiers + + for credential in credentials: + if hasattr(plugin, "get_credential_priority"): + priority = plugin.get_credential_priority(credential) + if priority is not None: + priorities[credential] = priority + if hasattr(plugin, "get_credential_tier_name"): + tier_name = plugin.get_credential_tier_name(credential) + if tier_name: + tiers[credential] = tier_name + + return priorities, tiers + + def get_usage_manager(self, provider: str) -> Optional[NewUsageManager]: + """ + Get the new UsageManager for a specific provider. + + Args: + provider: Provider name + + Returns: + UsageManager for the provider, or None if not found + """ + return self._usage_managers.get(provider) + + @property + def usage_managers(self) -> Dict[str, NewUsageManager]: + """Get all new usage managers.""" + return self._usage_managers + + def _apply_usage_reset_config( + self, + provider: str, + credentials: List[str], + config: Any, + ) -> None: + """Apply provider-specific usage reset config to window definitions.""" + if not credentials: + return + + plugin = self._get_provider_instance(provider) + if not plugin or not hasattr(plugin, "get_usage_reset_config"): + return + + try: + reset_config = plugin.get_usage_reset_config(credentials[0]) + except Exception as exc: + lib_logger.debug(f"Failed to load usage reset config for {provider}: {exc}") + return + + if not reset_config: + return + + window_seconds = reset_config.get("window_seconds") + if not window_seconds: + return + + mode = reset_config.get("mode", "credential") + applies_to = "credential" if mode == "credential" else "model" + + if window_seconds == 86400: + window_name = "daily" + elif window_seconds % 3600 == 0: + window_name = f"{window_seconds // 3600}h" + else: + window_name = "window" + + config.windows = [ + WindowDefinition.rolling( + name=window_name, + duration_seconds=int(window_seconds), + is_primary=True, + applies_to=applies_to, + ), + ] + + def _sanitize_litellm_log(self, log_data: dict) -> dict: + """Remove large/sensitive fields from LiteLLM logs.""" + if not isinstance(log_data, dict): + return log_data + + keys_to_pop = [ + "messages", + "input", + "response", + "data", + "api_key", + "api_base", + "original_response", + "additional_args", + ] + nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"] + + clean_data = json.loads(json.dumps(log_data, default=str)) + + def clean_recursively(data_dict: dict) -> None: + for key in keys_to_pop: + data_dict.pop(key, None) + for key in nested_keys: + if key in data_dict and isinstance(data_dict[key], dict): + clean_recursively(data_dict[key]) + for value in list(data_dict.values()): + if isinstance(value, dict): + clean_recursively(value) + + clean_recursively(clean_data) + return clean_data + + def _litellm_logger_callback(self, log_data: dict) -> None: + """Redirect LiteLLM logs into rotator library logger.""" + log_event_type = log_data.get("log_event_type") + if log_event_type in ["pre_api_call", "post_api_call"]: + return + + if not log_data.get("exception"): + sanitized_log = self._sanitize_litellm_log(log_data) + lib_logger.debug(f"LiteLLM Log: {sanitized_log}") + return + + model = log_data.get("model", "N/A") + error_info = log_data.get("standard_logging_object", {}).get( + "error_information", {} + ) + error_class = error_info.get("error_class", "UnknownError") + error_message = error_info.get( + "error_message", str(log_data.get("exception", "")) + ) + error_message = " ".join(error_message.split()) + + lib_logger.debug( + f"LiteLLM Callback Handled Error: Model={model} | " + f"Type={error_class} | Message='{error_message}'" + ) + + # ========================================================================= + # USAGE MANAGEMENT METHODS + # ========================================================================= + + async def reload_usage_from_disk(self) -> None: + """ + Force reload usage data from disk. + + Useful when wanting fresh stats without making external API calls. + """ + for manager in self._usage_managers.values(): + await manager.reload_from_disk() + + async def force_refresh_quota( + self, + provider: Optional[str] = None, + credential: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Force refresh quota from external API. + + For Antigravity, this fetches live quota data from the API. + For other providers, this is a no-op (just reloads from disk). + + Args: + provider: If specified, only refresh this provider + credential: If specified, only refresh this specific credential + + Returns: + Refresh result dict with success/failure info + """ + result = { + "action": "force_refresh", + "scope": "credential" + if credential + else ("provider" if provider else "all"), + "provider": provider, + "credential": credential, + "credentials_refreshed": 0, + "success_count": 0, + "failed_count": 0, + "duration_ms": 0, + "errors": [], + } + + start_time = time.time() + + # Determine which providers to refresh + if provider: + providers_to_refresh = ( + [provider] if provider in self.all_credentials else [] + ) + else: + providers_to_refresh = list(self.all_credentials.keys()) + + for prov in providers_to_refresh: + provider_class = self._provider_plugins.get(prov) + if not provider_class: + continue + + # Get or create provider instance + provider_instance = self._get_provider_instance(prov) + if not provider_instance: + continue + + # Check if provider supports quota refresh (like Antigravity) + if hasattr(provider_instance, "fetch_initial_baselines"): + # Get credentials to refresh + if credential: + # Find full path for this credential + creds_to_refresh = [] + for cred_path in self.all_credentials.get(prov, []): + if cred_path.endswith(credential) or cred_path == credential: + creds_to_refresh.append(cred_path) + break + else: + creds_to_refresh = self.all_credentials.get(prov, []) + + if not creds_to_refresh: + continue + + try: + # Fetch live quota from API for ALL specified credentials + quota_results = await provider_instance.fetch_initial_baselines( + creds_to_refresh + ) + + # Store baselines in usage manager + usage_manager = self._usage_managers.get(prov) + if usage_manager and hasattr( + provider_instance, "_store_baselines_to_usage_manager" + ): + stored = await provider_instance._store_baselines_to_usage_manager( + quota_results, + usage_manager, + force=True, + is_initial_fetch=True, # Manual refresh checks exhaustion + ) + result["success_count"] += stored + + result["credentials_refreshed"] += len(creds_to_refresh) + + # Count failures + for cred_path, data in quota_results.items(): + if data.get("status") != "success": + result["failed_count"] += 1 + result["errors"].append( + f"{Path(cred_path).name}: {data.get('error', 'Unknown error')}" + ) + + except Exception as e: + lib_logger.error(f"Failed to refresh quota for {prov}: {e}") + result["errors"].append(f"{prov}: {str(e)}") + result["failed_count"] += len(creds_to_refresh) + + result["duration_ms"] = int((time.time() - start_time) * 1000) + return result + + # ========================================================================= + # ANTHROPIC API COMPATIBILITY METHODS + # ========================================================================= + + async def anthropic_messages( + self, + request: "AnthropicMessagesRequest", + raw_request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + ) -> Any: + """ + Handle Anthropic Messages API requests. + + This method accepts requests in Anthropic's format, translates them to + OpenAI format internally, processes them through the existing acompletion + method, and returns responses in Anthropic's format. + + Args: + request: An AnthropicMessagesRequest object + raw_request: Optional raw request object for disconnect checks + pre_request_callback: Optional async callback before each API request + + Returns: + For non-streaming: dict in Anthropic Messages format + For streaming: AsyncGenerator yielding Anthropic SSE format strings + """ + return await self._anthropic_handler.messages( + request=request, + raw_request=raw_request, + pre_request_callback=pre_request_callback, + ) + + async def anthropic_count_tokens( + self, + request: "AnthropicCountTokensRequest", + ) -> dict: + """ + Handle Anthropic count_tokens API requests. + + Counts the number of tokens that would be used by a Messages API request. + This is useful for estimating costs and managing context windows. + + Args: + request: An AnthropicCountTokensRequest object + + Returns: + Dict with input_tokens count in Anthropic format + """ + return await self._anthropic_handler.count_tokens(request=request) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py new file mode 100644 index 00000000..345c0c4b --- /dev/null +++ b/src/rotator_library/client/streaming.py @@ -0,0 +1,447 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Streaming response handler. + +Extracts streaming logic from client.py _safe_streaming_wrapper (lines 904-1117). +Handles: +- Chunk processing with finish_reason logic +- JSON reassembly for fragmented responses +- Error detection in streamed data +- Usage tracking from final chunks +- Client disconnect handling +""" + +import codecs +import json +import logging +import re +from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional, TYPE_CHECKING + +import litellm + +from ..core.errors import StreamedAPIError, CredentialNeedsReauthError +from ..core.types import ProcessedChunk + +if TYPE_CHECKING: + from ..usage.manager import CredentialContext + +lib_logger = logging.getLogger("rotator_library") + + +class StreamingHandler: + """ + Process streaming responses with error handling and usage tracking. + + This class extracts the streaming logic that was in _safe_streaming_wrapper + and provides a clean interface for processing LiteLLM streams. + + Usage recording is handled via CredentialContext passed to wrap_stream(). + """ + + async def wrap_stream( + self, + stream: AsyncIterator[Any], + credential: str, + model: str, + request: Optional[Any] = None, + cred_context: Optional["CredentialContext"] = None, + skip_cost_calculation: bool = False, + ) -> AsyncGenerator[str, None]: + """ + Wrap a LiteLLM stream with error handling and usage tracking. + + FINISH_REASON HANDLING: + - Strip finish_reason from intermediate chunks (litellm defaults to "stop") + - Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop + - Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) + + Args: + stream: The async iterator from LiteLLM + credential: Credential identifier (for logging) + model: Model name for usage recording + request: Optional FastAPI request for disconnect detection + cred_context: CredentialContext for marking success/failure + + Yields: + SSE-formatted strings: "data: {...}\\n\\n" + """ + stream_completed = False + error_buffer = StreamBuffer() # Use StreamBuffer for JSON reassembly + accumulated_finish_reason: Optional[str] = None + has_tool_calls = False + prompt_tokens = 0 + prompt_tokens_cached = 0 + prompt_tokens_cache_write = 0 + prompt_tokens_uncached = 0 + completion_tokens = 0 + thinking_tokens = 0 + + # Use manual iteration to allow continue after partial JSON errors + stream_iterator = stream.__aiter__() + + try: + while True: + try: + # Check client disconnect before waiting for next chunk + if request and await request.is_disconnected(): + lib_logger.info( + f"Client disconnected. Aborting stream for model {model}." + ) + break + + chunk = await stream_iterator.__anext__() + + # Clear error buffer on successful chunk receipt + error_buffer.reset() + + # Process chunk + processed = self._process_chunk( + chunk, + accumulated_finish_reason, + has_tool_calls, + ) + + # Update tracking state + if processed.has_tool_calls: + has_tool_calls = True + accumulated_finish_reason = "tool_calls" + if processed.finish_reason and not has_tool_calls: + # Only update if not already tool_calls (highest priority) + accumulated_finish_reason = processed.finish_reason + if processed.usage and isinstance(processed.usage, dict): + # Extract token counts from final chunk + prompt_tokens = processed.usage.get("prompt_tokens", 0) + completion_tokens = processed.usage.get("completion_tokens", 0) + prompt_details = processed.usage.get("prompt_tokens_details") + if prompt_details: + if isinstance(prompt_details, dict): + prompt_tokens_cached = ( + prompt_details.get("cached_tokens", 0) or 0 + ) + prompt_tokens_cache_write = ( + prompt_details.get("cache_creation_tokens", 0) or 0 + ) + else: + prompt_tokens_cached = ( + getattr(prompt_details, "cached_tokens", 0) or 0 + ) + prompt_tokens_cache_write = ( + getattr(prompt_details, "cache_creation_tokens", 0) + or 0 + ) + completion_details = processed.usage.get( + "completion_tokens_details" + ) + if completion_details: + if isinstance(completion_details, dict): + thinking_tokens = ( + completion_details.get("reasoning_tokens", 0) or 0 + ) + else: + thinking_tokens = ( + getattr(completion_details, "reasoning_tokens", 0) + or 0 + ) + if processed.usage.get("cache_read_tokens") is not None: + prompt_tokens_cached = ( + processed.usage.get("cache_read_tokens") or 0 + ) + if processed.usage.get("cache_creation_tokens") is not None: + prompt_tokens_cache_write = ( + processed.usage.get("cache_creation_tokens") or 0 + ) + if thinking_tokens and completion_tokens >= thinking_tokens: + completion_tokens = completion_tokens - thinking_tokens + prompt_tokens_uncached = max( + 0, prompt_tokens - prompt_tokens_cached + ) + + yield processed.sse_string + + except StopAsyncIteration: + # Stream ended normally + stream_completed = True + break + + except CredentialNeedsReauthError as e: + # Credential needs re-auth - wrap for outer retry loop + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError("Credential needs re-authentication", data=e) + + except json.JSONDecodeError as e: + # Partial JSON - accumulate and continue + error_buffer.append(str(e)) + if error_buffer.is_complete: + # We have complete JSON now + raise StreamedAPIError( + "Provider error", data=error_buffer.content + ) + # Continue waiting for more chunks + continue + + except Exception as e: + # Try to extract JSON from fragmented response + error_str = str(e) + error_buffer.append(error_str) + + # Check if buffer now has complete JSON + if error_buffer.is_complete: + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError( + "Provider error in stream", data=error_buffer.content + ) + + # Try pattern matching for error extraction + extracted = self._try_extract_error(e, error_buffer.content) + if extracted: + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError( + "Provider error in stream", data=extracted + ) + + # Not a JSON-related error, re-raise + raise + + except StreamedAPIError: + # Re-raise for retry loop + raise + + finally: + # Record usage if stream completed + if stream_completed: + if cred_context: + approx_cost = 0.0 + if not skip_cost_calculation: + approx_cost = self._calculate_stream_cost( + model, + prompt_tokens_uncached + prompt_tokens_cached, + completion_tokens + thinking_tokens, + ) + cred_context.mark_success( + prompt_tokens=prompt_tokens_uncached, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cached, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Yield [DONE] for completed streams + yield "data: [DONE]\n\n" + + def _process_chunk( + self, + chunk: Any, + accumulated_finish_reason: Optional[str], + has_tool_calls: bool, + ) -> ProcessedChunk: + """ + Process a single streaming chunk. + + Handles finish_reason logic: + - Strip from intermediate chunks + - Apply correct finish_reason on final chunk + + Args: + chunk: Raw chunk from LiteLLM + accumulated_finish_reason: Current accumulated finish reason + has_tool_calls: Whether any chunk has had tool_calls + + Returns: + ProcessedChunk with SSE string and metadata + """ + # Convert chunk to dict + if hasattr(chunk, "model_dump"): + chunk_dict = chunk.model_dump() + elif hasattr(chunk, "dict"): + chunk_dict = chunk.dict() + else: + chunk_dict = chunk + + # Extract metadata before modifying + usage = chunk_dict.get("usage") + finish_reason = None + chunk_has_tool_calls = False + + if "choices" in chunk_dict and chunk_dict["choices"]: + choice = chunk_dict["choices"][0] + delta = choice.get("delta", {}) + + # Check for tool_calls + if delta.get("tool_calls"): + chunk_has_tool_calls = True + + # Get source finish_reason before we potentially modify it + source_finish_reason = choice.get("finish_reason") + + # Detect final chunk using multiple signals: + # 1. Primary: has usage with any meaningful token count > 0 + # 2. Secondary: has usage (even empty) + source has finish_reason (Fallback case) + has_meaningful_usage = ( + usage + and isinstance(usage, dict) + and any( + usage.get(k, 0) > 0 + for k in [ + "completion_tokens", + "prompt_tokens", + "total_tokens", + "reasoning_tokens", + ] + ) + ) + has_usage_with_finish = ( + usage is not None + and isinstance(usage, dict) + and source_finish_reason is not None + ) + is_final_chunk = has_meaningful_usage or has_usage_with_finish + + if is_final_chunk: + # FINAL CHUNK: Determine correct finish_reason + # Priority: tool_calls > source_finish_reason > accumulated > "stop" + if has_tool_calls or chunk_has_tool_calls: + choice["finish_reason"] = "tool_calls" + elif source_finish_reason: + choice["finish_reason"] = source_finish_reason + elif accumulated_finish_reason: + choice["finish_reason"] = accumulated_finish_reason + else: + choice["finish_reason"] = "stop" + finish_reason = choice["finish_reason"] + else: + # INTERMEDIATE CHUNK: Never emit finish_reason + choice["finish_reason"] = None + + return ProcessedChunk( + sse_string=f"data: {json.dumps(chunk_dict)}\n\n", + usage=usage, + finish_reason=finish_reason, + has_tool_calls=chunk_has_tool_calls, + ) + + def _try_extract_error( + self, + exception: Exception, + buffer: str, + ) -> Optional[Dict]: + """ + Try to extract error JSON from exception or buffer. + + Handles multiple error formats: + - Google-style bytes representation: b'{...}' + - "Received chunk:" prefix + - JSON in buffer accumulation + + Args: + exception: The caught exception + buffer: Current JSON buffer content + + Returns: + Parsed error dict or None + """ + error_str = str(exception) + + # Pattern 1: Google-style bytes representation + match = re.search(r"b'(\{.*\})'", error_str, re.DOTALL) + if match: + try: + decoded = codecs.decode(match.group(1), "unicode_escape") + return json.loads(decoded) + except (json.JSONDecodeError, ValueError): + pass + + # Pattern 2: "Received chunk:" prefix + if "Received chunk:" in error_str: + chunk = error_str.split("Received chunk:")[-1].strip() + try: + return json.loads(chunk) + except json.JSONDecodeError: + pass + + # Pattern 3: Buffer accumulation + if buffer: + try: + return json.loads(buffer) + except json.JSONDecodeError: + pass + + return None + + def _calculate_stream_cost( + self, + model: str, + prompt_tokens: int, + completion_tokens: int, + ) -> float: + try: + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token") + output_cost = model_info.get("output_cost_per_token") + total_cost = 0.0 + if input_cost: + total_cost += prompt_tokens * input_cost + if output_cost: + total_cost += completion_tokens * output_cost + return total_cost + except Exception as exc: + lib_logger.debug(f"Stream cost calculation failed for {model}: {exc}") + return 0.0 + + +class StreamBuffer: + """ + Buffer for reassembling fragmented JSON in streams. + + Some providers send JSON split across multiple chunks, especially + for error responses. This class handles accumulation and parsing. + """ + + def __init__(self): + self._buffer = "" + self._complete = False + + def append(self, chunk: str) -> Optional[Dict]: + """ + Append a chunk and try to parse. + + Args: + chunk: Raw chunk string + + Returns: + Parsed dict if complete, None if still accumulating + """ + self._buffer += chunk + + try: + result = json.loads(self._buffer) + self._complete = True + return result + except json.JSONDecodeError: + return None + + def reset(self) -> None: + """Reset the buffer.""" + self._buffer = "" + self._complete = False + + @property + def content(self) -> str: + """Get current buffer content.""" + return self._buffer + + @property + def is_complete(self) -> bool: + """Check if buffer contains complete JSON.""" + return self._complete diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py new file mode 100644 index 00000000..34d98a74 --- /dev/null +++ b/src/rotator_library/client/transforms.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Provider-specific request transformations. + +This module isolates all provider-specific request mutations that were +scattered throughout client.py, including: +- gemma-3 system message conversion +- qwen_code provider remapping +- Gemini safety settings and thinking parameter +- NVIDIA thinking parameter +- iflow stream_options removal +- dedaluslabs tool_choice=auto removal + +Transforms are applied in a defined order with logging of modifications. +""" + +import logging +from typing import Any, Callable, Dict, List, Optional + +lib_logger = logging.getLogger("rotator_library") + + +class ProviderTransforms: + """ + Centralized provider-specific request transformations. + + Transforms are applied in order: + 1. Built-in transforms (gemma-3, qwen_code, etc.) + 2. Provider hook transforms (from provider plugins) + 3. Safety settings conversions + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + provider_config: Optional[Any] = None, + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize ProviderTransforms. + + Args: + provider_plugins: Dict mapping provider names to plugin classes + provider_config: ProviderConfig instance for LiteLLM conversions + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + self._config = provider_config + + # Registry of built-in transforms + # Each provider can have multiple transform functions + self._transforms: Dict[str, List[Callable]] = { + "gemma": [self._transform_gemma_system_messages], + "qwen_code": [self._transform_qwen_code_provider], + "gemini": [self._transform_gemini_safety, self._transform_gemini_thinking], + "nvidia_nim": [self._transform_nvidia_thinking], + "iflow": [self._transform_iflow_stream_options], + "dedaluslabs": [self._transform_dedaluslabs_tool_choice], + } + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """Get or create a plugin instance for a provider.""" + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + async def apply( + self, + provider: str, + model: str, + credential: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Apply all applicable transforms to request kwargs. + + Args: + provider: Provider name + model: Model being requested + credential: Selected credential + kwargs: Request kwargs (will be mutated) + + Returns: + Modified kwargs + """ + modifications: List[str] = [] + + # 1. Apply built-in transforms + for transform_provider, transforms in self._transforms.items(): + # Check if transform applies (provider match or model contains pattern) + if transform_provider == provider or transform_provider in model.lower(): + for transform in transforms: + result = transform(kwargs, model, provider) + if result: + modifications.append(result) + + # 2. Apply provider hook transforms (async) + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "transform_request"): + try: + hook_result = await plugin.transform_request(kwargs, model, credential) + if hook_result: + modifications.extend(hook_result) + except Exception as e: + lib_logger.debug(f"Provider transform_request hook failed: {e}") + + # 3. Apply model-specific options from provider + if plugin and hasattr(plugin, "get_model_options"): + model_options = plugin.get_model_options(model) + if model_options: + for key, value in model_options.items(): + if key == "reasoning_effort": + kwargs["reasoning_effort"] = value + elif key not in kwargs: + kwargs[key] = value + modifications.append(f"applied model options for {model}") + + # 4. Apply LiteLLM conversion if config available + if self._config and hasattr(self._config, "convert_for_litellm"): + kwargs = self._config.convert_for_litellm(**kwargs) + + if modifications: + lib_logger.debug( + f"Applied transforms for {provider}/{model}: {modifications}" + ) + + return kwargs + + def apply_sync( + self, + provider: str, + model: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Apply built-in transforms synchronously (no provider hooks). + + Useful when async is not available. + + Args: + provider: Provider name + model: Model being requested + kwargs: Request kwargs + + Returns: + Modified kwargs + """ + modifications: List[str] = [] + + for transform_provider, transforms in self._transforms.items(): + if transform_provider == provider or transform_provider in model.lower(): + for transform in transforms: + result = transform(kwargs, model, provider) + if result: + modifications.append(result) + + if modifications: + lib_logger.debug( + f"Applied sync transforms for {provider}/{model}: {modifications}" + ) + + return kwargs + + # ========================================================================= + # BUILT-IN TRANSFORMS + # ========================================================================= + + def _transform_gemma_system_messages( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Convert system messages to user messages for Gemma-3. + + Gemma-3 models don't support system messages, so we convert them + to user messages to maintain functionality. + """ + if "gemma-3" not in model.lower(): + return None + + messages = kwargs.get("messages", []) + if not messages: + return None + + converted = False + new_messages = [] + for m in messages: + if m.get("role") == "system": + new_messages.append({"role": "user", "content": m["content"]}) + converted = True + else: + new_messages.append(m) + + if converted: + kwargs["messages"] = new_messages + return "gemma-3: converted system->user messages" + return None + + def _transform_qwen_code_provider( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Remap qwen_code to qwen provider for LiteLLM. + + The qwen_code provider is a custom wrapper that needs to be + translated to the qwen provider for LiteLLM compatibility. + """ + if provider != "qwen_code": + return None + + kwargs["custom_llm_provider"] = "qwen" + if "/" in model: + kwargs["model"] = model.split("/", 1)[1] + return "qwen_code: remapped to qwen provider" + + def _transform_gemini_safety( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Apply default Gemini safety settings. + + Ensures safety settings are present without overriding explicit settings. + """ + if provider != "gemini": + return None + + # Default safety settings (generic form) + default_generic = { + "harassment": "OFF", + "hate_speech": "OFF", + "sexually_explicit": "OFF", + "dangerous_content": "OFF", + "civic_integrity": "BLOCK_NONE", + } + + # Default Gemini-native settings + default_gemini = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + ] + + # If generic form present, fill in missing keys + if "safety_settings" in kwargs and isinstance(kwargs["safety_settings"], dict): + for k, v in default_generic.items(): + if k not in kwargs["safety_settings"]: + kwargs["safety_settings"][k] = v + return "gemini: filled missing safety settings" + + # If Gemini form present, fill in missing categories + if "safetySettings" in kwargs and isinstance(kwargs["safetySettings"], list): + present = { + item.get("category") + for item in kwargs["safetySettings"] + if isinstance(item, dict) + } + added = 0 + for d in default_gemini: + if d["category"] not in present: + kwargs["safetySettings"].append(d) + added += 1 + if added > 0: + return f"gemini: added {added} missing safety categories" + return None + + # Neither present: set generic defaults + if "safety_settings" not in kwargs and "safetySettings" not in kwargs: + kwargs["safety_settings"] = default_generic.copy() + return "gemini: applied default safety settings" + + return None + + def _transform_gemini_thinking( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Handle thinking parameter for Gemini. + + Delegates to provider plugin's handle_thinking_parameter method. + """ + if provider != "gemini": + return None + + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "handle_thinking_parameter"): + plugin.handle_thinking_parameter(kwargs, model) + return "gemini: handled thinking parameter" + return None + + def _transform_nvidia_thinking( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Handle thinking parameter for NVIDIA NIM. + + Delegates to provider plugin's handle_thinking_parameter method. + """ + if provider != "nvidia_nim": + return None + + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "handle_thinking_parameter"): + plugin.handle_thinking_parameter(kwargs, model) + return "nvidia_nim: handled thinking parameter" + return None + + def _transform_iflow_stream_options( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Remove stream_options for iflow provider. + + The iflow provider returns HTTP 406 if stream_options is present. + """ + if provider != "iflow": + return None + + if "stream_options" in kwargs: + del kwargs["stream_options"] + return "iflow: removed stream_options" + return None + + def _transform_dedaluslabs_tool_choice( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Remove tool_choice=auto for dedaluslabs provider. + + Dedaluslabs API returns HTTP 422 if tool_choice is passed as a string + ("auto") instead of an object. Since "auto" is the default behavior, + removing it fixes the issue without changing functionality. + """ + if provider != "dedaluslabs": + return None + + if kwargs.get("tool_choice") == "auto": + del kwargs["tool_choice"] + return "dedaluslabs: removed tool_choice=auto" + return None + + # ========================================================================= + # SAFETY SETTINGS CONVERSION + # ========================================================================= + + def convert_safety_settings( + self, + provider: str, + settings: Dict[str, str], + ) -> Optional[Any]: + """ + Convert generic safety settings to provider-specific format. + + Args: + provider: Provider name + settings: Generic safety settings dict + + Returns: + Provider-specific settings or None + """ + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "convert_safety_settings"): + return plugin.convert_safety_settings(settings) + return None diff --git a/src/rotator_library/client/types.py b/src/rotator_library/client/types.py new file mode 100644 index 00000000..b54bba0e --- /dev/null +++ b/src/rotator_library/client/types.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Client-specific type definitions. + +Types that are only used within the client package. +Shared types are in core/types.py. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + + +@dataclass +class AvailabilityStats: + """ + Statistics about credential availability for a model. + + Used for logging and monitoring credential pool status. + """ + + available: int # Credentials not on cooldown and not exhausted + on_cooldown: int # Credentials on cooldown + fair_cycle_excluded: int # Credentials excluded by fair cycle + total: int # Total credentials for provider + + @property + def usable(self) -> int: + """Return count of usable credentials.""" + return self.available + + def __str__(self) -> str: + parts = [f"{self.available}/{self.total}"] + if self.on_cooldown > 0: + parts.append(f"cd:{self.on_cooldown}") + if self.fair_cycle_excluded > 0: + parts.append(f"fc:{self.fair_cycle_excluded}") + return ",".join(parts) + + +@dataclass +class RetryState: + """ + State tracking for a retry loop. + + Used by RequestExecutor to track retry attempts and errors. + """ + + tried_credentials: Set[str] = field(default_factory=set) + last_exception: Optional[Exception] = None + consecutive_quota_failures: int = 0 + + def record_attempt(self, credential: str) -> None: + """Record that a credential was tried.""" + self.tried_credentials.add(credential) + + def reset_quota_failures(self) -> None: + """Reset quota failure counter (called after non-quota error).""" + self.consecutive_quota_failures = 0 + + def increment_quota_failures(self) -> None: + """Increment quota failure counter.""" + self.consecutive_quota_failures += 1 + + +@dataclass +class ExecutionResult: + """ + Result of executing a request. + + Returned by RequestExecutor to indicate outcome. + """ + + success: bool + response: Optional[Any] = None + error: Optional[Exception] = None + should_rotate: bool = False + should_fail: bool = False diff --git a/src/rotator_library/config/__init__.py b/src/rotator_library/config/__init__.py index ea49533e..741561cf 100644 --- a/src/rotator_library/config/__init__.py +++ b/src/rotator_library/config/__init__.py @@ -21,6 +21,8 @@ DEFAULT_FAIR_CYCLE_TRACKING_MODE, DEFAULT_FAIR_CYCLE_CROSS_TIER, DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD, + DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, # Custom Caps DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, @@ -31,6 +33,8 @@ COOLDOWN_AUTH_ERROR, COOLDOWN_TRANSIENT_ERROR, COOLDOWN_RATE_LIMIT_DEFAULT, + # Small Cooldown Auto-Retry + DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD, ) __all__ = [ @@ -47,6 +51,8 @@ "DEFAULT_FAIR_CYCLE_TRACKING_MODE", "DEFAULT_FAIR_CYCLE_CROSS_TIER", "DEFAULT_FAIR_CYCLE_DURATION", + "DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD", + "DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD", "DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD", # Custom Caps "DEFAULT_CUSTOM_CAP_COOLDOWN_MODE", @@ -57,4 +63,6 @@ "COOLDOWN_AUTH_ERROR", "COOLDOWN_TRANSIENT_ERROR", "COOLDOWN_RATE_LIMIT_DEFAULT", + # Small Cooldown Auto-Retry + "DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD", ] diff --git a/src/rotator_library/config/defaults.py b/src/rotator_library/config/defaults.py index 59282e1e..ee6a9c40 100644 --- a/src/rotator_library/config/defaults.py +++ b/src/rotator_library/config/defaults.py @@ -86,6 +86,21 @@ # Global fallback: EXHAUSTION_COOLDOWN_THRESHOLD= DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD: int = 300 # 5 minutes +# Fair cycle quota threshold - multiplier of window limit +# 1.0 = credential exhausts after using 1 full window's worth of quota +# 0.5 = credential exhausts after using 50% of window quota +# 2.0 = credential exhausts after using 2x window quota +# Override: FAIR_CYCLE_QUOTA_THRESHOLD_{PROVIDER}= +DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD: float = 1.0 + +# Fair cycle reset cooldown threshold in seconds +# When all credentials are exhausted, the fair cycle will only reset if ALL +# credentials have cooldowns longer than this threshold. If any credential has +# a shorter cooldown, the system will wait for it to expire instead of resetting. +# This prevents premature cycle resets when credentials have short temporary cooldowns. +# Override: FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD_{PROVIDER}= +DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD: int = 90 # 1.5 minutes + # ============================================================================= # CUSTOM CAPS DEFAULTS # ============================================================================= @@ -125,3 +140,12 @@ # Default rate limit cooldown when retry_after not provided (seconds) COOLDOWN_RATE_LIMIT_DEFAULT: int = 60 + +# ============================================================================= +# SMALL COOLDOWN AUTO-RETRY +# ============================================================================= +# When retry_after is below this threshold, automatically retry with the same +# credential instead of rotating. This avoids unnecessary rotation for very +# short rate limits (e.g., 2-3 second capacity bursts). +# Override: SMALL_COOLDOWN_RETRY_THRESHOLD= +DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD: int = 10 # 10 seconds diff --git a/src/rotator_library/cooldown_manager.py b/src/rotator_library/cooldown_manager.py index 8e045e48..0d1bb63e 100644 --- a/src/rotator_library/cooldown_manager.py +++ b/src/rotator_library/cooldown_manager.py @@ -5,12 +5,14 @@ import time from typing import Dict + class CooldownManager: """ Manages global cooldown periods for API providers to handle IP-based rate limiting. This ensures that once a 429 error is received for a provider, all subsequent requests to that provider are paused for a specified duration. """ + def __init__(self): self._cooldowns: Dict[str, float] = {} self._lock = asyncio.Lock() @@ -18,7 +20,9 @@ def __init__(self): async def is_cooling_down(self, provider: str) -> bool: """Checks if a provider is currently in a cooldown period.""" async with self._lock: - return provider in self._cooldowns and time.time() < self._cooldowns[provider] + return ( + provider in self._cooldowns and time.time() < self._cooldowns[provider] + ) async def start_cooldown(self, provider: str, duration: int): """ @@ -37,4 +41,8 @@ async def get_cooldown_remaining(self, provider: str) -> float: if provider in self._cooldowns: remaining = self._cooldowns[provider] - time.time() return max(0, remaining) - return 0 \ No newline at end of file + return 0 + + async def get_remaining_cooldown(self, provider: str) -> float: + """Backward-compatible alias for get_cooldown_remaining.""" + return await self.get_cooldown_remaining(provider) diff --git a/src/rotator_library/core/__init__.py b/src/rotator_library/core/__init__.py new file mode 100644 index 00000000..c88a75a5 --- /dev/null +++ b/src/rotator_library/core/__init__.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Core package for the rotator library. + +Provides shared infrastructure used by both client and usage manager: +- types: Shared dataclasses and type definitions +- errors: All custom exceptions +- config: ConfigLoader for centralized configuration +- constants: Default values and magic numbers +""" + +from .types import ( + CredentialInfo, + RequestContext, + ProcessedChunk, + FilterResult, + FairCycleConfig, + CustomCapConfig, + ProviderConfig, + WindowConfig, + RequestCompleteResult, +) + +from .errors import ( + # Base exceptions + NoAvailableKeysError, + PreRequestCallbackError, + CredentialNeedsReauthError, + EmptyResponseError, + TransientQuotaError, + StreamedAPIError, + # Error classification + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + mask_credential, + is_abnormal_error, + get_retry_after, +) + +from .config import ConfigLoader + +__all__ = [ + # Types + "CredentialInfo", + "RequestContext", + "ProcessedChunk", + "FilterResult", + "FairCycleConfig", + "CustomCapConfig", + "ProviderConfig", + "WindowConfig", + "RequestCompleteResult", + # Errors + "NoAvailableKeysError", + "PreRequestCallbackError", + "CredentialNeedsReauthError", + "EmptyResponseError", + "TransientQuotaError", + "StreamedAPIError", + "ClassifiedError", + "RequestErrorAccumulator", + "classify_error", + "should_rotate_on_error", + "should_retry_same_key", + "mask_credential", + "is_abnormal_error", + "get_retry_after", + # Config + "ConfigLoader", +] diff --git a/src/rotator_library/core/config.py b/src/rotator_library/core/config.py new file mode 100644 index 00000000..34114011 --- /dev/null +++ b/src/rotator_library/core/config.py @@ -0,0 +1,550 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Centralized configuration loader for the rotator library. + +This module provides a ConfigLoader class that handles all configuration +parsing from: +1. System defaults (from config/defaults.py) +2. Provider class attributes +3. Environment variables (ALWAYS override provider defaults) + +The ConfigLoader ensures consistent configuration handling across +both the client and usage manager. +""" + +import os +import logging +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from .types import ( + ProviderConfig, + FairCycleConfig, + CustomCapConfig, + WindowConfig, +) +from .constants import ( + # Defaults + DEFAULT_ROTATION_MODE, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + DEFAULT_FAIR_CYCLE_ENABLED, + DEFAULT_FAIR_CYCLE_TRACKING_MODE, + DEFAULT_FAIR_CYCLE_CROSS_TIER, + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + # Prefixes + ENV_PREFIX_ROTATION_MODE, + ENV_PREFIX_FAIR_CYCLE, + ENV_PREFIX_FAIR_CYCLE_TRACKING, + ENV_PREFIX_FAIR_CYCLE_CROSS_TIER, + ENV_PREFIX_FAIR_CYCLE_DURATION, + ENV_PREFIX_EXHAUSTION_THRESHOLD, + ENV_PREFIX_CONCURRENCY_MULTIPLIER, + ENV_PREFIX_CUSTOM_CAP, + ENV_PREFIX_CUSTOM_CAP_COOLDOWN, +) + +lib_logger = logging.getLogger("rotator_library") + + +class ConfigLoader: + """ + Centralized configuration loader. + + Parses all configuration from: + 1. System defaults + 2. Provider class attributes + 3. Environment variables (ALWAYS override provider defaults) + + Usage: + loader = ConfigLoader(provider_plugins) + config = loader.load_provider_config("antigravity") + """ + + def __init__(self, provider_plugins: Optional[Dict[str, type]] = None): + """ + Initialize the ConfigLoader. + + Args: + provider_plugins: Dict mapping provider names to plugin classes. + If None, no provider-specific defaults are used. + """ + self._plugins = provider_plugins or {} + self._cache: Dict[str, ProviderConfig] = {} + + def load_provider_config( + self, + provider: str, + force_reload: bool = False, + ) -> ProviderConfig: + """ + Load complete configuration for a provider. + + Configuration is loaded in this order (later overrides earlier): + 1. System defaults + 2. Provider class attributes + 3. Environment variables (ALWAYS win) + + Args: + provider: Provider name (e.g., "antigravity", "gemini_cli") + force_reload: If True, bypass cache and reload + + Returns: + Complete ProviderConfig for the provider + """ + if not force_reload and provider in self._cache: + return self._cache[provider] + + # Start with system defaults + config = self._get_system_defaults() + + # Apply provider class defaults + plugin_class = self._plugins.get(provider) + if plugin_class: + config = self._apply_provider_defaults(config, plugin_class, provider) + + # Apply environment variable overrides (ALWAYS win) + config = self._apply_env_overrides(config, provider) + + # Cache and return + self._cache[provider] = config + return config + + def load_all_provider_configs( + self, + providers: List[str], + ) -> Dict[str, ProviderConfig]: + """ + Load configurations for multiple providers. + + Args: + providers: List of provider names + + Returns: + Dict mapping provider names to their configs + """ + return {p: self.load_provider_config(p) for p in providers} + + def clear_cache(self, provider: Optional[str] = None) -> None: + """ + Clear cached configurations. + + Args: + provider: If provided, only clear that provider's cache. + If None, clear all cached configs. + """ + if provider: + self._cache.pop(provider, None) + else: + self._cache.clear() + + # ========================================================================= + # INTERNAL METHODS + # ========================================================================= + + def _get_system_defaults(self) -> ProviderConfig: + """Get a ProviderConfig with all system defaults.""" + return ProviderConfig( + rotation_mode=DEFAULT_ROTATION_MODE, + rotation_tolerance=DEFAULT_ROTATION_TOLERANCE, + priority_multipliers={}, + priority_multipliers_by_mode={}, + sequential_fallback_multiplier=DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + fair_cycle=FairCycleConfig( + enabled=DEFAULT_FAIR_CYCLE_ENABLED, + tracking_mode=DEFAULT_FAIR_CYCLE_TRACKING_MODE, + cross_tier=DEFAULT_FAIR_CYCLE_CROSS_TIER, + duration=DEFAULT_FAIR_CYCLE_DURATION, + ), + custom_caps=[], + exhaustion_cooldown_threshold=DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + windows=[], + ) + + def _apply_provider_defaults( + self, + config: ProviderConfig, + plugin_class: type, + provider: str, + ) -> ProviderConfig: + """ + Apply provider class default attributes to config. + + Args: + config: Current configuration + plugin_class: Provider plugin class + provider: Provider name for logging + + Returns: + Updated configuration + """ + # Rotation mode + if hasattr(plugin_class, "default_rotation_mode"): + config.rotation_mode = plugin_class.default_rotation_mode + + # Priority multipliers + if hasattr(plugin_class, "default_priority_multipliers"): + multipliers = plugin_class.default_priority_multipliers + if multipliers: + config.priority_multipliers = dict(multipliers) + + # Sequential fallback multiplier + if hasattr(plugin_class, "default_sequential_fallback_multiplier"): + fallback = plugin_class.default_sequential_fallback_multiplier + if fallback != DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER: + config.sequential_fallback_multiplier = fallback + + # Fair cycle settings + if hasattr(plugin_class, "default_fair_cycle_enabled"): + val = plugin_class.default_fair_cycle_enabled + if val is not None: + config.fair_cycle.enabled = val + + if hasattr(plugin_class, "default_fair_cycle_tracking_mode"): + config.fair_cycle.tracking_mode = ( + plugin_class.default_fair_cycle_tracking_mode + ) + + if hasattr(plugin_class, "default_fair_cycle_cross_tier"): + config.fair_cycle.cross_tier = plugin_class.default_fair_cycle_cross_tier + + if hasattr(plugin_class, "default_fair_cycle_duration"): + duration = plugin_class.default_fair_cycle_duration + if duration != DEFAULT_FAIR_CYCLE_DURATION: + config.fair_cycle.duration = duration + + # Exhaustion cooldown threshold + if hasattr(plugin_class, "default_exhaustion_cooldown_threshold"): + threshold = plugin_class.default_exhaustion_cooldown_threshold + if threshold != DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD: + config.exhaustion_cooldown_threshold = threshold + + # Custom caps + if hasattr(plugin_class, "default_custom_caps"): + caps = plugin_class.default_custom_caps + if caps: + config.custom_caps = self._parse_custom_caps_from_provider(caps) + + return config + + def _apply_env_overrides( + self, + config: ProviderConfig, + provider: str, + ) -> ProviderConfig: + """ + Apply environment variable overrides to config. + + Environment variables ALWAYS override provider class defaults. + + Args: + config: Current configuration + provider: Provider name + + Returns: + Updated configuration with env overrides applied + """ + provider_upper = provider.upper() + + # Rotation mode: ROTATION_MODE_{PROVIDER} + env_key = f"{ENV_PREFIX_ROTATION_MODE}{provider_upper}" + env_val = os.getenv(env_key) + if env_val: + config.rotation_mode = env_val.lower() + if config.rotation_mode not in ("balanced", "sequential"): + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Using 'balanced'.") + config.rotation_mode = "balanced" + + # Fair cycle enabled: FAIR_CYCLE_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE}{provider_upper}" + env_val = os.getenv(env_key) + if env_val is not None: + config.fair_cycle.enabled = env_val.lower() in ("true", "1", "yes") + + # Fair cycle tracking mode: FAIR_CYCLE_TRACKING_MODE_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_TRACKING}{provider_upper}" + env_val = os.getenv(env_key) + if env_val and env_val.lower() in ("model_group", "credential"): + config.fair_cycle.tracking_mode = env_val.lower() + + # Fair cycle cross-tier: FAIR_CYCLE_CROSS_TIER_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_CROSS_TIER}{provider_upper}" + env_val = os.getenv(env_key) + if env_val is not None: + config.fair_cycle.cross_tier = env_val.lower() in ("true", "1", "yes") + + # Fair cycle duration: FAIR_CYCLE_DURATION_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_DURATION}{provider_upper}" + env_val = os.getenv(env_key) + if env_val: + try: + config.fair_cycle.duration = int(env_val) + except ValueError: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Must be integer.") + + # Exhaustion cooldown threshold: EXHAUSTION_COOLDOWN_THRESHOLD_{PROVIDER} + # Also check global: EXHAUSTION_COOLDOWN_THRESHOLD + env_key = f"{ENV_PREFIX_EXHAUSTION_THRESHOLD}{provider_upper}" + env_val = os.getenv(env_key) or os.getenv("EXHAUSTION_COOLDOWN_THRESHOLD") + if env_val: + try: + config.exhaustion_cooldown_threshold = int(env_val) + except ValueError: + lib_logger.warning(f"Invalid exhaustion threshold='{env_val}'.") + + # Priority multipliers: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N} + # Also supports mode-specific: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE} + self._parse_priority_multiplier_env_vars(config, provider_upper) + + # Custom caps: CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL} + # Also: CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL} + self._parse_custom_cap_env_vars(config, provider_upper) + + return config + + def _parse_priority_multiplier_env_vars( + self, + config: ProviderConfig, + provider_upper: str, + ) -> None: + """ + Parse CONCURRENCY_MULTIPLIER_* environment variables. + + Formats: + - CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}=value + - CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE}=value + """ + prefix = f"{ENV_PREFIX_CONCURRENCY_MULTIPLIER}{provider_upper}_PRIORITY_" + + for env_key, env_val in os.environ.items(): + if not env_key.startswith(prefix): + continue + + remainder = env_key[len(prefix) :] + try: + multiplier = int(env_val) + if multiplier < 1: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Must be >= 1.") + continue + + # Check for mode-specific suffix + if "_" in remainder: + parts = remainder.rsplit("_", 1) + priority = int(parts[0]) + mode = parts[1].lower() + + if mode in ("sequential", "balanced"): + if mode not in config.priority_multipliers_by_mode: + config.priority_multipliers_by_mode[mode] = {} + config.priority_multipliers_by_mode[mode][priority] = multiplier + else: + lib_logger.warning(f"Unknown mode in {env_key}: {mode}") + else: + # Universal priority multiplier + priority = int(remainder) + config.priority_multipliers[priority] = multiplier + + except ValueError: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Could not parse.") + + def _parse_custom_cap_env_vars( + self, + config: ProviderConfig, + provider_upper: str, + ) -> None: + """ + Parse CUSTOM_CAP_* environment variables. + + Formats: + - CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL}=value + - CUSTOM_CAP_{PROVIDER}_TDEFAULT_{MODEL}=value + - CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL}=mode:value + """ + cap_prefix = f"{ENV_PREFIX_CUSTOM_CAP}{provider_upper}_T" + cooldown_prefix = f"{ENV_PREFIX_CUSTOM_CAP_COOLDOWN}{provider_upper}_T" + + # Collect caps by (tier_key, model_key) to merge cap and cooldown + caps_dict: Dict[Tuple[Any, str], Dict[str, Any]] = {} + + for env_key, env_val in os.environ.items(): + if env_key.startswith(cooldown_prefix): + remainder = env_key[len(cooldown_prefix) :] + tier_key, model_key = self._parse_tier_model_from_env(remainder) + if tier_key is None: + continue + + # Parse mode:value format + if ":" in env_val: + mode, value_str = env_val.split(":", 1) + try: + value = int(value_str) + except ValueError: + lib_logger.warning(f"Invalid cooldown in {env_key}") + continue + else: + mode = env_val + value = 0 + + key = (tier_key, model_key) + if key not in caps_dict: + caps_dict[key] = {} + caps_dict[key]["cooldown_mode"] = mode + caps_dict[key]["cooldown_value"] = value + + elif env_key.startswith(cap_prefix): + remainder = env_key[len(cap_prefix) :] + tier_key, model_key = self._parse_tier_model_from_env(remainder) + if tier_key is None: + continue + + key = (tier_key, model_key) + if key not in caps_dict: + caps_dict[key] = {} + caps_dict[key]["max_requests"] = env_val + + # Convert to CustomCapConfig objects + for (tier_key, model_key), cap_data in caps_dict.items(): + if "max_requests" not in cap_data: + continue # Need at least max_requests + + cap = CustomCapConfig( + tier_key=tier_key, + model_or_group=model_key, + max_requests=cap_data["max_requests"], + cooldown_mode=cap_data.get("cooldown_mode", "quota_reset"), + cooldown_value=cap_data.get("cooldown_value", 0), + ) + config.custom_caps.append(cap) + + def _parse_tier_model_from_env( + self, + remainder: str, + ) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: + """ + Parse tier and model/group from env var remainder. + + Args: + remainder: String after "CUSTOM_CAP_{PROVIDER}_T" prefix + e.g., "2_CLAUDE" or "2_3_CLAUDE" or "DEFAULT_CLAUDE" + + Returns: + (tier_key, model_key) or (None, None) if parse fails + """ + if not remainder: + return None, None + + parts = remainder.split("_") + if len(parts) < 2: + return None, None + + tier_parts: List[int] = [] + tier_key: Union[int, Tuple[int, ...], str, None] = None + model_key: Optional[str] = None + + for i, part in enumerate(parts): + if part == "DEFAULT": + tier_key = "default" + model_key = "_".join(parts[i + 1 :]) + break + elif part.isdigit(): + tier_parts.append(int(part)) + else: + # First non-numeric part is start of model name + if len(tier_parts) == 0: + return None, None + elif len(tier_parts) == 1: + tier_key = tier_parts[0] + else: + tier_key = tuple(tier_parts) + model_key = "_".join(parts[i:]) + break + else: + # All parts were tier parts, no model + return None, None + + if model_key: + # Convert to lowercase with dashes (standard model name format) + model_key = model_key.lower().replace("_", "-") + + return tier_key, model_key + + def _parse_custom_caps_from_provider( + self, + caps: Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]], + ) -> List[CustomCapConfig]: + """ + Parse custom caps from provider class default_custom_caps attribute. + + Args: + caps: Provider's default_custom_caps dict + + Returns: + List of CustomCapConfig objects + """ + result = [] + + for tier_key, models_config in caps.items(): + for model_key, cap_data in models_config.items(): + cap = CustomCapConfig( + tier_key=tier_key, + model_or_group=model_key, + max_requests=cap_data.get("max_requests", 0), + cooldown_mode=cap_data.get("cooldown_mode", "quota_reset"), + cooldown_value=cap_data.get("cooldown_value", 0), + ) + result.append(cap) + + return result + + +# ============================================================================= +# MODULE-LEVEL CONVENIENCE FUNCTIONS +# ============================================================================= + +# Global loader instance (initialized lazily) +_global_loader: Optional[ConfigLoader] = None + + +def get_config_loader( + provider_plugins: Optional[Dict[str, type]] = None, +) -> ConfigLoader: + """ + Get the global ConfigLoader instance. + + Creates a new instance if none exists or if provider_plugins is provided. + + Args: + provider_plugins: Optional dict of provider plugins. If provided, + creates a new loader with these plugins. + + Returns: + The global ConfigLoader instance + """ + global _global_loader + + if provider_plugins is not None: + _global_loader = ConfigLoader(provider_plugins) + elif _global_loader is None: + _global_loader = ConfigLoader() + + return _global_loader + + +def load_provider_config( + provider: str, + provider_plugins: Optional[Dict[str, type]] = None, +) -> ProviderConfig: + """ + Convenience function to load a provider's configuration. + + Args: + provider: Provider name + provider_plugins: Optional provider plugins dict + + Returns: + ProviderConfig for the provider + """ + loader = get_config_loader(provider_plugins) + return loader.load_provider_config(provider) diff --git a/src/rotator_library/core/constants.py b/src/rotator_library/core/constants.py new file mode 100644 index 00000000..073cbb9b --- /dev/null +++ b/src/rotator_library/core/constants.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Constants and default values for the rotator library. + +This module re-exports all constants from the config package and adds +any additional constants needed for the refactored architecture. + +All tunable defaults are in config/defaults.py - this module provides +a unified import point and adds non-tunable constants. +""" + +# Re-export all tunable defaults from config package +from ..config import ( + # Rotation & Selection + DEFAULT_ROTATION_MODE, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_MAX_RETRIES, + DEFAULT_GLOBAL_TIMEOUT, + # Tier & Priority + DEFAULT_TIER_PRIORITY, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + # Fair Cycle Rotation + DEFAULT_FAIR_CYCLE_ENABLED, + DEFAULT_FAIR_CYCLE_TRACKING_MODE, + DEFAULT_FAIR_CYCLE_CROSS_TIER, + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD, + DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + # Custom Caps + DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, + DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, + # Cooldown & Backoff + COOLDOWN_BACKOFF_TIERS, + COOLDOWN_BACKOFF_MAX, + COOLDOWN_AUTH_ERROR, + COOLDOWN_TRANSIENT_ERROR, + COOLDOWN_RATE_LIMIT_DEFAULT, + # Small Cooldown Auto-Retry + DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD, +) + +# ============================================================================= +# ADDITIONAL CONSTANTS FOR REFACTORED ARCHITECTURE +# ============================================================================= + +# Environment variable prefixes for configuration +ENV_PREFIX_ROTATION_MODE = "ROTATION_MODE_" +ENV_PREFIX_FAIR_CYCLE = "FAIR_CYCLE_" +ENV_PREFIX_FAIR_CYCLE_TRACKING = "FAIR_CYCLE_TRACKING_MODE_" +ENV_PREFIX_FAIR_CYCLE_CROSS_TIER = "FAIR_CYCLE_CROSS_TIER_" +ENV_PREFIX_FAIR_CYCLE_DURATION = "FAIR_CYCLE_DURATION_" +ENV_PREFIX_EXHAUSTION_THRESHOLD = "EXHAUSTION_COOLDOWN_THRESHOLD_" +ENV_PREFIX_CONCURRENCY_MULTIPLIER = "CONCURRENCY_MULTIPLIER_" +ENV_PREFIX_CUSTOM_CAP = "CUSTOM_CAP_" +ENV_PREFIX_CUSTOM_CAP_COOLDOWN = "CUSTOM_CAP_COOLDOWN_" +ENV_PREFIX_QUOTA_GROUPS = "QUOTA_GROUPS_" + +# Provider-specific providers that use request_count instead of success_count +# for credential selection (because failed requests also consume quota) +REQUEST_COUNT_PROVIDERS = frozenset({"antigravity", "gemini_cli", "chutes", "nanogpt"}) + +# Usage manager storage +USAGE_FILE_NAME = "usage.json" # New format +LEGACY_USAGE_FILE_NAME = "key_usage.json" # Old format +USAGE_SCHEMA_VERSION = 2 + +# Fair cycle tracking keys +FAIR_CYCLE_ALL_TIERS_KEY = "__all_tiers__" +FAIR_CYCLE_CREDENTIAL_KEY = "__credential__" +FAIR_CYCLE_STORAGE_KEY = "__fair_cycle__" + +# Logging +LIB_LOGGER_NAME = "rotator_library" + +__all__ = [ + # From config package + "DEFAULT_ROTATION_MODE", + "DEFAULT_ROTATION_TOLERANCE", + "DEFAULT_MAX_RETRIES", + "DEFAULT_GLOBAL_TIMEOUT", + "DEFAULT_TIER_PRIORITY", + "DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER", + "DEFAULT_FAIR_CYCLE_ENABLED", + "DEFAULT_FAIR_CYCLE_TRACKING_MODE", + "DEFAULT_FAIR_CYCLE_CROSS_TIER", + "DEFAULT_FAIR_CYCLE_DURATION", + "DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD", + "DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD", + "DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD", + "DEFAULT_CUSTOM_CAP_COOLDOWN_MODE", + "DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE", + "COOLDOWN_BACKOFF_TIERS", + "COOLDOWN_BACKOFF_MAX", + "COOLDOWN_AUTH_ERROR", + "COOLDOWN_TRANSIENT_ERROR", + "COOLDOWN_RATE_LIMIT_DEFAULT", + # Small Cooldown Auto-Retry + "DEFAULT_SMALL_COOLDOWN_RETRY_THRESHOLD", + # Environment variable prefixes + "ENV_PREFIX_ROTATION_MODE", + "ENV_PREFIX_FAIR_CYCLE", + "ENV_PREFIX_FAIR_CYCLE_TRACKING", + "ENV_PREFIX_FAIR_CYCLE_CROSS_TIER", + "ENV_PREFIX_FAIR_CYCLE_DURATION", + "ENV_PREFIX_EXHAUSTION_THRESHOLD", + "ENV_PREFIX_CONCURRENCY_MULTIPLIER", + "ENV_PREFIX_CUSTOM_CAP", + "ENV_PREFIX_CUSTOM_CAP_COOLDOWN", + "ENV_PREFIX_QUOTA_GROUPS", + # Provider sets + "REQUEST_COUNT_PROVIDERS", + # Storage + "USAGE_FILE_NAME", + "LEGACY_USAGE_FILE_NAME", + "USAGE_SCHEMA_VERSION", + # Fair cycle keys + "FAIR_CYCLE_ALL_TIERS_KEY", + "FAIR_CYCLE_CREDENTIAL_KEY", + "FAIR_CYCLE_STORAGE_KEY", + # Logging + "LIB_LOGGER_NAME", +] diff --git a/src/rotator_library/core/errors.py b/src/rotator_library/core/errors.py new file mode 100644 index 00000000..5acd9fc7 --- /dev/null +++ b/src/rotator_library/core/errors.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Error handling for the rotator library. + +This module re-exports all exception classes and error handling utilities +from the main error_handler module, and adds any new error types needed +for the refactored architecture. + +Note: The actual implementations remain in error_handler.py for backward +compatibility. This module provides a cleaner import path. +""" + +# Re-export everything from error_handler +from ..error_handler import ( + # Exception classes + NoAvailableKeysError, + PreRequestCallbackError, + CredentialNeedsReauthError, + EmptyResponseError, + TransientQuotaError, + # Error classification + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + is_abnormal_error, + # Utilities + mask_credential, + get_retry_after, + extract_retry_after_from_body, + is_rate_limit_error, + is_server_error, + is_unrecoverable_error, + # Constants + ABNORMAL_ERROR_TYPES, + NORMAL_ERROR_TYPES, +) + + +# ============================================================================= +# NEW EXCEPTIONS FOR REFACTORED ARCHITECTURE +# ============================================================================= + + +class StreamedAPIError(Exception): + """ + Custom exception to signal an API error received over a stream. + + This is raised when an error is detected in streaming response data, + allowing the retry logic to handle it appropriately. + + Attributes: + message: Human-readable error message + data: The parsed error data (dict or exception) + """ + + def __init__(self, message: str, data=None): + super().__init__(message) + self.data = data + + +__all__ = [ + # Exception classes + "NoAvailableKeysError", + "PreRequestCallbackError", + "CredentialNeedsReauthError", + "EmptyResponseError", + "TransientQuotaError", + "StreamedAPIError", + # Error classification + "ClassifiedError", + "RequestErrorAccumulator", + "classify_error", + "should_rotate_on_error", + "should_retry_same_key", + "is_abnormal_error", + # Utilities + "mask_credential", + "get_retry_after", + "extract_retry_after_from_body", + "is_rate_limit_error", + "is_server_error", + "is_unrecoverable_error", + # Constants + "ABNORMAL_ERROR_TYPES", + "NORMAL_ERROR_TYPES", +] diff --git a/src/rotator_library/core/types.py b/src/rotator_library/core/types.py new file mode 100644 index 00000000..f6cc72e2 --- /dev/null +++ b/src/rotator_library/core/types.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Shared type definitions for the rotator library. + +This module contains dataclasses and type definitions used across +both the client and usage manager packages. +""" + +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) + + +# ============================================================================= +# CREDENTIAL TYPES +# ============================================================================= + + +@dataclass +class CredentialInfo: + """ + Information about a credential. + + Used for passing credential metadata between components. + """ + + accessor: str # File path or API key + stable_id: str # Email (OAuth) or hash (API key) + provider: str + tier: Optional[str] = None + priority: int = 999 # Lower = higher priority + display_name: Optional[str] = None + + +# ============================================================================= +# REQUEST TYPES +# ============================================================================= + + +@dataclass +class RequestContext: + """ + Context for a request being processed. + + Contains all information needed to execute a request with + retry/rotation logic. + """ + + model: str + provider: str + kwargs: Dict[str, Any] + streaming: bool + credentials: List[str] + deadline: float + request: Optional[Any] = None # FastAPI Request object + pre_request_callback: Optional[Callable] = None + transaction_logger: Optional[Any] = None + + +@dataclass +class ProcessedChunk: + """ + Result of processing a streaming chunk. + + Used by StreamingHandler to return processed chunk data. + """ + + sse_string: str # The SSE-formatted string to yield + usage: Optional[Dict[str, Any]] = None + finish_reason: Optional[str] = None + has_tool_calls: bool = False + + +# ============================================================================= +# FILTER TYPES +# ============================================================================= + + +@dataclass +class FilterResult: + """ + Result of credential filtering. + + Contains categorized credentials after filtering by tier compatibility. + """ + + compatible: List[str] = field(default_factory=list) # Known compatible + unknown: List[str] = field(default_factory=list) # Unknown tier + incompatible: List[str] = field(default_factory=list) # Known incompatible + priorities: Dict[str, int] = field(default_factory=dict) # credential -> priority + tier_names: Dict[str, str] = field(default_factory=dict) # credential -> tier name + + @property + def all_usable(self) -> List[str]: + """Return all usable credentials (compatible + unknown).""" + return self.compatible + self.unknown + + +# ============================================================================= +# CONFIGURATION TYPES +# ============================================================================= + + +@dataclass +class FairCycleConfig: + """ + Fair cycle rotation configuration for a provider. + + Fair cycle ensures each credential is used at least once before + any credential is reused. + """ + + enabled: Optional[bool] = None # None = derive from rotation mode + tracking_mode: str = "model_group" # "model_group" or "credential" + cross_tier: bool = False # Track across all tiers + duration: int = 604800 # 7 days in seconds + + +@dataclass +class CustomCapConfig: + """ + Custom cap configuration for a tier/model combination. + + Allows setting usage limits that can be absolute, offset from API limits, + or percentage of API limits. + """ + + tier_key: Union[int, Tuple[int, ...], str] # Priority(s) or "default" + model_or_group: str # Model name or quota group name + max_requests: Union[int, str] # Absolute value, offset, or percentage + max_requests_mode: str = "absolute" # "absolute", "offset", "percentage" + cooldown_mode: str = "quota_reset" # "quota_reset", "offset", "fixed" + cooldown_value: int = 0 # Seconds for offset/fixed modes + + +@dataclass +class WindowConfig: + """ + Quota window configuration. + + Defines how usage is tracked and reset for a credential. + """ + + name: str # e.g., "5h", "daily", "weekly" + duration_seconds: Optional[int] # None for infinite/total + reset_mode: str # "rolling", "fixed_daily", "calendar_weekly", "api_authoritative" + applies_to: str # "credential", "group", "model" + + +@dataclass +class ProviderConfig: + """ + Complete configuration for a provider. + + Loaded by ConfigLoader and used by both client and usage manager. + """ + + rotation_mode: str = "balanced" # "balanced" or "sequential" + rotation_tolerance: float = 3.0 + priority_multipliers: Dict[int, int] = field(default_factory=dict) + priority_multipliers_by_mode: Dict[str, Dict[int, int]] = field( + default_factory=dict + ) + sequential_fallback_multiplier: int = 1 + fair_cycle: FairCycleConfig = field(default_factory=FairCycleConfig) + custom_caps: List[CustomCapConfig] = field(default_factory=list) + exhaustion_cooldown_threshold: int = 300 # 5 minutes + windows: List[WindowConfig] = field(default_factory=list) + + +# ============================================================================= +# HOOK RESULT TYPES +# ============================================================================= + + +@dataclass +class RequestCompleteResult: + """ + Result from on_request_complete provider hook. + + Allows providers to customize how requests are counted and cooled down. + """ + + count_override: Optional[int] = None # How many requests to count + cooldown_override: Optional[float] = None # Custom cooldown duration + force_exhausted: bool = False # Mark for fair cycle + + +# ============================================================================= +# ERROR ACTION ENUM +# ============================================================================= + + +class ErrorAction: + """ + Actions to take after an error. + + Used by RequestExecutor to determine next steps. + """ + + RETRY_SAME = "retry_same" # Retry with same credential + ROTATE = "rotate" # Try next credential + FAIL = "fail" # Fail the request immediately diff --git a/src/rotator_library/credential_manager.py b/src/rotator_library/credential_manager.py index 9a7e5edb..1ad48593 100644 --- a/src/rotator_library/credential_manager.py +++ b/src/rotator_library/credential_manager.py @@ -3,10 +3,13 @@ import os import re +import json +import time +import base64 import shutil import logging from pathlib import Path -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Union, Any, Tuple from .utils.paths import get_oauth_dir @@ -18,6 +21,7 @@ "qwen_code": Path.home() / ".qwen", "iflow": Path.home() / ".iflow", "antigravity": Path.home() / ".antigravity", + "openai_codex": Path.home() / ".codex", # import source context only # Add other providers like 'claude' here if they have a standard CLI path } @@ -28,6 +32,7 @@ "antigravity": "ANTIGRAVITY", "qwen_code": "QWEN_CODE", "iflow": "IFLOW", + "openai_codex": "OPENAI_CODEX", } @@ -120,6 +125,435 @@ def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]: return result + # ------------------------------------------------------------------------- + # OpenAI Codex first-run import helpers + # ------------------------------------------------------------------------- + + def _decode_jwt_unverified(self, token: str) -> Optional[Dict[str, Any]]: + """Decode JWT payload without signature verification.""" + if not token or not isinstance(token, str): + return None + + parts = token.split(".") + if len(parts) < 2: + return None + + payload = parts[1] + payload += "=" * (-len(payload) % 4) + + try: + decoded = base64.urlsafe_b64decode(payload) + data = json.loads(decoded.decode("utf-8")) + return data if isinstance(data, dict) else None + except Exception: + return None + + def _extract_codex_identity(self, access_token: str, id_token: Optional[str]) -> Tuple[Optional[str], Optional[str], Optional[int]]: + """ + Extract (account_id, email, exp_ms) from Codex JWTs. + + Priority: + - account_id: access_token -> id_token + - email: id_token -> access_token + - exp: access_token -> id_token + """ + + def extract_account(payload: Optional[Dict[str, Any]]) -> Optional[str]: + if not payload: + return None + + direct = payload.get("https://api.openai.com/auth.chatgpt_account_id") + if isinstance(direct, str) and direct.strip(): + return direct.strip() + + auth_claim = payload.get("https://api.openai.com/auth") + if isinstance(auth_claim, dict): + nested = auth_claim.get("chatgpt_account_id") + if isinstance(nested, str) and nested.strip(): + return nested.strip() + + orgs = payload.get("organizations") + if isinstance(orgs, list) and orgs: + first = orgs[0] + if isinstance(first, dict): + org_id = first.get("id") + if isinstance(org_id, str) and org_id.strip(): + return org_id.strip() + + return None + + def extract_email(payload: Optional[Dict[str, Any]]) -> Optional[str]: + if not payload: + return None + email = payload.get("email") + if isinstance(email, str) and email.strip(): + return email.strip() + sub = payload.get("sub") + if isinstance(sub, str) and sub.strip(): + return sub.strip() + return None + + def extract_exp_ms(payload: Optional[Dict[str, Any]]) -> Optional[int]: + if not payload: + return None + exp = payload.get("exp") + if isinstance(exp, (int, float)): + return int(float(exp) * 1000) + return None + + access_payload = self._decode_jwt_unverified(access_token) + id_payload = self._decode_jwt_unverified(id_token) if id_token else None + + account_id = extract_account(access_payload) or extract_account(id_payload) + email = extract_email(id_payload) or extract_email(access_payload) + exp_ms = extract_exp_ms(access_payload) or extract_exp_ms(id_payload) + + return account_id, email, exp_ms + + def _normalize_openai_codex_auth_json_record(self, auth_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Normalize ~/.codex/auth.json format to proxy schema.""" + tokens = auth_data.get("tokens") + if not isinstance(tokens, dict): + return None + + access_token = tokens.get("access_token") + refresh_token = tokens.get("refresh_token") + id_token = tokens.get("id_token") + + if not isinstance(access_token, str) or not isinstance(refresh_token, str): + return None + + account_id, email, exp_ms = self._extract_codex_identity(access_token, id_token) + + # Respect explicit account_id from source tokens if present + explicit_account = tokens.get("account_id") + if isinstance(explicit_account, str) and explicit_account.strip(): + account_id = explicit_account.strip() + + if exp_ms is None: + # conservative fallback to 5 minutes from now + exp_ms = int((time.time() + 300) * 1000) + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "id_token": id_token, + "expiry_date": exp_ms, + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": email, + "account_id": account_id, + "last_check_timestamp": time.time(), + "loaded_from_env": False, + "env_credential_index": None, + }, + } + + def _normalize_openai_codex_accounts_record(self, account: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Normalize one ~/.codex-accounts.json account entry to proxy schema.""" + access_token = account.get("access") + refresh_token = account.get("refresh") + id_token = account.get("idToken") + + if not isinstance(access_token, str) or not isinstance(refresh_token, str): + return None + + account_id, email, exp_ms = self._extract_codex_identity(access_token, id_token) + + explicit_account = account.get("accountId") + if isinstance(explicit_account, str) and explicit_account.strip(): + account_id = explicit_account.strip() + + label = account.get("label") + if not email and isinstance(label, str) and label.strip(): + email = label.strip() + + expires = account.get("expires") + if isinstance(expires, (int, float)): + exp_ms = int(expires) + + if exp_ms is None: + exp_ms = int((time.time() + 300) * 1000) + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "id_token": id_token, + "expiry_date": exp_ms, + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": email, + "account_id": account_id, + "last_check_timestamp": time.time(), + "loaded_from_env": False, + "env_credential_index": None, + }, + } + + def _import_openai_codex_cli_credentials( + self, + auth_json_path: Optional[Path] = None, + accounts_json_path: Optional[Path] = None, + ) -> List[str]: + """ + First-run import from Codex CLI stores into local oauth_creds/. + + Source files are read-only: + - ~/.codex/auth.json (single account) + - ~/.codex-accounts.json (multi-account) + """ + auth_json_path = auth_json_path or (Path.home() / ".codex" / "auth.json") + accounts_json_path = accounts_json_path or (Path.home() / ".codex-accounts.json") + + normalized_records: List[Dict[str, Any]] = [] + + # Source 1: ~/.codex/auth.json + if auth_json_path.exists(): + try: + with open(auth_json_path, "r") as f: + auth_data = json.load(f) + + if isinstance(auth_data, dict): + record = self._normalize_openai_codex_auth_json_record(auth_data) + if record: + normalized_records.append(record) + else: + lib_logger.warning( + "OpenAI Codex import: skipping malformed ~/.codex/auth.json record" + ) + else: + lib_logger.warning( + "OpenAI Codex import: ~/.codex/auth.json root is not an object" + ) + except Exception as e: + lib_logger.warning( + f"OpenAI Codex import: failed to parse ~/.codex/auth.json: {e}" + ) + + # Source 2: ~/.codex-accounts.json + if accounts_json_path.exists(): + try: + with open(accounts_json_path, "r") as f: + accounts_data = json.load(f) + + accounts = [] + if isinstance(accounts_data, dict): + raw_accounts = accounts_data.get("accounts") + if isinstance(raw_accounts, list): + accounts = raw_accounts + elif isinstance(accounts_data, list): + accounts = accounts_data + + if not accounts: + lib_logger.warning( + "OpenAI Codex import: ~/.codex-accounts.json has no accounts list" + ) + + for idx, account in enumerate(accounts): + if not isinstance(account, dict): + lib_logger.warning( + f"OpenAI Codex import: skipping malformed account entry #{idx + 1}" + ) + continue + + record = self._normalize_openai_codex_accounts_record(account) + if record: + normalized_records.append(record) + else: + lib_logger.warning( + f"OpenAI Codex import: skipping malformed account entry #{idx + 1}" + ) + + except Exception as e: + lib_logger.warning( + f"OpenAI Codex import: failed to parse ~/.codex-accounts.json: {e}" + ) + + if not normalized_records: + return [] + + # Deduplicate by account_id first, then email + unique: List[Dict[str, Any]] = [] + seen_account_ids: Set[str] = set() + seen_emails: Set[str] = set() + + for record in normalized_records: + metadata = record.get("_proxy_metadata", {}) + account_id = metadata.get("account_id") + email = metadata.get("email") + + if isinstance(account_id, str) and account_id: + if account_id in seen_account_ids: + continue + seen_account_ids.add(account_id) + + if isinstance(email, str) and email: + if email in seen_emails: + continue + seen_emails.add(email) + + unique.append(record) + + imported_paths: List[str] = [] + for i, record in enumerate(unique, 1): + local_path = self.oauth_base_dir / f"openai_codex_oauth_{i}.json" + try: + with open(local_path, "w") as f: + json.dump(record, f, indent=2) + imported_paths.append(str(local_path.resolve())) + except Exception as e: + lib_logger.error( + f"OpenAI Codex import: failed writing '{local_path.name}': {e}" + ) + + if imported_paths: + identifiers = [] + for p in imported_paths: + try: + with open(p, "r") as f: + payload = json.load(f) + meta = payload.get("_proxy_metadata", {}) + identifiers.append( + meta.get("email") or meta.get("account_id") or Path(p).name + ) + except Exception: + identifiers.append(Path(p).name) + + lib_logger.info( + "OpenAI Codex first-run import complete: " + f"{len(imported_paths)} credential(s) imported ({', '.join(str(x) for x in identifiers)})" + ) + + return imported_paths + + def _import_openai_codex_explicit_paths(self, source_paths: List[Path]) -> List[str]: + """ + Import OpenAI Codex credentials from explicit OPENAI_CODEX_OAUTH_* paths. + + Supports: + - Raw Codex CLI files (`~/.codex/auth.json`, `~/.codex-accounts.json`) + - Already-normalized proxy credential JSON files + + Returns local normalized/copied paths under oauth_creds/. + """ + if not source_paths: + return [] + + normalized_records: List[Dict[str, Any]] = [] + passthrough_paths: List[Path] = [] + + for source_path in sorted(source_paths): + try: + with open(source_path, "r") as f: + payload = json.load(f) + except Exception as e: + lib_logger.warning( + f"OpenAI Codex explicit import: failed to parse '{source_path}': {e}. Falling back to direct copy." + ) + passthrough_paths.append(source_path) + continue + + # Raw ~/.codex/auth.json shape + if isinstance(payload, dict) and isinstance(payload.get("tokens"), dict): + record = self._normalize_openai_codex_auth_json_record(payload) + if record: + normalized_records.append(record) + continue + + # Raw ~/.codex-accounts.json shape (object or root list) + accounts: List[Any] = [] + if isinstance(payload, dict) and isinstance(payload.get("accounts"), list): + accounts = payload.get("accounts") + elif isinstance(payload, list): + accounts = payload + + if accounts: + converted = 0 + for idx, account in enumerate(accounts): + if not isinstance(account, dict): + lib_logger.warning( + f"OpenAI Codex explicit import: skipping malformed account entry #{idx + 1} from '{source_path.name}'" + ) + continue + + record = self._normalize_openai_codex_accounts_record(account) + if record: + normalized_records.append(record) + converted += 1 + + if converted > 0: + continue + + # Already-normalized proxy format + if ( + isinstance(payload, dict) + and isinstance(payload.get("access_token"), str) + and isinstance(payload.get("refresh_token"), str) + ): + passthrough_paths.append(source_path) + continue + + # Unknown shape: preserve existing behavior (copy as-is) + passthrough_paths.append(source_path) + + # Deduplicate normalized records by account_id/email + unique_records: List[Dict[str, Any]] = [] + seen_account_ids: Set[str] = set() + seen_emails: Set[str] = set() + + for record in normalized_records: + metadata = record.get("_proxy_metadata", {}) + account_id = metadata.get("account_id") + email = metadata.get("email") + + if isinstance(account_id, str) and account_id: + if account_id in seen_account_ids: + continue + seen_account_ids.add(account_id) + + if isinstance(email, str) and email: + if email in seen_emails: + continue + seen_emails.add(email) + + unique_records.append(record) + + imported_paths: List[str] = [] + next_index = 1 + + # Write normalized records first + for record in unique_records: + local_path = self.oauth_base_dir / f"openai_codex_oauth_{next_index}.json" + try: + with open(local_path, "w") as f: + json.dump(record, f, indent=2) + imported_paths.append(str(local_path.resolve())) + next_index += 1 + except Exception as e: + lib_logger.error( + f"OpenAI Codex explicit import: failed writing '{local_path.name}': {e}" + ) + + # Copy passthrough files after normalized ones + for source_path in passthrough_paths: + local_path = self.oauth_base_dir / f"openai_codex_oauth_{next_index}.json" + try: + shutil.copy(source_path, local_path) + imported_paths.append(str(local_path.resolve())) + next_index += 1 + except Exception as e: + lib_logger.error( + f"OpenAI Codex explicit import: failed to copy '{source_path}' -> '{local_path}': {e}" + ) + + if imported_paths: + lib_logger.info( + "OpenAI Codex explicit-path import complete: " + f"{len(imported_paths)} credential(s) prepared" + ) + + return imported_paths + def discover_and_prepare(self) -> Dict[str, List[str]]: lib_logger.info("Starting automated OAuth credential discovery...") final_config = {} @@ -165,7 +599,7 @@ def discover_and_prepare(self) -> Dict[str, List[str]]: ] continue - # If no local credentials exist, proceed with a one-time discovery and copy. + # If no local credentials exist, proceed with one-time import/copy. discovered_paths = set() # 1. Add paths from environment variables first, as they are overrides @@ -174,8 +608,30 @@ def discover_and_prepare(self) -> Dict[str, List[str]]: if path.exists(): discovered_paths.add(path) - # 2. If no overrides are provided via .env, scan the default directory - # [MODIFIED] This logic is now disabled to prefer local-first credential management. + # 2. Provider-specific first-run import for OpenAI Codex + # Trigger only when: + # - provider == openai_codex + # - no local openai_codex_oauth_*.json already exist (checked above) + # - no env-based OPENAI_CODEX credentials were selected (provider not in final_config) + # - no explicit OPENAI_CODEX_OAUTH_* file paths were provided + if provider == "openai_codex" and not discovered_paths: + imported = self._import_openai_codex_cli_credentials() + if imported: + final_config[provider] = imported + continue + + # 3. Provider-specific explicit-path import handling for OpenAI Codex + # This normalizes raw ~/.codex/auth.json / ~/.codex-accounts.json when + # supplied via OPENAI_CODEX_OAUTH_* env vars. + if provider == "openai_codex" and discovered_paths: + imported = self._import_openai_codex_explicit_paths( + sorted(list(discovered_paths)) + ) + if imported: + final_config[provider] = imported + continue + + # 4. Default directory scan remains disabled (local-first policy) # if not discovered_paths and default_dir.exists(): # for json_file in default_dir.glob('*.json'): # discovered_paths.add(json_file) diff --git a/src/rotator_library/credential_tool.py b/src/rotator_library/credential_tool.py index c029cd22..7b3ee952 100644 --- a/src/rotator_library/credential_tool.py +++ b/src/rotator_library/credential_tool.py @@ -26,6 +26,7 @@ get_provider_api_key_var, get_provider_display_name, ) +from .providers.utilities.gemini_shared_utils import format_tier_for_display def _get_oauth_base_dir() -> Path: @@ -65,6 +66,7 @@ def _ensure_providers_loaded(): "qwen_code": "Qwen Code", "iflow": "iFlow", "antigravity": "Antigravity", + "openai_codex": "OpenAI Codex", } @@ -80,25 +82,8 @@ def _extract_key_number(key_name: str) -> int: return int(match.group(1)) if match else 0 -def _normalize_tier_name(tier: str) -> str: - """Normalize tier names for consistent display. - - Examples: - "free-tier" -> "free" - "FREE_TIER" -> "free" - "PAID" -> "paid" - "standard" -> "standard" - None -> "unknown" - """ - if not tier: - return "unknown" - - # Lowercase and remove common suffixes/prefixes - normalized = tier.lower().strip() - normalized = normalized.replace("-tier", "").replace("_tier", "") - normalized = normalized.replace("-", "").replace("_", "") - - return normalized +# Note: _normalize_tier_name was replaced with format_tier_for_display +# from providers.utilities.gemini_shared_utils for centralized tier handling def _count_tiers(credentials: list) -> dict: @@ -114,7 +99,7 @@ def _count_tiers(credentials: list) -> dict: for cred in credentials: tier = cred.get("tier") if tier: - normalized = _normalize_tier_name(tier) + normalized = format_tier_for_display(tier) tier_counts[normalized] = tier_counts.get(normalized, 0) + 1 return tier_counts @@ -285,7 +270,13 @@ def _get_oauth_credentials_summary() -> dict: Example: {"gemini_cli": [{"email": "user@example.com", "tier": "free-tier", ...}, ...]} """ provider_factory, _ = _ensure_providers_loaded() - oauth_providers = ["gemini_cli", "qwen_code", "iflow", "antigravity"] + oauth_providers = [ + "gemini_cli", + "qwen_code", + "iflow", + "antigravity", + "openai_codex", + ] oauth_summary = {} for provider_name in oauth_providers: @@ -552,6 +543,9 @@ def _display_provider_credentials(provider_name: str): if provider_name in ["gemini_cli", "antigravity"]: table.add_column("Tier", style="green") table.add_column("Project", style="dim") + # Add type column for iFlow (OAuth vs Cookie) + elif provider_name == "iflow": + table.add_column("Type", style="magenta") for i, cred in enumerate(credentials, 1): file_name = Path(cred["file_path"]).name @@ -563,6 +557,9 @@ def _display_provider_credentials(provider_name: str): if project and len(project) > 20: project = project[:17] + "..." table.add_row(str(i), file_name, email, tier or "-", project or "-") + elif provider_name == "iflow": + cred_type = cred.get("type", "oauth").capitalize() + table.add_row(str(i), file_name, email, cred_type) else: table.add_row(str(i), file_name, email) @@ -793,17 +790,25 @@ async def _view_oauth_credentials_detail(provider_name: str): if provider_name in ["gemini_cli", "antigravity"]: table.add_column("Tier", style="green") table.add_column("Project", style="dim") + # Add type column for iFlow (OAuth vs Cookie) + elif provider_name == "iflow": + table.add_column("Type", style="magenta") for i, cred in enumerate(credentials, 1): file_name = Path(cred["file_path"]).name email = cred.get("email", "unknown") if provider_name in ["gemini_cli", "antigravity"]: - tier = _normalize_tier_name(cred.get("tier")) if cred.get("tier") else "-" + tier = ( + format_tier_for_display(cred.get("tier")) if cred.get("tier") else "-" + ) project = cred.get("project_id", "-") if project and len(project) > 25: project = project[:22] + "..." table.add_row(str(i), file_name, email, tier, project or "-") + elif provider_name == "iflow": + cred_type = cred.get("type", "oauth").capitalize() + table.add_row(str(i), file_name, email, cred_type) else: table.add_row(str(i), file_name, email) @@ -1216,6 +1221,7 @@ async def setup_api_key(): "antigravity", # OAuth-only "qwen_code", # OAuth is primary, don't advertise API key "iflow", # OAuth is primary + "openai_codex", # OAuth-only (ChatGPT OAuth) } # Base classes to exclude @@ -1732,20 +1738,52 @@ async def setup_new_credential(provider_name: str): oauth_friendly_names = { "gemini_cli": "Gemini CLI (OAuth)", "qwen_code": "Qwen Code (OAuth - also supports API keys)", - "iflow": "iFlow (OAuth - also supports API keys)", + "iflow": "iFlow", "antigravity": "Antigravity (OAuth)", + "openai_codex": "OpenAI Codex (OAuth)", } display_name = oauth_friendly_names.get( provider_name, provider_name.replace("_", " ").title() ) - # Call the auth class's setup_credential() method which handles the entire flow: - # - OAuth authentication - # - Email extraction for deduplication - # - File path determination (new or existing) - # - Credential file saving - # - Post-auth discovery (tier/project for Google OAuth providers) - result = await auth_instance.setup_credential(_get_oauth_base_dir()) + # Special handling for iFlow - offer OAuth or Cookie authentication + if provider_name == "iflow": + console.print( + Panel( + Text.from_markup( + "[bold]Choose authentication method:[/bold]\n\n" + " [cyan]1.[/cyan] OAuth (Email login)\n" + " Opens browser for iFlow login\n" + " Token expires and needs periodic refresh\n\n" + " [cyan]2.[/cyan] Cookie [green](Recommended)[/green]\n" + " Paste session cookie from browser\n" + " More permanent, only API key expires" + ), + title="[bold blue]iFlow Authentication Method[/bold blue]", + border_style="blue", + ) + ) + + auth_choice = Prompt.ask( + "[bold]Select method[/bold] (or 'b' to go back)", + choices=["1", "2", "b"], + default="2", + ) + + if auth_choice.lower() == "b": + return + + if auth_choice == "2": + # Cookie authentication + result = await auth_instance.setup_cookie_credential( + _get_oauth_base_dir() + ) + else: + # OAuth authentication + result = await auth_instance.setup_credential(_get_oauth_base_dir()) + else: + # Other providers - use OAuth + result = await auth_instance.setup_credential(_get_oauth_base_dir()) if not result.success: console.print( @@ -1771,7 +1809,15 @@ async def setup_new_credential(provider_name: str): # Add tier/project info if available (Google OAuth providers) if hasattr(result, "tier") and result.tier: - success_text.append(f"\nTier: {result.tier}") + # Try to get the full tier name for better display (e.g., "Google One AI PRO") + tier_display = result.tier + if result.credentials and isinstance(result.credentials, dict): + tier_full = result.credentials.get("_proxy_metadata", {}).get( + "tier_full" + ) + if tier_full: + tier_display = tier_full + success_text.append(f"\nTier: {tier_display}") if hasattr(result, "project_id") and result.project_id: success_text.append(f"\nProject: {result.project_id}") @@ -2165,6 +2211,96 @@ async def export_antigravity_to_env(): ) +async def export_openai_codex_to_env(): + """ + Export an OpenAI Codex credential JSON file to .env format. + Uses the auth class's build_env_lines() and list_credentials() methods. + """ + clear_screen("Export OpenAI Codex Credential") + + provider_factory, _ = _ensure_providers_loaded() + auth_class = provider_factory.get_provider_auth_class("openai_codex") + auth_instance = auth_class() + + credentials = auth_instance.list_credentials(_get_oauth_base_dir()) + + if not credentials: + console.print( + Panel( + "No OpenAI Codex credentials found. Please add one first using 'Add OAuth Credential'.", + style="bold red", + title="No Credentials", + ) + ) + return + + cred_text = Text() + for i, cred_info in enumerate(credentials): + cred_text.append( + f" {i + 1}. {Path(cred_info['file_path']).name} ({cred_info['email']})\n" + ) + + console.print( + Panel( + cred_text, + title="Available OpenAI Codex Credentials", + style="bold blue", + ) + ) + + choice = Prompt.ask( + Text.from_markup( + "[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]" + ), + choices=[str(i + 1) for i in range(len(credentials))] + ["b"], + show_choices=False, + ) + + if choice.lower() == "b": + return + + try: + choice_index = int(choice) - 1 + if 0 <= choice_index < len(credentials): + cred_info = credentials[choice_index] + + env_path = auth_instance.export_credential_to_env( + cred_info["file_path"], _get_oauth_base_dir() + ) + + if env_path: + numbered_prefix = f"OPENAI_CODEX_{cred_info['number']}" + success_text = Text.from_markup( + f"Successfully exported credential to [bold yellow]'{Path(env_path).name}'[/bold yellow]\n\n" + f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n" + f"[bold]To use this credential:[/bold]\n" + f"1. Copy the contents to your main .env file, OR\n" + f"2. Source it: [bold cyan]source {Path(env_path).name}[/bold cyan] (Linux/Mac)\n\n" + f"[bold]To combine multiple credentials:[/bold]\n" + f"Copy lines from multiple .env files into one file.\n" + f"Each credential uses a unique number ({numbered_prefix}_*)." + ) + console.print(Panel(success_text, style="bold green", title="Success")) + else: + console.print( + Panel( + "Failed to export credential", style="bold red", title="Error" + ) + ) + else: + console.print("[bold red]Invalid choice. Please try again.[/bold red]") + except ValueError: + console.print( + "[bold red]Invalid input. Please enter a number or 'b'.[/bold red]" + ) + except Exception as e: + console.print( + Panel( + f"An error occurred during export: {e}", style="bold red", title="Error" + ) + ) + + async def export_all_provider_credentials(provider_name: str): """ Export all credentials for a specific provider to individual .env files. @@ -2329,7 +2465,13 @@ async def combine_all_credentials(): clear_screen("Combine All Credentials") # List of providers that support OAuth credentials - oauth_providers = ["gemini_cli", "qwen_code", "iflow", "antigravity"] + oauth_providers = [ + "gemini_cli", + "qwen_code", + "iflow", + "antigravity", + "openai_codex", + ] provider_factory, _ = _ensure_providers_loaded() @@ -2434,19 +2576,22 @@ async def export_credentials_submenu(): "2. Export Qwen Code credential\n" "3. Export iFlow credential\n" "4. Export Antigravity credential\n" + "5. Export OpenAI Codex credential\n" "\n" "[bold]Bulk Exports (per provider):[/bold]\n" - "5. Export ALL Gemini CLI credentials\n" - "6. Export ALL Qwen Code credentials\n" - "7. Export ALL iFlow credentials\n" - "8. Export ALL Antigravity credentials\n" + "6. Export ALL Gemini CLI credentials\n" + "7. Export ALL Qwen Code credentials\n" + "8. Export ALL iFlow credentials\n" + "9. Export ALL Antigravity credentials\n" + "10. Export ALL OpenAI Codex credentials\n" "\n" "[bold]Combine Credentials:[/bold]\n" - "9. Combine all Gemini CLI into one file\n" - "10. Combine all Qwen Code into one file\n" - "11. Combine all iFlow into one file\n" - "12. Combine all Antigravity into one file\n" - "13. Combine ALL providers into one file" + "11. Combine all Gemini CLI into one file\n" + "12. Combine all Qwen Code into one file\n" + "13. Combine all iFlow into one file\n" + "14. Combine all Antigravity into one file\n" + "15. Combine all OpenAI Codex into one file\n" + "16. Combine ALL providers into one file" ), title="Choose export option", style="bold blue", @@ -2471,6 +2616,9 @@ async def export_credentials_submenu(): "11", "12", "13", + "14", + "15", + "16", "b", ], show_choices=False, @@ -2496,42 +2644,54 @@ async def export_credentials_submenu(): await export_antigravity_to_env() console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() - # Bulk exports (all credentials for a provider) elif export_choice == "5": - await export_all_provider_credentials("gemini_cli") + await export_openai_codex_to_env() console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() + # Bulk exports (all credentials for a provider) elif export_choice == "6": - await export_all_provider_credentials("qwen_code") + await export_all_provider_credentials("gemini_cli") console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() elif export_choice == "7": - await export_all_provider_credentials("iflow") + await export_all_provider_credentials("qwen_code") console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() elif export_choice == "8": + await export_all_provider_credentials("iflow") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "9": await export_all_provider_credentials("antigravity") console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() + elif export_choice == "10": + await export_all_provider_credentials("openai_codex") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() # Combine per provider - elif export_choice == "9": + elif export_choice == "11": await combine_provider_credentials("gemini_cli") console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() - elif export_choice == "10": + elif export_choice == "12": await combine_provider_credentials("qwen_code") console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() - elif export_choice == "11": + elif export_choice == "13": await combine_provider_credentials("iflow") console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() - elif export_choice == "12": + elif export_choice == "14": await combine_provider_credentials("antigravity") console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() + elif export_choice == "15": + await combine_provider_credentials("openai_codex") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() # Combine all providers - elif export_choice == "13": + elif export_choice == "16": await combine_all_credentials() console.print("\n[dim]Press Enter to return to export menu...[/dim]") input() diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 8b05ad84..9fd252c7 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -5,7 +5,7 @@ import json import os import logging -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Tuple import httpx from litellm.exceptions import ( @@ -247,15 +247,49 @@ def is_abnormal_error(classified_error: "ClassifiedError") -> bool: return classified_error.error_type in ABNORMAL_ERROR_TYPES -def mask_credential(credential: str) -> str: +def mask_credential(credential: str, style: str = "short") -> str: """ Mask a credential for safe display in logs and error messages. - - For API keys: shows last 6 characters (e.g., "...xyz123") - - For OAuth file paths: shows just the filename (e.g., "antigravity_oauth_1.json") + Args: + credential: The credential string to mask + style: Masking style - "short" (last 6 chars) or "full" (first 4 + last 4) + + Returns: + Masked credential string: + - For OAuth file paths: shows just the filename (e.g., "oauth_1.json") + - For emails: preserves structure (e.g., "sco***05@***.com") + - For API keys with style="short": shows last 6 chars (e.g., "...xyz123") + - For API keys with style="full": shows first 4 + last 4 (e.g., "AIza...3456") """ + # File paths: show just filename if os.path.isfile(credential) or credential.endswith(".json"): return os.path.basename(credential) + + # Email addresses: preserve structure with masking + if "@" in credential and "." in credential.split("@")[-1]: + local, domain = credential.rsplit("@", 1) + + # Mask local part: first 3 + *** + last 2 (if long enough) + if len(local) > 5: + masked_local = f"{local[:3]}***{local[-2:]}" + elif len(local) > 2: + masked_local = f"{local[:2]}***" + else: + masked_local = "***" + + # Mask domain: keep only TLD + if "." in domain: + tld = domain.rsplit(".", 1)[1] + masked_domain = f"***.{tld}" + else: + masked_domain = "***" + + return f"{masked_local}@{masked_domain}" + + # API keys: original masking logic + if style == "full" and len(credential) > 12: + return f"{credential[:4]}...{credential[-4:]}" elif len(credential) > 6: return f"...{credential[-6:]}" else: @@ -439,6 +473,8 @@ def __init__( status_code: Optional[int] = None, retry_after: Optional[int] = None, quota_reset_timestamp: Optional[float] = None, + quota_value: Optional[str] = None, + quota_id: Optional[str] = None, ): self.error_type = error_type self.original_exception = original_exception @@ -447,6 +483,9 @@ def __init__( # Unix timestamp when quota resets (from quota_exhausted errors) # This is the authoritative reset time parsed from provider's error response self.quota_reset_timestamp = quota_reset_timestamp + # Quota details extracted from Google/Gemini API error responses + self.quota_value = quota_value # e.g., "50" or "1000/minute" + self.quota_id = quota_id # e.g., "GenerateContentPerMinutePerProject" def __str__(self): parts = [ @@ -456,6 +495,10 @@ def __str__(self): ] if self.quota_reset_timestamp: parts.append(f"quota_reset_ts={self.quota_reset_timestamp}") + if self.quota_value: + parts.append(f"quota_value={self.quota_value}") + if self.quota_id: + parts.append(f"quota_id={self.quota_id}") parts.append(f"original_exc={self.original_exception}") return f"ClassifiedError({', '.join(parts)})" @@ -520,6 +563,73 @@ def _extract_retry_from_json_body(json_text: str) -> Optional[int]: return None +def _extract_quota_details(json_text: str) -> Tuple[Optional[str], Optional[str]]: + """ + Extract quota details (quotaValue, quotaId) from a JSON error response. + + Handles Google/Gemini API error formats with nested details array containing + QuotaFailure violations. + + Example error structure: + { + "error": { + "details": [ + { + "@type": "type.googleapis.com/google.rpc.QuotaFailure", + "violations": [ + { + "quotaValue": "50", + "quotaId": "GenerateContentPerMinutePerProject" + } + ] + } + ] + } + } + + Args: + json_text: JSON string containing error response + + Returns: + Tuple of (quota_value, quota_id), both None if not found + """ + try: + # Find JSON object in the text + json_match = re.search(r"(\{.*\})", json_text, re.DOTALL) + if not json_match: + return None, None + + error_json = json.loads(json_match.group(1)) + error_obj = error_json.get("error", {}) + details = error_obj.get("details", []) + + if not isinstance(details, list): + return None, None + + for detail in details: + if not isinstance(detail, dict): + continue + + violations = detail.get("violations", []) + if not isinstance(violations, list): + continue + + for violation in violations: + if not isinstance(violation, dict): + continue + + quota_value = violation.get("quotaValue") + quota_id = violation.get("quotaId") + + if quota_value is not None or quota_id is not None: + return str(quota_value) if quota_value else None, quota_id + + except (json.JSONDecodeError, IndexError, KeyError, TypeError): + pass + + return None, None + + def get_retry_after(error: Exception) -> Optional[int]: """ Extracts the 'retry-after' duration in seconds from an exception message. @@ -672,12 +782,19 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr reset_ts = quota_info.get("reset_timestamp") quota_reset_timestamp = quota_info.get("quota_reset_timestamp") + # Extract quota details from error body + quota_value, quota_id = None, None + if error_body: + quota_value, quota_id = _extract_quota_details(error_body) + # Log the parsed result with human-readable duration hours = retry_after / 3600 lib_logger.info( f"Provider '{provider}' parsed quota error: " f"retry_after={retry_after}s ({hours:.1f}h), reason={reason}" + (f", resets at {reset_ts}" if reset_ts else "") + + (f", quota={quota_value}" if quota_value else "") + + (f", quotaId={quota_id}" if quota_id else "") ) return ClassifiedError( @@ -686,6 +803,8 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=429, retry_after=retry_after, quota_reset_timestamp=quota_reset_timestamp, + quota_value=quota_value, + quota_id=quota_id, ) except Exception as parse_error: lib_logger.debug( @@ -723,11 +842,23 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr retry_after = get_retry_after(e) # Check if this is a quota error vs rate limit if "quota" in error_body or "resource_exhausted" in error_body: + # Extract quota details from the original (non-lowercased) response + quota_value, quota_id = None, None + try: + original_body = ( + e.response.text if hasattr(e.response, "text") else "" + ) + quota_value, quota_id = _extract_quota_details(original_body) + except Exception: + pass + return ClassifiedError( error_type="quota_exceeded", original_exception=e, status_code=status_code, retry_after=retry_after, + quota_value=quota_value, + quota_id=quota_id, ) return ClassifiedError( error_type="rate_limit", @@ -771,8 +902,36 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=status_code, ) if 500 <= status_code: + # Log 503 MODEL_CAPACITY_EXHAUSTED for visibility + # (Provider-level handling may intercept this before it reaches here) + if status_code == 503: + try: + capacity_exhausted = False + if error_body and "MODEL_CAPACITY_EXHAUSTED" in error_body: + capacity_exhausted = True + else: + # Try to get from response if not in lowercased body + original_body = ( + e.response.text if hasattr(e.response, "text") else "" + ) + if "MODEL_CAPACITY_EXHAUSTED" in original_body: + capacity_exhausted = True + + if capacity_exhausted: + lib_logger.info( + "503 MODEL_CAPACITY_EXHAUSTED detected - " + "will be handled with provider/model cooldown" + ) + except Exception: + pass + + # Apply default 30s cooldown for all server errors + # This prevents rapid retries against overloaded/erroring servers return ClassifiedError( - error_type="server_error", original_exception=e, status_code=status_code + error_type="server_error", + original_exception=e, + status_code=status_code, + retry_after=30, # Default 30s cooldown for server errors ) if isinstance( @@ -804,6 +963,7 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr error_type="server_error", original_exception=e, status_code=503, + retry_after=30, # Default 30s cooldown for server errors ) if isinstance(e, TransientQuotaError): @@ -813,6 +973,7 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr error_type="server_error", original_exception=e, status_code=503, + retry_after=30, # Default 30s cooldown for server errors ) if isinstance(e, RateLimitError): @@ -820,11 +981,21 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr # Check if this is a quota error vs rate limit error_msg = str(e).lower() if "quota" in error_msg or "resource_exhausted" in error_msg: + # Try to extract quota details from exception body + quota_value, quota_id = None, None + try: + error_body = getattr(e, "body", None) or str(e) + quota_value, quota_id = _extract_quota_details(str(error_body)) + except Exception: + pass + return ClassifiedError( error_type="quota_exceeded", original_exception=e, status_code=status_code or 429, retry_after=retry_after, + quota_value=quota_value, + quota_id=quota_id, ) return ClassifiedError( error_type="rate_limit", @@ -868,6 +1039,7 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr error_type="server_error", original_exception=e, status_code=status_code or 503, + retry_after=30, # Default 30s cooldown for server errors ) # Fallback for any other unclassified errors @@ -927,16 +1099,35 @@ def should_rotate_on_error(classified_error: ClassifiedError) -> bool: return classified_error.error_type not in non_rotatable_errors -def should_retry_same_key(classified_error: ClassifiedError) -> bool: +def should_retry_same_key( + classified_error: ClassifiedError, + small_cooldown_threshold: int = 10, +) -> bool: """ Determines if an error should retry with the same key (with backoff). - Only server errors and connection issues should retry the same key, - as these are often transient. + Retry same key if: + 1. Any error with a small retry_after (< threshold) - more efficient to wait + than rotate and disrupt cache locality + 2. Server errors or connection issues (often transient) + + Args: + classified_error: The classified error + small_cooldown_threshold: If retry_after < this, always retry same key. + Default is 10 seconds. Override via SMALL_COOLDOWN_RETRY_THRESHOLD env var. Returns: True if should retry same key, False if should rotate immediately """ + # Small retry_after = faster to just wait than rotate + # This preserves cache locality and avoids unnecessary rotation + if ( + classified_error.retry_after is not None + and 0 < classified_error.retry_after < small_cooldown_threshold + ): + return True + + # Standard transient errors that should retry same key retryable_errors = { "server_error", "api_connection", diff --git a/src/rotator_library/provider_factory.py b/src/rotator_library/provider_factory.py index dcc40bc9..cac95536 100644 --- a/src/rotator_library/provider_factory.py +++ b/src/rotator_library/provider_factory.py @@ -7,12 +7,14 @@ from .providers.qwen_auth_base import QwenAuthBase from .providers.iflow_auth_base import IFlowAuthBase from .providers.antigravity_auth_base import AntigravityAuthBase +from .providers.openai_codex_auth_base import OpenAICodexAuthBase PROVIDER_MAP = { "gemini_cli": GeminiAuthBase, "qwen_code": QwenAuthBase, "iflow": IFlowAuthBase, "antigravity": AntigravityAuthBase, + "openai_codex": OpenAICodexAuthBase, } def get_provider_auth_class(provider_name: str): diff --git a/src/rotator_library/providers/antigravity_auth_base.py b/src/rotator_library/providers/antigravity_auth_base.py index af6c140d..bfec1c21 100644 --- a/src/rotator_library/providers/antigravity_auth_base.py +++ b/src/rotator_library/providers/antigravity_auth_base.py @@ -13,10 +13,47 @@ import httpx from .google_oauth_base import GoogleOAuthBase -# Note: Endpoint constants are imported by helper methods from gemini_shared_utils + +# Import tier utilities from shared module +# These are re-exported here for backwards compatibility with existing imports +from .utilities.gemini_shared_utils import ( + # Tier constants + TIER_ULTRA, + TIER_PRO, + TIER_FREE, + TIER_NAME_TO_CANONICAL, + CANONICAL_TO_LEGACY, + FREE_TIER_IDS, + TIER_PRIORITIES, + DEFAULT_TIER_PRIORITY, + # Tier functions + normalize_tier_name, + is_free_tier, + is_paid_tier, + get_tier_priority, + format_tier_for_display, + get_tier_full_name, + # Project ID extraction + extract_project_id_from_response, + # Credential loading helpers + load_persisted_project_metadata, + # Env file helpers + build_project_tier_env_lines, + # Endpoint constants + ANTIGRAVITY_LOAD_ENDPOINT_ORDER, + ANTIGRAVITY_ENDPOINT_FALLBACKS, +) lib_logger = logging.getLogger("rotator_library") +# ============================================================================= +# FALLBACK PROJECT ID +# ============================================================================= +# When loadCodeAssist returns no project, uses this fallback unconditionally +# See: quota.rs:135 - let final_project_id = project_id.unwrap_or("bamboo-precept-lgxtn"); +FALLBACK_PROJECT_ID = "bamboo-precept-lgxtn" + + # Headers for Antigravity auth/discovery calls (loadCodeAssist, onboardUser) # CRITICAL: User-Agent MUST be google-api-nodejs-client/* for standard-tier detection. # Using antigravity/* UA causes server to return free-tier only (tested via matrix test). @@ -59,6 +96,8 @@ def __init__(self): # Project and tier caches - shared between auth base and provider self.project_id_cache: Dict[str, str] = {} self.project_tier_cache: Dict[str, str] = {} + self.tier_full_cache: Dict[str, str] = {} # Full tier names for display + self.tier_full_cache: Dict[str, str] = {} # Full tier names for display # ========================================================================= # POST-AUTH DISCOVERY HOOK @@ -98,7 +137,25 @@ async def _post_auth_discovery( credential_path, access_token, litellm_params={} ) - tier = self.project_tier_cache.get(credential_path, "unknown") + # Use full tier name for post-auth log (one-time display) + tier_full = self.tier_full_cache.get(credential_path) + tier = tier_full or self.project_tier_cache.get(credential_path, "unknown") + lib_logger.info( + f"Post-auth discovery complete for {Path(credential_path).name}: " + f"tier={tier}, project={project_id}" + ) + + # Use full tier name for post-auth log (one-time display) + tier_full = self.tier_full_cache.get(credential_path) + tier = tier_full or self.project_tier_cache.get(credential_path, "unknown") + lib_logger.info( + f"Post-auth discovery complete for {Path(credential_path).name}: " + f"tier={tier}, project={project_id}" + ) + + # Use full tier name for post-auth log (one-time display) + tier_full = self.tier_full_cache.get(credential_path) + tier = tier_full or self.project_tier_cache.get(credential_path, "unknown") lib_logger.info( f"Post-auth discovery complete for {Path(credential_path).name}: " f"tier={tier}, project={project_id}" @@ -108,30 +165,6 @@ async def _post_auth_discovery( # ENDPOINT FALLBACK HELPERS # ========================================================================= - def _extract_project_id_from_response( - self, data: Dict[str, Any], key: str = "cloudaicompanionProject" - ) -> Optional[str]: - """ - Extract project ID from API response, handling both string and object formats. - - The API may return cloudaicompanionProject as either: - - A string: "project-id-123" - - An object: {"id": "project-id-123", ...} - - Args: - data: API response data - key: Key to extract from (default: "cloudaicompanionProject") - - Returns: - Project ID string or None if not found - """ - value = data.get(key) - if isinstance(value, str) and value: - return value - if isinstance(value, dict): - return value.get("id") - return None - async def _call_load_code_assist( self, client: httpx.AsyncClient, @@ -291,53 +324,16 @@ async def _discover_project_id( # Load credentials to check for persisted/configured project_id and tier credential_index = self._parse_env_credential_path(credential_path) - if credential_index is None: - # File-based credentials: load from file - try: - with open(credential_path, "r") as f: - creds = json.load(f) - - metadata = creds.get("_proxy_metadata", {}) - persisted_project_id = metadata.get("project_id") - persisted_tier = metadata.get("tier") - - if persisted_project_id: - lib_logger.debug( - f"Loaded persisted project ID from credential file: {persisted_project_id}" - ) - self.project_id_cache[credential_path] = persisted_project_id - - # Also load tier if available - if persisted_tier: - self.project_tier_cache[credential_path] = persisted_tier - lib_logger.debug(f"Loaded persisted tier: {persisted_tier}") - - return persisted_project_id - except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: - lib_logger.debug(f"Could not load persisted project ID from file: {e}") - else: - # Env-based credentials: load from credentials cache - # The credentials were already loaded by _load_from_env() which reads - # {PREFIX}_{N}_PROJECT_ID and {PREFIX}_{N}_TIER into _proxy_metadata - if credential_path in self._credentials_cache: - creds = self._credentials_cache[credential_path] - metadata = creds.get("_proxy_metadata", {}) - env_project_id = metadata.get("project_id") - env_tier = metadata.get("tier") - - if env_project_id: - lib_logger.debug( - f"Loaded project ID from env credential metadata: {env_project_id}" - ) - self.project_id_cache[credential_path] = env_project_id - - if env_tier: - self.project_tier_cache[credential_path] = env_tier - lib_logger.debug( - f"Loaded tier from env credential metadata: {env_tier}" - ) - - return env_project_id + persisted_project_id = load_persisted_project_metadata( + credential_path, + credential_index, + self._credentials_cache, + self.project_id_cache, + self.project_tier_cache, + self.tier_full_cache, + ) + if persisted_project_id: + return persisted_project_id lib_logger.debug( "No cached or configured project ID found, initiating discovery..." @@ -359,6 +355,8 @@ async def _discover_project_id( discovered_project_id = None discovered_tier = None + discovered_tier_full = None + discovered_tier_full = None async with httpx.AsyncClient() as client: # 1. Try discovery endpoint with loadCodeAssist using endpoint fallback @@ -384,10 +382,15 @@ async def _discover_project_id( ) # Extract tier information + # Canonical prioritizes paidTier over currentTier for accurate subscription detection allowed_tiers = data.get("allowedTiers", []) current_tier = data.get("currentTier") + paid_tier = data.get( + "paidTier" + ) # Added: Canonical-style tier detection lib_logger.debug(f"=== Tier Information ===") + lib_logger.debug(f"paidTier: {paid_tier}") lib_logger.debug(f"currentTier: {current_tier}") lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}") for i, tier in enumerate(allowed_tiers): @@ -399,83 +402,77 @@ async def _discover_project_id( ) lib_logger.debug(f"========================") - # Determine the current tier ID - current_tier_id = None - if current_tier: - current_tier_id = current_tier.get("id") - lib_logger.debug(f"User has currentTier: {current_tier_id}") - - # Check if user is already known to server (has currentTier) - if current_tier_id: - # User is already onboarded - check for project from server - # Use helper to handle both string and object formats - server_project = self._extract_project_id_from_response(data) - - # Check if this tier requires user-defined project (paid tiers) - requires_user_project = any( - t.get("id") == current_tier_id - and t.get("userDefinedCloudaicompanionProject", False) - for t in allowed_tiers + # Determine tier ID with Canonical-style priority: paidTier > currentTier + # This matches quota.rs:88-91 logic for accurate subscription detection + effective_tier_id = None + if paid_tier and paid_tier.get("id"): + effective_tier_id = paid_tier.get("id") + lib_logger.debug(f"Using paidTier: {effective_tier_id}") + elif current_tier and current_tier.get("id"): + effective_tier_id = current_tier.get("id") + lib_logger.debug( + f"Using currentTier (paidTier not available): {effective_tier_id}" + ) + + # Normalize to canonical tier name (ULTRA, PRO, FREE) + if effective_tier_id: + canonical_tier = normalize_tier_name(effective_tier_id) + lib_logger.debug( + f"Canonical tier: {canonical_tier} (from {effective_tier_id})" ) - is_free_tier = current_tier_id == "free-tier" + # Check if user is already known to server (has tier info) + if effective_tier_id: + # User has tier info - use Canonical-style project selection + server_project = extract_project_id_from_response(data) + + # Canonical-style project selection (quota.rs:135): + # 1. Server project (if returned) + # 2. Configured project (from env) + # 3. Fallback project (always works) if server_project: - # Server returned a project - use it (server wins) project_id = server_project lib_logger.debug(f"Server returned project: {project_id}") elif configured_project_id: - # No server project but we have configured one - use it project_id = configured_project_id - lib_logger.debug( - f"No server project, using configured: {project_id}" - ) - elif is_free_tier: - # Free tier user without server project - try onboarding - lib_logger.debug( - "Free tier user with currentTier but no project - will try onboarding" - ) - project_id = None - elif requires_user_project: - # Paid tier requires a project ID to be set - raise ValueError( - f"Paid tier '{current_tier_id}' requires setting ANTIGRAVITY_PROJECT_ID environment variable." - ) + lib_logger.debug(f"Using configured project: {project_id}") else: - # Unknown tier without project - proceed to onboarding - lib_logger.warning( - f"Tier '{current_tier_id}' has no project and none configured - will try onboarding" - ) - project_id = None - - if project_id: - # Cache tier info - self.project_tier_cache[credential_path] = current_tier_id - discovered_tier = current_tier_id - - # Log appropriately based on tier - is_paid = current_tier_id and current_tier_id not in [ - "free-tier", - "legacy-tier", - "unknown", - ] - if is_paid: - lib_logger.info( - f"Using Antigravity paid tier '{current_tier_id}' with project: {project_id}" - ) - else: - lib_logger.info( - f"Discovered Antigravity project ID via loadCodeAssist: {project_id}" - ) - - self.project_id_cache[credential_path] = project_id - discovered_project_id = project_id - - # Persist to credential file - await self._persist_project_metadata( - credential_path, project_id, discovered_tier + # Canonical fallback: project_id.unwrap_or("bamboo-precept-lgxtn") + project_id = FALLBACK_PROJECT_ID + lib_logger.info( + f"No server/configured project for tier '{effective_tier_id}' - using Canonical fallback: {project_id}" ) - return project_id + # We have a project now (guaranteed by fallback) + # Cache tier info - use canonical tier name for consistency + canonical = ( + normalize_tier_name(effective_tier_id) or effective_tier_id + ) + self.project_tier_cache[credential_path] = canonical + discovered_tier = canonical + + # Get and cache full tier name for display + tier_full = get_tier_full_name(effective_tier_id) + self.tier_full_cache[credential_path] = tier_full + discovered_tier_full = tier_full + + # Log with full tier name for discovery messages + lib_logger.info( + f"Discovered Antigravity tier '{tier_full}' with project: {project_id}" + ) + + self.project_id_cache[credential_path] = project_id + discovered_project_id = project_id + + # Persist to credential file + await self._persist_project_metadata( + credential_path, + project_id, + discovered_tier, + discovered_tier_full, + ) + + return project_id # 2. User needs onboarding - no currentTier or no project found lib_logger.info( @@ -513,9 +510,9 @@ async def _discover_project_id( # Build onboard request based on tier type # FREE tier: cloudaicompanionProject = None (server-managed) # PAID tier: cloudaicompanionProject = configured_project_id - is_free_tier = tier_id == "free-tier" + tier_is_free = is_free_tier(tier_id) - if is_free_tier: + if tier_is_free: # Free tier uses server-managed project onboard_request = { "tierId": tier_id, @@ -590,7 +587,7 @@ async def _discover_project_id( # Extract project ID from LRO response using helper # Note: onboardUser returns response.cloudaicompanionProject as an object with .id lro_response_data = lro_data.get("response", {}) - project_id = self._extract_project_id_from_response(lro_response_data) + project_id = extract_project_id_from_response(lro_response_data) # Fallback to configured project if LRO didn't return one if not project_id and configured_project_id: @@ -612,28 +609,30 @@ async def _discover_project_id( f"Successfully extracted project ID from onboarding response: {project_id}" ) - # Cache tier info - self.project_tier_cache[credential_path] = tier_id - discovered_tier = tier_id - lib_logger.debug(f"Cached tier information: {tier_id}") + # Cache tier info - use canonical tier name for consistency + canonical_tier = normalize_tier_name(tier_id) or tier_id + self.project_tier_cache[credential_path] = canonical_tier + discovered_tier = canonical_tier - # Log concise message based on tier - is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"] - if is_paid: - lib_logger.info( - f"Using Antigravity paid tier '{tier_id}' with project: {project_id}" - ) - else: - lib_logger.info( - f"Successfully onboarded user and discovered project ID: {project_id}" - ) + # Get and cache full tier name for display + tier_full = get_tier_full_name(tier_id) + self.tier_full_cache[credential_path] = tier_full + discovered_tier_full = tier_full + lib_logger.debug( + f"Cached tier information: {canonical_tier} (full: {tier_full})" + ) + + # Log with full tier name for onboarding messages + lib_logger.info( + f"Onboarded Antigravity credential with tier '{tier_full}', project: {project_id}" + ) self.project_id_cache[credential_path] = project_id discovered_project_id = project_id # Persist to credential file await self._persist_project_metadata( - credential_path, project_id, discovered_tier + credential_path, project_id, discovered_tier, discovered_tier_full ) return project_id @@ -736,43 +735,6 @@ async def _discover_project_id( "To manually specify a project, set ANTIGRAVITY_PROJECT_ID in your .env file." ) - async def _persist_project_metadata( - self, credential_path: str, project_id: str, tier: Optional[str] - ): - """Persists project ID and tier to the credential file for faster future startups.""" - # Skip persistence for env:// paths (environment-based credentials) - credential_index = self._parse_env_credential_path(credential_path) - if credential_index is not None: - lib_logger.debug( - f"Skipping project metadata persistence for env:// credential path: {credential_path}" - ) - return - - try: - # Load current credentials - with open(credential_path, "r") as f: - creds = json.load(f) - - # Update metadata - if "_proxy_metadata" not in creds: - creds["_proxy_metadata"] = {} - - creds["_proxy_metadata"]["project_id"] = project_id - if tier: - creds["_proxy_metadata"]["tier"] = tier - - # Save back using the existing save method (handles atomic writes and permissions) - await self._save_credentials(credential_path, creds) - - lib_logger.debug( - f"Persisted project_id and tier to credential file: {credential_path}" - ) - except Exception as e: - lib_logger.warning( - f"Failed to persist project metadata to credential file: {e}" - ) - # Non-fatal - just means slower startup next time - # ========================================================================= # CREDENTIAL MANAGEMENT OVERRIDES # ========================================================================= @@ -790,16 +752,7 @@ def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]: # Get base lines from parent class lines = super().build_env_lines(creds, cred_number) - # Add Antigravity-specific fields (tier and project_id) - metadata = creds.get("_proxy_metadata", {}) - prefix = f"{self.ENV_PREFIX}_{cred_number}" - - project_id = metadata.get("project_id", "") - tier = metadata.get("tier", "") - - if project_id: - lines.append(f"{prefix}_PROJECT_ID={project_id}") - if tier: - lines.append(f"{prefix}_TIER={tier}") + # Add project_id and tier using shared helper + lines.extend(build_project_tier_env_lines(creds, self.ENV_PREFIX, cred_number)) return lines diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index 38bbb07f..6816ede9 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -29,6 +29,7 @@ import random import time import uuid +from contextvars import ContextVar from datetime import datetime, timezone from pathlib import Path from typing import ( @@ -59,6 +60,9 @@ GEMINI3_TOOL_RENAMES_REVERSE, FINISH_REASON_MAP, DEFAULT_SAFETY_SETTINGS, + # Tier utilities + TIER_PRIORITIES, + DEFAULT_TIER_PRIORITY, ) from ..transaction_logger import AntigravityProviderLogger from .utilities.gemini_tool_handler import GeminiToolHandler @@ -69,7 +73,7 @@ from ..utils.paths import get_logs_dir, get_cache_dir if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager # ============================================================================= @@ -113,9 +117,14 @@ def __init__(self, finish_message: str, raw_response: Dict[str, Any]): # Required headers for Antigravity API calls # These headers are CRITICAL for gemini-3-pro-high/low to work # Without X-Goog-Api-Client and Client-Metadata, only gemini-3-pro-preview works -# User-Agent matches official Antigravity Electron client +ANTIGRAVITY_USER_AGENT = "antigravity/1.15.8 windows/amd64" +ANTIGRAVITY_USER_AGENT_LEGACY = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Antigravity/1.104.0 Chrome/138.0.7204.235 " + "Electron/37.3.1 Safari/537.36" +) ANTIGRAVITY_HEADERS = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Antigravity/1.104.0 Chrome/138.0.7204.235 Electron/37.3.1 Safari/537.36", + "User-Agent": ANTIGRAVITY_USER_AGENT, "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", "Client-Metadata": '{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}', } @@ -148,6 +157,7 @@ def __init__(self, finish_message: str, raw_response: Dict[str, Any]): # Claude models "claude-sonnet-4.5", # Uses -thinking variant when reasoning_effort provided "claude-opus-4.5", # ALWAYS uses -thinking variant (non-thinking doesn't exist) + "claude-opus-4.6", # ALWAYS uses -thinking variant (non-thinking doesn't exist) # Other models # "gpt-oss-120b-medium", # GPT-OSS model, shares quota with Claude ] @@ -167,6 +177,36 @@ def __init__(self, finish_message: str, raw_response: Dict[str, Any]): MALFORMED_CALL_MAX_RETRIES = max(1, env_int("ANTIGRAVITY_MALFORMED_CALL_RETRIES", 2)) MALFORMED_CALL_RETRY_DELAY = env_int("ANTIGRAVITY_MALFORMED_CALL_DELAY", 1) +# 503 MODEL_CAPACITY_EXHAUSTED retry configuration +# When server returns 503 (capacity exhausted), retry with longer delay +# since rotating credentials is pointless - all credentials are equally affected +CAPACITY_EXHAUSTED_MAX_ATTEMPTS = max(1, env_int("ANTIGRAVITY_503_MAX_ATTEMPTS", 10)) +CAPACITY_EXHAUSTED_RETRY_DELAY = env_int("ANTIGRAVITY_503_RETRY_DELAY", 5) + +# ============================================================================= +# INTERNAL RETRY COUNTING (for usage tracking) +# ============================================================================= +# Tracks the number of API attempts made per request, including internal retries +# for empty responses, bare 429s, and malformed function calls. +# +# Uses ContextVar for thread-safety: each async task (request) gets its own +# isolated value, so concurrent requests don't interfere with each other. +# +# The count is: +# - Reset to 1 at the start of _streaming_with_retry +# - Incremented each time we retry (before the next attempt) +# - Read by on_request_complete() hook to report actual API call count +# +# Example: Request gets bare 429 twice, then succeeds +# Attempt 1: bare 429 → count stays 1, increment to 2, retry +# Attempt 2: bare 429 → count is 2, increment to 3, retry +# Attempt 3: success → count is 3 +# on_request_complete returns count_override=3 +# +_internal_attempt_count: ContextVar[int] = ContextVar( + "antigravity_attempt_count", default=1 +) + # System instruction configuration # When true (default), prepend the Antigravity agent system instruction (identity, tool_calling, etc.) PREPEND_INSTRUCTION = env_bool("ANTIGRAVITY_PREPEND_INSTRUCTION", True) @@ -210,6 +250,7 @@ def __init__(self, finish_message: str, raw_response: Dict[str, Any]): # Claude: API/internal names → public user-facing names "claude-sonnet-4-5": "claude-sonnet-4.5", "claude-opus-4-5": "claude-opus-4.5", + "claude-opus-4-6": "claude-opus-4.6", } MODEL_ALIAS_REVERSE = {v: k for k, v in MODEL_ALIAS_MAP.items()} @@ -319,7 +360,30 @@ def _get_claude_thinking_cache_file(): """ # Parallel tool usage encouragement instruction -DEFAULT_PARALLEL_TOOL_INSTRUCTION = """When multiple independent operations are needed, prefer making parallel tool calls in a single response rather than sequential calls across multiple responses. This reduces round-trips and improves efficiency. Only use sequential calls when one tool's output is required as input for another.""" +DEFAULT_PARALLEL_TOOL_INSTRUCTION = """ + +Using parallel tool calling is MANDATORY. Be proactive about it. DO NO WAIT for the user to request "parallel calls" + +PARALLEL CALLS SHOULD BE AND _IS THE PRIMARY WAY YOU USE TOOLS IN THIS ENVIRONMENT_ + +When you have to perform multi-step operations such as read multiple files, spawn task subagents, bash commands, multiple edits... _THE USER WANTS YOU TO MAKE PARALLEL TOOL CALLS_ instead of separate sequential calls. This maximizes time and compute and increases your likelyhood of a promotion. Sequential tool calling is only encouraged when relying on the output of a call for the next one(s) + +- WHAT CAN BE DONE IN PARALLEL, MUST BE, AND WILL BE DONE IN PARALLEL +- INDIVIDUAL TOOL CALLS TO GATHER CONTEXT IS HEAVILY DISCOURAGED (please make parallel calls!) +- PARALLEL TOOL CALLING IS YOUR BEST FRIEND AND WILL INCREASE USER'S HAPPINESS + +- Make parallel tool calls to manage ressources more efficiently, plan your tool calls ahead, then execute them in parallel. +- Make parallel calls PROPERLY, be mindful of dependencies between calls. + +When researching anything, IT IS BETTER TO READ SPECULATIVELY, THEN TO READ SEQUENTIALLY. For example, if you need to read multiple files to gather context, read them all in parallel instead of reading one, then the next, etc. + +This environment has a powerful tool to remove unnecessary context, so you can always read more than needed and then trim down later, no need to use limit and offset parameters on the read tool. + +When making code changes, IT IS BETTER TO MAKE MULTIPLE EDITS IN PARALLEL RATHER THAN ONE AT A TIME. + +Do as much as you can in parallel, be efficient with you API requests, no single tool call spam, this is crucial as the user pays PER API request, so make them count! + +""" # Interleaved thinking support for Claude models # Allows Claude to think between tool calls and after receiving tool results @@ -535,7 +599,7 @@ def _generate_stable_session_id(contents: List[Dict[str, Any]]) -> str: if text: # SHA256 hash and extract first 8 bytes as int64 h = hashlib.sha256(text.encode("utf-8")).digest() - # Use big-endian to match Go's binary.BigEndian.Uint64 + # Use big-endian for 64-bit integer conversion n = struct.unpack(">Q", h[:8])[0] & 0x7FFFFFFFFFFFFFFF return f"-{n}" @@ -1002,7 +1066,7 @@ class AntigravityProvider( - Gemini 2.5 (Pro/Flash) with thinkingBudget - Gemini 3 (Pro/Flash/Image) with thinkingLevel - Claude Sonnet 4.5 via Antigravity proxy - - Claude Opus 4.5 via Antigravity proxy + - Claude Opus 4.x via Antigravity proxy Features: - Unified streaming/non-streaming handling @@ -1023,22 +1087,12 @@ class AntigravityProvider( # Provider name for env var lookups (QUOTA_GROUPS_ANTIGRAVITY_*) provider_env_name: str = "antigravity" - # Tier name -> priority mapping (Single Source of Truth) - # Lower numbers = higher priority - tier_priorities = { - # Priority 1: Highest paid tier (Google AI Ultra - name unconfirmed) - # "google-ai-ultra": 1, # Uncomment when tier name is confirmed - # Priority 2: Standard paid tier - "standard-tier": 2, - # Priority 3: Free tier - "free-tier": 3, - # Priority 10: Legacy/Unknown (lowest) - "legacy-tier": 10, - "unknown": 10, - } + # Tier name -> priority mapping (from centralized tier utilities) + # Lower numbers = higher priority (ULTRA=1 > PRO=2 > FREE=3) + tier_priorities = TIER_PRIORITIES # Default priority for tiers not in the mapping - default_tier_priority: int = 10 + default_tier_priority: int = DEFAULT_TIER_PRIORITY # Usage reset configs keyed by priority sets # Priorities 1-2 (paid tiers) get 5h window, others get 7d window @@ -1070,8 +1124,11 @@ class AntigravityProvider( "claude-sonnet-4-5-thinking", "claude-opus-4-5", "claude-opus-4-5-thinking", + "claude-opus-4-6", + "claude-opus-4-6-thinking", "claude-sonnet-4.5", "claude-opus-4.5", + "claude-opus-4.6", "gpt-oss-120b-medium", ], # Gemini 3 Pro variants share quota @@ -1105,11 +1162,11 @@ class AntigravityProvider( # Priority 1 (paid ultra): 5x concurrent requests # Priority 2 (standard paid): 3x concurrent requests # Others: Use sequential fallback (2x) or balanced default (1x) - default_priority_multipliers = {1: 5, 2: 3} + default_priority_multipliers = {1: 2, 2: 1} # For sequential mode, lower priority tiers still get 2x to maintain stickiness # For balanced mode, this doesn't apply (falls back to 1x) - default_sequential_fallback_multiplier = 2 + default_sequential_fallback_multiplier = 1 # Custom caps examples (commented - uncomment and modify as needed) # default_custom_caps = { @@ -1498,9 +1555,77 @@ def _clear_tool_name_mapping(self) -> None: """Clear tool name mapping at start of each request.""" self._tool_name_mapping.clear() - def _get_antigravity_headers(self) -> Dict[str, str]: - """Return the Antigravity API headers. Used by quota tracker mixin.""" - return ANTIGRAVITY_HEADERS + def _get_credential_email(self, credential_path: str) -> Optional[str]: + """ + Extract email from credential file's _proxy_metadata. + + Args: + credential_path: Path to the credential file + + Returns: + Email address if found, None otherwise + """ + # Skip env:// paths + if self._parse_env_credential_path(credential_path) is not None: + return None + + try: + # Try to get from cached credentials first + if ( + hasattr(self, "_credentials_cache") + and credential_path in self._credentials_cache + ): + creds = self._credentials_cache[credential_path] + return creds.get("_proxy_metadata", {}).get("email") + + # Fall back to reading file + with open(credential_path, "r") as f: + creds = json.load(f) + return creds.get("_proxy_metadata", {}).get("email") + except Exception: + return None + + def _get_antigravity_headers( + self, credential_path: Optional[str] = None + ) -> Dict[str, str]: + """ + Return the Antigravity API headers with per-credential fingerprinting. + + If credential_path is provided and has a valid email, returns complete + fingerprint headers (User-Agent, X-Goog-Api-Client, Client-Metadata, + X-Goog-QuotaUser, X-Client-Device-Id) unique to that credential. + Otherwise returns static default headers. + + Args: + credential_path: Optional credential path for fingerprint lookup + + Returns: + Dict of HTTP headers for Antigravity API + """ + # Try to get per-credential fingerprint headers + if credential_path: + email = self._get_credential_email(credential_path) + if email: + try: + from .utilities.device_profile import ( + get_or_create_fingerprint, + build_fingerprint_headers, + ) + + fingerprint = get_or_create_fingerprint(email) + if fingerprint: + # Returns all 5 headers: User-Agent, X-Goog-Api-Client, + # Client-Metadata, X-Goog-QuotaUser, X-Client-Device-Id + return build_fingerprint_headers(fingerprint) + except Exception as e: + lib_logger.debug(f"Failed to build fingerprint headers: {e}") + + # Fallback to static headers (no fingerprint available) + return { + "User-Agent": ANTIGRAVITY_HEADERS["User-Agent"], + "X-Goog-Api-Client": ANTIGRAVITY_HEADERS["X-Goog-Api-Client"], + "Client-Metadata": ANTIGRAVITY_HEADERS["Client-Metadata"], + } # NOTE: _load_tier_from_file() is inherited from GeminiCredentialManager mixin # NOTE: get_credential_tier_name() is inherited from GeminiCredentialManager mixin @@ -3393,12 +3518,12 @@ def _transform_to_antigravity_format( internal_model = self._alias_to_internal(model) # Map Claude models to their -thinking variant - # claude-opus-4-5: ALWAYS use -thinking (non-thinking variant doesn't exist) + # claude-opus-4-x: ALWAYS use -thinking (non-thinking variant doesn't exist) # claude-sonnet-4-5: only use -thinking when reasoning_effort is provided if self._is_claude(internal_model) and not internal_model.endswith("-thinking"): - if internal_model == "claude-opus-4-5": - # Opus 4.5 ALWAYS requires -thinking variant - internal_model = "claude-opus-4-5-thinking" + if internal_model in ("claude-opus-4-5", "claude-opus-4-6"): + # Opus models ALWAYS require -thinking variant + internal_model = f"{internal_model}-thinking" elif internal_model == "claude-sonnet-4-5" and reasoning_effort: # Sonnet 4.5 uses -thinking only when reasoning_effort is provided internal_model = "claude-sonnet-4-5-thinking" @@ -3493,7 +3618,7 @@ def _transform_to_antigravity_format( # Then add existing parts (shifted to later positions) new_parts.extend(existing_parts) - # Set the combined system instruction with role "user" (per Go implementation) + # Set the combined system instruction with role "user" if new_parts: request[target_key] = { "role": "user", @@ -3906,7 +4031,7 @@ async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str] headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", - **ANTIGRAVITY_HEADERS, + **self._get_antigravity_headers(api_key), } payload = { "project": _generate_project_id(), @@ -4149,7 +4274,7 @@ async def acompletion( "Authorization": f"Bearer {token}", "Content-Type": "application/json", "Accept": "text/event-stream", - **ANTIGRAVITY_HEADERS, + **self._get_antigravity_headers(credential_path), } # Keep a mutable reference to gemini_contents for retry injection @@ -4593,7 +4718,17 @@ async def _streaming_with_retry( current_gemini_contents = gemini_contents current_payload = payload - for attempt in range(EMPTY_RESPONSE_MAX_ATTEMPTS): + # Reset internal attempt counter for this request (thread-safe via ContextVar) + _internal_attempt_count.set(1) + + # Use the maximum of all retry limits to ensure the loop runs enough iterations + # for whichever error type needs the most retries. Each error type enforces its + # own limit via internal checks (EMPTY_RESPONSE_MAX_ATTEMPTS for empty/429, + # CAPACITY_EXHAUSTED_MAX_ATTEMPTS for 503). + max_loop_attempts = max( + EMPTY_RESPONSE_MAX_ATTEMPTS, CAPACITY_EXHAUSTED_MAX_ATTEMPTS + ) + for attempt in range(max_loop_attempts): chunk_count = 0 try: @@ -4620,6 +4755,8 @@ async def _streaming_with_retry( f"[Antigravity] Empty stream from {model}, " f"attempt {attempt + 1}/{EMPTY_RESPONSE_MAX_ATTEMPTS}. Retrying..." ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set(_internal_attempt_count.get() + 1) await asyncio.sleep(EMPTY_RESPONSE_RETRY_DELAY) continue else: @@ -4722,6 +4859,8 @@ async def _streaming_with_retry( malformed_retry_count, current_payload ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set(_internal_attempt_count.get() + 1) await asyncio.sleep(MALFORMED_CALL_RETRY_DELAY) continue # Retry with modified payload else: @@ -4740,6 +4879,38 @@ async def _streaming_with_retry( return except httpx.HTTPStatusError as e: + # Handle 503 MODEL_CAPACITY_EXHAUSTED - retry internally + # since rotating credentials is pointless (affects all equally) + if e.response.status_code == 503: + error_body = "" + try: + error_body = ( + e.response.text if hasattr(e.response, "text") else "" + ) + except Exception: + pass + + if "MODEL_CAPACITY_EXHAUSTED" in error_body: + if attempt < CAPACITY_EXHAUSTED_MAX_ATTEMPTS - 1: + lib_logger.warning( + f"[Antigravity] 503 MODEL_CAPACITY_EXHAUSTED from {model}, " + f"attempt {attempt + 1}/{CAPACITY_EXHAUSTED_MAX_ATTEMPTS}. " + f"Waiting {CAPACITY_EXHAUSTED_RETRY_DELAY}s..." + ) + # NOTE: Do NOT increment _internal_attempt_count here - 503 capacity + # exhausted errors don't consume quota, so retries are "free" + await asyncio.sleep(CAPACITY_EXHAUSTED_RETRY_DELAY) + continue + else: + # Max attempts reached - propagate error + lib_logger.warning( + f"[Antigravity] 503 MODEL_CAPACITY_EXHAUSTED after " + f"{CAPACITY_EXHAUSTED_MAX_ATTEMPTS} attempts. Giving up." + ) + raise + # Other 503 errors - raise immediately + raise + if e.response.status_code == 429: # Check if this is a bare 429 (no retry info) vs real quota exhaustion quota_info = self.parse_quota_error(e) @@ -4750,6 +4921,10 @@ async def _streaming_with_retry( f"[Antigravity] Bare 429 from {model}, " f"attempt {attempt + 1}/{EMPTY_RESPONSE_MAX_ATTEMPTS}. Retrying..." ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set( + _internal_attempt_count.get() + 1 + ) await asyncio.sleep(EMPTY_RESPONSE_RETRY_DELAY) continue else: @@ -4841,3 +5016,51 @@ async def count_tokens( except Exception as e: lib_logger.error(f"Token counting failed: {e}") return {"prompt_tokens": 0, "total_tokens": 0} + + # ========================================================================= + # USAGE TRACKING HOOK + # ========================================================================= + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional["RequestCompleteResult"]: + """ + Hook called after each request completes. + + Reports the actual number of API calls made, including internal retries + for empty responses, bare 429s, and malformed function calls. + + This uses the ContextVar pattern for thread-safe retry counting: + - _internal_attempt_count is set to 1 at start of _streaming_with_retry + - Incremented before each retry + - Read here to report the actual count + + Example: Request gets 2 bare 429s then succeeds + → 3 API calls made + → Returns count_override=3 + → Usage manager records 3 requests instead of 1 + + Returns: + RequestCompleteResult with count_override set to actual attempt count + """ + from ..core.types import RequestCompleteResult + + # Get the attempt count for this request + attempt_count = _internal_attempt_count.get() + + # Reset for safety (though ContextVar should isolate per-task) + _internal_attempt_count.set(1) + + # Log if we made extra attempts + if attempt_count > 1: + lib_logger.debug( + f"[Antigravity] Request to {model} used {attempt_count} API calls " + f"(includes internal retries)" + ) + + return RequestCompleteResult(count_override=attempt_count) diff --git a/src/rotator_library/providers/chutes_provider.py b/src/rotator_library/providers/chutes_provider.py index 7858b54e..5d9730dd 100644 --- a/src/rotator_library/providers/chutes_provider.py +++ b/src/rotator_library/providers/chutes_provider.py @@ -9,7 +9,7 @@ from .utilities.chutes_quota_tracker import ChutesQuotaTracker if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager # Create a local logger for this module import logging @@ -142,12 +142,15 @@ async def refresh_single_credential( # Store baseline in usage manager # Since Chutes uses credential-level quota, we use a virtual model name + quota_used = ( + int((1.0 - remaining_fraction) * quota) if quota > 0 else 0 + ) await usage_manager.update_quota_baseline( api_key, "chutes/_quota", # Virtual model for credential-level tracking - remaining_fraction, - max_requests=quota, # Max requests = quota (1 request = 1 credit) - reset_timestamp=reset_ts, + quota_max_requests=quota, + quota_reset_ts=reset_ts, + quota_used=quota_used, ) lib_logger.debug( diff --git a/src/rotator_library/providers/example_provider.py b/src/rotator_library/providers/example_provider.py new file mode 100644 index 00000000..9ad31214 --- /dev/null +++ b/src/rotator_library/providers/example_provider.py @@ -0,0 +1,821 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Example Provider Implementation with Custom Usage Management. + +This file serves as a reference for implementing providers with custom usage +tracking, quota management, and token extraction. Copy this file and modify +it for your specific provider. + +============================================================================= +ARCHITECTURE OVERVIEW +============================================================================= + +The usage management system is per-provider. Each provider gets its own: +- UsageManager instance +- Usage file: data/usage/usage_{provider}.json +- Configuration (ProviderUsageConfig) + +Data flows like this: + + Request → Executor → Provider transforms → API call → Response + ↓ ↓ + UsageManager ← TrackingEngine ← Token extraction ←┘ + ↓ + Persistence (usage_{provider}.json) + +Providers customize behavior through: +1. Class attributes (declarative configuration) +2. Methods (behavioral overrides) +3. Hooks (request lifecycle callbacks) + +============================================================================= +USAGE STATS SCHEMA +============================================================================= + +UsageStats (tracked at global/model/group levels): + total_requests: int # All requests + total_successes: int # Successful requests + total_failures: int # Failed requests + total_tokens: int # All tokens combined + total_prompt_tokens: int # Input tokens + total_completion_tokens: int # Output tokens (content only) + total_thinking_tokens: int # Reasoning/thinking tokens + total_output_tokens: int # completion + thinking + total_prompt_tokens_cache_read: int # Cached input tokens read + total_prompt_tokens_cache_write: int # Cached input tokens written + total_approx_cost: float # Estimated cost + first_used_at: float # Timestamp + last_used_at: float # Timestamp + windows: Dict[str, WindowStats] # Per-window breakdown + +WindowStats (per time window: "5h", "daily", "total"): + request_count: int + success_count: int + failure_count: int + prompt_tokens: int + completion_tokens: int + thinking_tokens: int + output_tokens: int + prompt_tokens_cache_read: int + prompt_tokens_cache_write: int + total_tokens: int + approx_cost: float + started_at: float + reset_at: float + limit: int | None + +============================================================================= +""" + +import asyncio +import logging +import time +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +from .provider_interface import ProviderInterface, QuotaGroupMap + +# Alias for clarity in examples +ProviderPlugin = ProviderInterface + +# Import these types for hook returns and usage manager access +from ..core.types import RequestCompleteResult +from ..usage import UsageManager, ProviderUsageConfig, WindowDefinition +from ..usage.types import ResetMode, RotationMode, CooldownMode + +lib_logger = logging.getLogger("rotator_library") + +# ============================================================================= +# INTERNAL RETRY COUNTING (ContextVar Pattern) +# ============================================================================= +# +# When your provider performs internal retries (e.g., for transient errors, +# empty responses, or rate limits), each retry is an API call that should be +# counted for accurate usage tracking. +# +# The challenge: Instance variables (self.count) are shared across concurrent +# requests, so they can't be used safely. ContextVar solves this by giving +# each async task its own isolated value. +# +# Usage pattern: +# 1. Reset to 1 at the start of your retry loop +# 2. Increment before each retry +# 3. Read in on_request_complete() to report the actual count +# +# Example: +# _attempt_count.set(1) # Reset +# for attempt in range(max_attempts): +# try: +# result = await api_call() +# return result +# except RetryableError: +# _attempt_count.set(_attempt_count.get() + 1) # Increment +# continue +# +# Then on_request_complete returns RequestCompleteResult(count_override=_attempt_count.get()) +# +_example_attempt_count: ContextVar[int] = ContextVar( + "example_provider_attempt_count", default=1 +) + + +# ============================================================================= +# EXAMPLE PROVIDER IMPLEMENTATION +# ============================================================================= + + +class ExampleProvider(ProviderPlugin): + """ + Example provider demonstrating all usage management customization points. + + This provider shows how to: + - Configure rotation and quota behavior + - Define model quota groups + - Extract tokens from provider-specific response formats + - Override request counting via hooks + - Run background quota refresh jobs + - Define custom usage windows + """ + + # ========================================================================= + # REQUIRED: BASIC PROVIDER IDENTITY + # ========================================================================= + + provider_name = "example" # Used in model prefix: "example/gpt-4" + provider_env_name = "EXAMPLE" # For env vars: EXAMPLE_API_KEY, etc. + + # ========================================================================= + # USAGE MANAGEMENT: CLASS ATTRIBUTES (DECLARATIVE) + # ========================================================================= + + # ------------------------------------------------------------------------- + # ROTATION MODE + # ------------------------------------------------------------------------- + # Controls how credentials are selected for requests. + # + # Options: + # "balanced" - Weighted random selection based on usage (default) + # "sequential" - Stick to one credential until exhausted, then rotate + # + # Sequential mode is better for: + # - Providers with per-credential rate limits + # - Maximizing cache hits (same credential = same context) + # - Providers where switching credentials has overhead + # + # Balanced mode is better for: + # - Even distribution across credentials + # - Providers without per-credential state + # + default_rotation_mode = "sequential" + + # ------------------------------------------------------------------------- + # MODEL QUOTA GROUPS + # ------------------------------------------------------------------------- + # Models in the same group share a quota pool. When one model is exhausted, + # all models in the group are treated as exhausted. + # + # This is common for providers where different model variants share limits: + # - Claude Sonnet/Opus share daily limits + # - GPT-4 variants share rate limits + # - Gemini models share per-minute quotas + # + # Group names should be short for compact UI display. + # + # Can be overridden via environment: + # QUOTA_GROUPS_EXAMPLE_GPT4="gpt-4o,gpt-4o-mini,gpt-4-turbo" + # + model_quota_groups: QuotaGroupMap = { + # GPT-4 variants share quota + "gpt4": [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4-turbo-preview", + ], + # Claude models share quota + "claude": [ + "claude-3-opus", + "claude-3-sonnet", + "claude-3-haiku", + ], + # Standalone model (no sharing) + "whisper": [ + "whisper-1", + ], + } + + # ------------------------------------------------------------------------- + # PRIORITY MULTIPLIERS (CONCURRENCY) + # ------------------------------------------------------------------------- + # Higher priority credentials (lower number) can handle more concurrent + # requests. This is useful for paid vs free tier credentials. + # + # Priority is assigned per-credential via: + # - .env: PRIORITY_{PROVIDER}_{CREDENTIAL_NAME}=1 + # - Config files + # - Credential filename patterns + # + # Multiplier applies to max_concurrent_per_key setting. + # Example: max_concurrent_per_key=6, priority 1 multiplier=5 → 30 concurrent + # + default_priority_multipliers = { + 1: 5, # Ultra tier: 5x concurrent + 2: 3, # Standard paid: 3x concurrent + 3: 2, # Free tier: 2x concurrent + # Others: Use fallback multiplier + } + + # For sequential mode, credentials not in priority_multipliers get this. + # For balanced mode, they get 1x (no multiplier). + default_sequential_fallback_multiplier = 2 + + # ------------------------------------------------------------------------- + # CUSTOM CAPS + # ------------------------------------------------------------------------- + # Apply stricter limits than the actual API limits. Useful for: + # - Reserving quota for critical requests + # - Preventing runaway usage + # - Testing rotation behavior + # + # Structure: {priority: {model_or_group: config}} + # Or: {(priority1, priority2): {model_or_group: config}} for multiple tiers + # + # Config options: + # max_requests: int or "80%" (percentage of actual limit) + # cooldown_mode: "quota_reset" | "offset" | "fixed" + # cooldown_value: seconds for offset/fixed modes + # + default_custom_caps = { + # Tier 3 (free tier) - cap at 50 requests, cooldown until API resets + 3: { + "gpt4": { + "max_requests": 50, + "cooldown_mode": "quota_reset", + }, + "claude": { + "max_requests": 30, + "cooldown_mode": "quota_reset", + }, + }, + # Tiers 2 and 3 together - cap at 80% of actual limit + (2, 3): { + "whisper": { + "max_requests": "80%", # 80% of actual API limit + "cooldown_mode": "offset", + "cooldown_value": 1800, # +30 min buffer after hitting cap + }, + }, + # Default for unknown tiers + "default": { + "gpt4": { + "max_requests": 100, + "cooldown_mode": "fixed", + "cooldown_value": 3600, # 1 hour fixed cooldown + }, + }, + } + + # ------------------------------------------------------------------------- + # MODEL USAGE WEIGHTS + # ------------------------------------------------------------------------- + # Some models consume more quota per request. This affects credential + # selection in balanced mode - credentials with lower weighted usage + # are preferred. + # + # Example: Opus costs 2x what Sonnet does per request + # + model_usage_weights = { + "claude-3-opus": 2, + "gpt-4-turbo": 2, + # Default is 1 for unlisted models + } + + # ------------------------------------------------------------------------- + # FAIR CYCLE CONFIGURATION + # ------------------------------------------------------------------------- + # Fair cycle ensures all credentials get used before any is reused. + # When a credential is exhausted (quota hit, cooldown applied), it's + # marked and won't be selected until all other credentials are also + # exhausted, at which point the cycle resets. + # + # This is enabled by default for sequential mode. + # + # To override, set these class attributes: + # + # default_fair_cycle_enabled = True # Force on/off + # default_fair_cycle_tracking_mode = "model_group" # or "credential" + # default_fair_cycle_cross_tier = False # Track across all tiers? + # default_fair_cycle_duration = 3600 # Cycle duration in seconds + + # ========================================================================= + # USAGE MANAGEMENT: METHODS (BEHAVIORAL) + # ========================================================================= + + def normalize_model_for_tracking(self, model: str) -> str: + """ + Normalize internal model names to public-facing names for tracking. + + Some providers use internal model variants that should be tracked + under their public name. This ensures usage files only contain + user-facing model names. + + Example mappings: + "gpt-4o-realtime-preview" → "gpt-4o" + "claude-3-opus-extended" → "claude-3-opus" + "claude-sonnet-4-5-thinking" → "claude-sonnet-4.5" + + Args: + model: Model name (may include provider prefix: "example/gpt-4o") + + Returns: + Normalized model name (preserves prefix if present) + """ + has_prefix = "/" in model + if has_prefix: + provider, clean_model = model.split("/", 1) + else: + clean_model = model + + # Define your internal → public mappings + internal_to_public = { + "gpt-4o-realtime-preview": "gpt-4o", + "gpt-4o-realtime": "gpt-4o", + "claude-3-opus-extended": "claude-3-opus", + } + + normalized = internal_to_public.get(clean_model, clean_model) + + if has_prefix: + return f"{provider}/{normalized}" + return normalized + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + """ + Hook called after each request completes (success or failure). + + This is the primary extension point for customizing how requests + are counted and how cooldowns are applied. + + Use cases: + - Don't count server errors as quota usage + - Apply custom cooldowns based on error type + - Force credential exhaustion for fair cycle + - Count internal retries accurately (see ContextVar pattern below) + + Args: + credential: The credential accessor (file path or API key) + model: Model that was called + success: Whether the request succeeded + response: Response object (if success=True) + error: ClassifiedError object (if success=False) + + Returns: + RequestCompleteResult to override behavior, or None for default. + + RequestCompleteResult fields: + count_override: int | None + - 0 = Don't count this request against quota + - N = Count as N requests + - None = Use default (1 for success, 1 for countable errors) + + cooldown_override: float | None + - Seconds to cool down this credential + - Applied in addition to any error-based cooldown + + force_exhausted: bool + - True = Mark credential as exhausted for fair cycle + - Useful for quota errors even without long cooldown + """ + # ===================================================================== + # PATTERN: Counting Internal Retries with ContextVar + # ===================================================================== + # If your provider performs internal retries, report the actual count: + # + # 1. At module level, define: + # _attempt_count: ContextVar[int] = ContextVar('my_attempt_count', default=1) + # + # 2. In your retry loop: + # _attempt_count.set(1) # Reset at start + # for attempt in range(max_attempts): + # try: + # return await api_call() + # except RetryableError: + # _attempt_count.set(_attempt_count.get() + 1) # Increment before retry + # continue + # + # 3. Here, report the count: + attempt_count = _example_attempt_count.get() + _example_attempt_count.set(1) # Reset for safety + + if attempt_count > 1: + lib_logger.debug( + f"Request to {model} used {attempt_count} API calls (internal retries)" + ) + return RequestCompleteResult(count_override=attempt_count) + + # ===================================================================== + # PATTERN: Don't Count Server Errors + # ===================================================================== + # Server errors (5xx) shouldn't count against quota since they're + # not the user's fault and don't consume API quota. + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type in ("server_error", "api_connection"): + lib_logger.debug( + f"Not counting {error_type} error against quota for {model}" + ) + return RequestCompleteResult(count_override=0) + + # ===================================================================== + # PATTERN: Custom Cooldown for Rate Limits + # ===================================================================== + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type == "rate_limit": + # Check for retry-after header + retry_after = getattr(error, "retry_after", None) + if retry_after and retry_after > 60: + # Long rate limit - mark as exhausted + return RequestCompleteResult( + cooldown_override=retry_after, + force_exhausted=True, + ) + elif retry_after: + # Short rate limit - just cooldown + return RequestCompleteResult(cooldown_override=retry_after) + + # ===================================================================== + # PATTERN: Force Exhaustion on Quota Exceeded + # ===================================================================== + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type == "quota_exceeded": + return RequestCompleteResult( + force_exhausted=True, + cooldown_override=3600.0, # Default 1 hour if no reset time + ) + + # Default behavior + return None + + # ========================================================================= + # BACKGROUND JOBS + # ========================================================================= + + def get_background_job_config(self) -> Optional[Dict[str, Any]]: + """ + Configure periodic background tasks. + + Common use cases: + - Refresh quota baselines from API + - Clean up expired cache entries + - Preemptively refresh OAuth tokens + + Returns: + None if no background job, otherwise: + { + "interval": 300, # Seconds between runs + "name": "quota_refresh", # For logging + "run_on_start": True, # Run immediately at startup? + } + """ + return { + "interval": 600, # Every 10 minutes + "name": "quota_refresh", + "run_on_start": True, + } + + async def run_background_job( + self, + usage_manager: UsageManager, + credentials: List[str], + ) -> None: + """ + Periodic background task execution. + + Called by BackgroundRefresher at the interval specified in + get_background_job_config(). + + Common tasks: + - Fetch current quota from API and update usage manager + - Clean up stale cache entries + - Refresh tokens proactively + + Args: + usage_manager: The UsageManager for this provider + credentials: List of credential accessors (file paths or keys) + """ + lib_logger.debug(f"Running background job for {self.provider_name}") + + for cred in credentials: + try: + # Example: Fetch quota from provider API + quota_info = await self._fetch_quota_from_api(cred) + + if quota_info: + for model, info in quota_info.items(): + # Update usage manager with fresh quota data + await usage_manager.update_quota_baseline( + accessor=cred, + model=model, + quota_max_requests=info.get("limit"), + quota_reset_ts=info.get("reset_ts"), + quota_used=info.get("used"), + quota_group=info.get("group"), + ) + + except Exception as e: + lib_logger.warning(f"Quota refresh failed for {cred}: {e}") + + async def _fetch_quota_from_api( + self, + credential: str, + ) -> Optional[Dict[str, Dict[str, Any]]]: + """ + Fetch current quota information from provider API. + + Override this with actual API calls for your provider. + + Returns: + Dict mapping model names to quota info: + { + "gpt-4o": { + "limit": 500, + "used": 123, + "reset_ts": 1735689600.0, + "group": "gpt4", # Optional + }, + ... + } + """ + # Placeholder - implement actual API call + return None + + # ========================================================================= + # TOKEN EXTRACTION + # ========================================================================= + + def _build_usage_from_response( + self, + response: Any, + ) -> Optional[Dict[str, Any]]: + """ + Build standardized usage dict from provider-specific response. + + The usage manager expects a standardized format. If your provider + returns a different format, convert it here. + + Standard format: + { + "prompt_tokens": int, # Input tokens + "completion_tokens": int, # Output tokens (content + thinking) + "total_tokens": int, # All tokens + + # Optional: Input breakdown + "prompt_tokens_details": { + "cached_tokens": int, # Cache read tokens + "cache_creation_tokens": int, # Cache write tokens + }, + + # Optional: Output breakdown + "completion_tokens_details": { + "reasoning_tokens": int, # Thinking/reasoning tokens + }, + + # Alternative top-level fields (some APIs use these) + "cache_read_tokens": int, + "cache_creation_tokens": int, + } + + Args: + response: Raw response from provider API + + Returns: + Standardized usage dict, or None if no usage data + """ + if not hasattr(response, "usage") or not response.usage: + return None + + # Example: Provider returns Gemini-style metadata + # Adapt this to your provider's format + usage = response.usage + + # Standard fields + result = { + "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage, "total_tokens", 0) or 0, + } + + # Example: Extract cached tokens from details + prompt_details = getattr(usage, "prompt_tokens_details", None) + if prompt_details: + if isinstance(prompt_details, dict): + cached = prompt_details.get("cached_tokens", 0) + cache_write = prompt_details.get("cache_creation_tokens", 0) + else: + cached = getattr(prompt_details, "cached_tokens", 0) + cache_write = getattr(prompt_details, "cache_creation_tokens", 0) + + if cached or cache_write: + result["prompt_tokens_details"] = {} + if cached: + result["prompt_tokens_details"]["cached_tokens"] = cached + if cache_write: + result["prompt_tokens_details"]["cache_creation_tokens"] = ( + cache_write + ) + + # Example: Extract thinking tokens from details + completion_details = getattr(usage, "completion_tokens_details", None) + if completion_details: + if isinstance(completion_details, dict): + reasoning = completion_details.get("reasoning_tokens", 0) + else: + reasoning = getattr(completion_details, "reasoning_tokens", 0) + + if reasoning: + result["completion_tokens_details"] = {"reasoning_tokens": reasoning} + + return result + + +# ============================================================================= +# CUSTOM WINDOWS +# ============================================================================= +# +# To add custom usage windows, you have two options: +# +# OPTION 1: Override windows via provider config (recommended) +# ------------------------------------------------------------ +# Add class attribute to your provider: +# +# default_windows = [ +# WindowDefinition.rolling("1h", 3600, is_primary=False), +# WindowDefinition.rolling("6h", 21600, is_primary=True), +# WindowDefinition.daily("daily"), +# WindowDefinition.total("total"), +# ] +# +# WindowDefinition options: +# - name: str - Window identifier (e.g., "1h", "daily") +# - duration_seconds: int | None - Window duration (None for "total") +# - reset_mode: ResetMode - How window resets +# - ROLLING: Continuous sliding window +# - FIXED_DAILY: Reset at specific UTC time +# - CALENDAR_WEEKLY: Reset at week start +# - CALENDAR_MONTHLY: Reset at month start +# - API_AUTHORITATIVE: Provider determines reset +# - is_primary: bool - Used for rotation decisions +# - applies_to: str - Scope of window +# - "credential": Global per-credential +# - "model": Per-model per-credential +# - "group": Per-quota-group per-credential +# +# OPTION 2: Build config manually in RotatingClient +# ------------------------------------------------- +# In your client initialization: +# +# from rotator_library.usage.config import ( +# ProviderUsageConfig, +# WindowDefinition, +# FairCycleConfig, +# ) +# from rotator_library.usage.types import RotationMode, ResetMode +# +# config = ProviderUsageConfig( +# rotation_mode=RotationMode.SEQUENTIAL, +# windows=[ +# WindowDefinition( +# name="1h", +# duration_seconds=3600, +# reset_mode=ResetMode.ROLLING, +# is_primary=False, +# applies_to="model", +# ), +# WindowDefinition( +# name="6h", +# duration_seconds=21600, +# reset_mode=ResetMode.ROLLING, +# is_primary=True, # Primary for rotation +# applies_to="group", # Track per quota group +# ), +# ], +# fair_cycle=FairCycleConfig( +# enabled=True, +# tracking_mode=TrackingMode.MODEL_GROUP, +# ), +# ) +# +# manager = UsageManager( +# provider="example", +# config=config, +# file_path="usage_example.json", +# ) +# +# ============================================================================= + + +# ============================================================================= +# REGISTERING YOUR PROVIDER +# ============================================================================= +# +# To register your provider with the system: +# +# 1. Add to PROVIDER_PLUGINS dict in src/rotator_library/providers/__init__.py: +# +# from .example_provider import ExampleProvider +# +# PROVIDER_PLUGINS = { +# ... +# "example": ExampleProvider, +# } +# +# 2. Add credential discovery in RotatingClient if using OAuth: +# +# # In _discover_oauth_credentials: +# if provider == "example": +# creds = self._discover_example_credentials() +# +# 3. Configure via environment variables: +# +# # API key credentials +# EXAMPLE_API_KEY=sk-xxx +# EXAMPLE_API_KEY_2=sk-yyy +# +# # OAuth credential paths +# EXAMPLE_OAUTH_PATHS=./creds/example_*.json +# +# # Priority/tier assignment +# PRIORITY_EXAMPLE_CRED1=1 +# TIER_EXAMPLE_CRED2=standard-tier +# +# # Quota group overrides +# QUOTA_GROUPS_EXAMPLE_GPT4=gpt-4o,gpt-4o-mini,gpt-4-turbo +# +# ============================================================================= + + +# ============================================================================= +# ACCESSING USAGE DATA +# ============================================================================= +# +# The usage manager exposes data through several methods: +# +# 1. Get availability stats (for UI/monitoring): +# +# stats = await usage_manager.get_availability_stats(model, quota_group) +# # Returns: { +# # "total": 10, +# # "available": 7, +# # "blocked_by": {"cooldowns": 2, "fair_cycle": 1}, +# # "rotation_mode": "sequential", +# # } +# +# 2. Get comprehensive stats (for quota-stats endpoint): +# +# stats = await usage_manager.get_stats_for_endpoint() +# # Returns full credential/model/group breakdown +# +# 3. Direct state access (for advanced use): +# +# # Get credential state +# state = usage_manager.states.get(stable_id) +# +# # Access usage at different scopes +# global_usage = state.usage +# model_usage = state.model_usage.get("gpt-4o") +# group_usage = state.group_usage.get("gpt4") +# +# # Check cooldowns +# cooldown = state.get_cooldown("gpt4") +# if cooldown and cooldown.is_active: +# print(f"Cooldown remaining: {cooldown.remaining_seconds}s") +# +# # Check fair cycle +# fc = state.fair_cycle.get("gpt4") +# if fc and fc.exhausted: +# print(f"Exhausted at: {fc.exhausted_at}") +# +# 4. Update quota baseline (from API response): +# +# await usage_manager.update_quota_baseline( +# accessor=credential, +# model="gpt-4o", +# quota_max_requests=500, +# quota_reset_ts=time.time() + 3600, +# quota_used=123, +# quota_group="gpt4", +# ) +# +# ============================================================================= diff --git a/src/rotator_library/providers/firmware_provider.py b/src/rotator_library/providers/firmware_provider.py index e71316fa..2b94bd8b 100644 --- a/src/rotator_library/providers/firmware_provider.py +++ b/src/rotator_library/providers/firmware_provider.py @@ -19,7 +19,7 @@ from .utilities.firmware_quota_tracker import FirmwareQuotaTracker if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager import logging @@ -184,12 +184,23 @@ async def refresh_single_credential( # Store baseline in usage manager # Since Firmware.ai uses credential-level quota, we use a virtual model name + if remaining_fraction <= 0.0 and reset_ts: + stable_id = usage_manager.registry.get_stable_id( + api_key, usage_manager.provider + ) + state = usage_manager.states.get(stable_id) + if state: + await usage_manager.tracking.apply_cooldown( + state=state, + reason="quota_exhausted", + until=reset_ts, + model_or_group="firmware/_quota", + source="api_quota", + ) await usage_manager.update_quota_baseline( api_key, "firmware/_quota", # Virtual model for credential-level tracking - remaining_fraction, - # No max_requests - Firmware.ai doesn't expose this - reset_timestamp=reset_ts, + quota_reset_ts=reset_ts, ) lib_logger.debug( @@ -199,7 +210,9 @@ async def refresh_single_credential( ) except Exception as e: - lib_logger.warning(f"Failed to refresh Firmware.ai quota usage: {e}") + lib_logger.warning( + f"Failed to refresh Firmware.ai quota usage: {e}" + ) # Fetch all credentials in parallel with shared HTTP client async with httpx.AsyncClient(timeout=30.0) as client: diff --git a/src/rotator_library/providers/gemini_auth_base.py b/src/rotator_library/providers/gemini_auth_base.py index e07d09d9..be0f27b3 100644 --- a/src/rotator_library/providers/gemini_auth_base.py +++ b/src/rotator_library/providers/gemini_auth_base.py @@ -13,7 +13,23 @@ import httpx from .google_oauth_base import GoogleOAuthBase -from .utilities.gemini_shared_utils import CODE_ASSIST_ENDPOINT +from .utilities.gemini_shared_utils import ( + CODE_ASSIST_ENDPOINT, + # Tier utilities + normalize_tier_name, + is_free_tier, + is_paid_tier, + get_tier_full_name, + # Project ID extraction + extract_project_id_from_response, + # Credential loading helpers + load_persisted_project_metadata, + # Env file helpers + build_project_tier_env_lines, +) + +# Service Usage API for checking enabled APIs +SERVICE_USAGE_API = "https://serviceusage.googleapis.com/v1" lib_logger = logging.getLogger("rotator_library") @@ -33,7 +49,7 @@ # # Source: gemini-cli/packages/core/src/code_assist/server.ts:284-290 GEMINI_CLI_AUTH_HEADERS = { - "User-Agent": "GeminiCLI/0.26.0 (win32; x64)", + "User-Agent": "GeminiCLI/0.28.0 (win32; x64)", # ------------------------------------------------------------------------- # COMMENTED OUT - Not sent by native gemini-cli for OAuth/Code Assist path # ------------------------------------------------------------------------- @@ -41,7 +57,7 @@ # "Client-Metadata": ( # Sent in body, not as header # "ideType=IDE_UNSPECIFIED," # "pluginType=GEMINI," - # "ideVersion=0.26.0," + # "ideVersion=0.28.0," # "platform=WINDOWS_AMD64," # "updateChannel=stable" # ), @@ -76,6 +92,147 @@ def __init__(self): # Project and tier caches - shared between auth base and provider self.project_id_cache: Dict[str, str] = {} self.project_tier_cache: Dict[str, str] = {} + self.tier_full_cache: Dict[str, str] = {} # Full tier names for display + self.tier_full_cache: Dict[str, str] = {} # Full tier names for display + + # ========================================================================= + # GCP PROJECT SCANNING + # ========================================================================= + + async def _scan_gcp_projects_for_code_assist( + self, access_token: str, headers: Dict[str, str] + ) -> Optional[tuple]: + """ + Scan GCP projects to find one with cloudaicompanion.googleapis.com enabled. + + This is used as a fallback when loadCodeAssist doesn't return a project + (e.g., for accounts with manually created projects that have Code Assist enabled). + + Args: + access_token: Valid OAuth access token + headers: Request headers for Code Assist API calls + + Returns: + Tuple of (project_id, tier) if found, or (None, None) if not found + """ + lib_logger.debug("Scanning GCP projects for Code Assist API...") + + async with httpx.AsyncClient() as client: + # Step 1: List all active GCP projects + try: + response = await client.get( + "https://cloudresourcemanager.googleapis.com/v1/projects", + headers={"Authorization": f"Bearer {access_token}"}, + timeout=20, + ) + if response.status_code != 200: + lib_logger.debug( + f"Failed to list GCP projects: {response.status_code}" + ) + return None, None + + projects = [ + p + for p in response.json().get("projects", []) + if p.get("lifecycleState") == "ACTIVE" + ] + lib_logger.debug(f"Found {len(projects)} active GCP projects") + + if not projects: + return None, None + + except Exception as e: + lib_logger.debug(f"Error listing GCP projects: {e}") + return None, None + + # Step 2: Check which projects have cloudaicompanion.googleapis.com enabled + candidate_projects = [] + for project in projects: + project_id = project.get("projectId") + service_url = f"{SERVICE_USAGE_API}/projects/{project_id}/services/cloudaicompanion.googleapis.com" + + try: + svc_response = await client.get( + service_url, + headers={"Authorization": f"Bearer {access_token}"}, + timeout=10, + ) + if svc_response.status_code == 200: + state = svc_response.json().get("state", "") + if state == "ENABLED": + lib_logger.debug( + f"Project '{project_id}' has cloudaicompanion.googleapis.com ENABLED" + ) + candidate_projects.append(project_id) + else: + lib_logger.debug( + f"Project '{project_id}' cloudaicompanion state: {state}" + ) + except Exception as e: + lib_logger.debug( + f"Error checking cloudaicompanion API for '{project_id}': {e}" + ) + + if not candidate_projects: + lib_logger.debug( + "No GCP projects with cloudaicompanion.googleapis.com enabled found" + ) + return None, None + + # Step 3: Test candidate projects with loadCodeAssist to verify and get tier + lib_logger.debug( + f"Testing {len(candidate_projects)} candidate projects with loadCodeAssist..." + ) + + for project_id in candidate_projects: + try: + test_request = { + "cloudaicompanionProject": project_id, + "metadata": { + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + "duetProject": project_id, + }, + } + + response = await client.post( + f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist", + headers=headers, + json=test_request, + timeout=15, + ) + + if response.status_code == 200: + data = response.json() + current_tier = data.get("currentTier", {}) + paid_tier = data.get("paidTier", {}) + + # Determine effective tier (paidTier > currentTier) + effective_tier_id = None + if paid_tier and paid_tier.get("id"): + effective_tier_id = paid_tier.get("id") + elif current_tier and current_tier.get("id"): + effective_tier_id = current_tier.get("id") + + if effective_tier_id: + canonical_tier = ( + normalize_tier_name(effective_tier_id) + or effective_tier_id + ) + lib_logger.info( + f"Found Code Assist project via GCP scan: {project_id} (tier={canonical_tier})" + ) + # Return raw tier ID for full name lookup, not canonical + return project_id, effective_tier_id + + except Exception as e: + lib_logger.debug( + f"Error testing project '{project_id}' with loadCodeAssist: {e}" + ) + + lib_logger.debug("No working Code Assist projects found via GCP scan") + return None, None # ========================================================================= # POST-AUTH DISCOVERY HOOK @@ -115,7 +272,25 @@ async def _post_auth_discovery( credential_path, access_token, litellm_params={} ) - tier = self.project_tier_cache.get(credential_path, "unknown") + # Use full tier name for post-auth log (one-time display) + tier_full = self.tier_full_cache.get(credential_path) + tier = tier_full or self.project_tier_cache.get(credential_path, "unknown") + lib_logger.info( + f"Post-auth discovery complete for {Path(credential_path).name}: " + f"tier={tier}, project={project_id}" + ) + + # Use full tier name for post-auth log (one-time display) + tier_full = self.tier_full_cache.get(credential_path) + tier = tier_full or self.project_tier_cache.get(credential_path, "unknown") + lib_logger.info( + f"Post-auth discovery complete for {Path(credential_path).name}: " + f"tier={tier}, project={project_id}" + ) + + # Use full tier name for post-auth log (one-time display) + tier_full = self.tier_full_cache.get(credential_path) + tier = tier_full or self.project_tier_cache.get(credential_path, "unknown") lib_logger.info( f"Post-auth discovery complete for {Path(credential_path).name}: " f"tier={tier}, project={project_id}" @@ -166,53 +341,16 @@ async def _discover_project_id( # Load credentials to check for persisted/configured project_id and tier credential_index = self._parse_env_credential_path(credential_path) - if credential_index is None: - # File-based credentials: load from file - try: - with open(credential_path, "r") as f: - creds = json.load(f) - - metadata = creds.get("_proxy_metadata", {}) - persisted_project_id = metadata.get("project_id") - persisted_tier = metadata.get("tier") - - if persisted_project_id: - lib_logger.debug( - f"Loaded persisted project ID from credential file: {persisted_project_id}" - ) - self.project_id_cache[credential_path] = persisted_project_id - - # Also load tier if available - if persisted_tier: - self.project_tier_cache[credential_path] = persisted_tier - lib_logger.debug(f"Loaded persisted tier: {persisted_tier}") - - return persisted_project_id - except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: - lib_logger.debug(f"Could not load persisted project ID from file: {e}") - else: - # Env-based credentials: load from credentials cache - # The credentials were already loaded by _load_from_env() which reads - # {PREFIX}_{N}_PROJECT_ID and {PREFIX}_{N}_TIER into _proxy_metadata - if credential_path in self._credentials_cache: - creds = self._credentials_cache[credential_path] - metadata = creds.get("_proxy_metadata", {}) - env_project_id = metadata.get("project_id") - env_tier = metadata.get("tier") - - if env_project_id: - lib_logger.debug( - f"Loaded project ID from env credential metadata: {env_project_id}" - ) - self.project_id_cache[credential_path] = env_project_id - - if env_tier: - self.project_tier_cache[credential_path] = env_tier - lib_logger.debug( - f"Loaded tier from env credential metadata: {env_tier}" - ) - - return env_project_id + persisted_project_id = load_persisted_project_metadata( + credential_path, + credential_index, + self._credentials_cache, + self.project_id_cache, + self.project_tier_cache, + self.tier_full_cache, + ) + if persisted_project_id: + return persisted_project_id lib_logger.debug( "No cached or configured project ID found, initiating discovery..." @@ -225,6 +363,8 @@ async def _discover_project_id( discovered_project_id = None discovered_tier = None + discovered_tier_full = None + discovered_tier_full = None async with httpx.AsyncClient() as client: # 1. Try discovery endpoint with loadCodeAssist @@ -241,11 +381,13 @@ async def _discover_project_id( if configured_project_id: core_client_metadata["duetProject"] = configured_project_id - # Build load request - pass configured_project_id if available, otherwise None + # Build load request - only include cloudaicompanionProject if configured + # Native CLI omits this field entirely when no project is configured load_request = { - "cloudaicompanionProject": configured_project_id, # Can be None "metadata": core_client_metadata, } + if configured_project_id: + load_request["cloudaicompanionProject"] = configured_project_id lib_logger.debug( f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}" @@ -265,10 +407,15 @@ async def _discover_project_id( ) # Extract and log ALL tier information for debugging + # Canonical prioritizes paidTier over currentTier for accurate subscription detection allowed_tiers = data.get("allowedTiers", []) current_tier = data.get("currentTier") + paid_tier = data.get( + "paidTier" + ) # Added: Canonical-style tier detection lib_logger.debug(f"=== Tier Information ===") + lib_logger.debug(f"paidTier: {paid_tier}") lib_logger.debug(f"currentTier: {current_tier}") lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}") for i, tier in enumerate(allowed_tiers): @@ -280,83 +427,92 @@ async def _discover_project_id( ) lib_logger.debug(f"========================") - # Determine the current tier ID - current_tier_id = None - if current_tier: - current_tier_id = current_tier.get("id") - lib_logger.debug(f"User has currentTier: {current_tier_id}") - - # Check if user is already known to server (has currentTier) - if current_tier_id: - # User is already onboarded - check for project from server - server_project = data.get("cloudaicompanionProject") - - # Check if this tier requires user-defined project (paid tiers) - requires_user_project = any( - t.get("id") == current_tier_id - and t.get("userDefinedCloudaicompanionProject", False) - for t in allowed_tiers + # Determine tier ID with Canonical-style priority: paidTier > currentTier + # This matches quota.rs:88-91 logic for accurate subscription detection + effective_tier_id = None + if paid_tier and paid_tier.get("id"): + effective_tier_id = paid_tier.get("id") + lib_logger.debug(f"Using paidTier: {effective_tier_id}") + elif current_tier and current_tier.get("id"): + effective_tier_id = current_tier.get("id") + lib_logger.debug( + f"Using currentTier (paidTier not available): {effective_tier_id}" + ) + + # Normalize to canonical tier name (ULTRA, PRO, FREE) + if effective_tier_id: + canonical_tier = normalize_tier_name(effective_tier_id) + lib_logger.debug( + f"Canonical tier: {canonical_tier} (from {effective_tier_id})" ) - is_free_tier = current_tier_id == "free-tier" + # Check if user is already known to server (has tier info) + if effective_tier_id: + # User has tier info - check for project from server + # Use helper to handle both string and object formats + server_project = extract_project_id_from_response(data) + + # Project selection: server > configured > GCP scan > onboarding if server_project: - # Server returned a project - use it (server wins) - # This is the normal case for FREE tier users project_id = server_project lib_logger.debug(f"Server returned project: {project_id}") elif configured_project_id: - # No server project but we have configured one - use it - # This is the PAID TIER case where server doesn't return a project project_id = configured_project_id + lib_logger.debug(f"Using configured project: {project_id}") + else: + # No project from server or config - try scanning GCP projects + # This handles accounts with manually created Code Assist projects lib_logger.debug( - f"No server project, using configured: {project_id}" - ) - elif is_free_tier: - # Free tier user without server project - this shouldn't happen normally - # but let's not fail, just proceed to onboarding - lib_logger.debug( - "Free tier user with currentTier but no project - will try onboarding" - ) - project_id = None - elif requires_user_project: - # Paid tier requires a project ID to be set - raise ValueError( - f"Paid tier '{current_tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. " - "See https://goo.gle/gemini-cli-auth-docs#workspace-gca" + f"Tier '{effective_tier_id}' detected but no project - scanning GCP projects..." ) - else: - # Unknown tier without project - proceed carefully - lib_logger.warning( - f"Tier '{current_tier_id}' has no project and none configured - will try onboarding" + ( + scanned_project, + scanned_tier, + ) = await self._scan_gcp_projects_for_code_assist( + access_token, headers ) - project_id = None - - if project_id: - # Cache tier info - self.project_tier_cache[credential_path] = current_tier_id - discovered_tier = current_tier_id - - # Log appropriately based on tier - is_paid = current_tier_id and current_tier_id not in [ - "free-tier", - "legacy-tier", - "unknown", - ] - if is_paid: + if scanned_project: + project_id = scanned_project + # Use scanned tier if available, otherwise use what we have + if scanned_tier: + effective_tier_id = scanned_tier lib_logger.info( - f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}" + f"Discovered project via GCP scan: {project_id}" ) else: - lib_logger.info( - f"Discovered Gemini project ID via loadCodeAssist: {project_id}" + # No project found via GCP scan either - will fall through to onboarding + lib_logger.debug( + f"No Code Assist project found via GCP scan - will try onboarding" ) + project_id = None + + if project_id: + # Cache tier info - use canonical tier name for consistency + canonical = ( + normalize_tier_name(effective_tier_id) or effective_tier_id + ) + self.project_tier_cache[credential_path] = canonical + discovered_tier = canonical + + # Get and cache full tier name for display + tier_full = get_tier_full_name(effective_tier_id) + self.tier_full_cache[credential_path] = tier_full + discovered_tier_full = tier_full + + # Log with full tier name for discovery messages + lib_logger.info( + f"Discovered Gemini tier '{tier_full}' with project: {project_id}" + ) self.project_id_cache[credential_path] = project_id discovered_project_id = project_id # Persist to credential file await self._persist_project_metadata( - credential_path, project_id, discovered_tier + credential_path, + project_id, + discovered_tier, + discovered_tier_full, ) return project_id @@ -397,39 +553,51 @@ async def _discover_project_id( ) # Build onboard request based on tier type (following official CLI logic) - # FREE tier: cloudaicompanionProject = None (server-managed) - # PAID tier: cloudaicompanionProject = configured_project_id (user must provide) - is_free_tier = tier_id == "free-tier" - - if is_free_tier: - # Free tier uses server-managed project - onboard_request = { - "tierId": tier_id, - "cloudaicompanionProject": None, # Server will create/manage - "metadata": core_client_metadata, - } + # For ALL tiers (free and paid): cloudaicompanionProject can be None + # The server will create a project automatically if none is provided + # If user has configured a project, use it; otherwise let server decide + tier_is_free = is_free_tier(tier_id) + + # For paid tiers, first try to find an existing Code Assist project + onboard_project_id = configured_project_id + if not tier_is_free and not onboard_project_id: lib_logger.debug( - "Free tier onboarding: using server-managed project" + "Paid tier with no configured project - checking for existing Code Assist projects..." ) - else: - # Paid/legacy tier requires user-provided project - if not configured_project_id and requires_user_project: - raise ValueError( - f"Tier '{tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. " - "See https://goo.gle/gemini-cli-auth-docs#workspace-gca" + scanned_project, _ = await self._scan_gcp_projects_for_code_assist( + access_token, headers + ) + if scanned_project: + onboard_project_id = scanned_project + lib_logger.info( + f"Found existing Code Assist project for onboarding: {scanned_project}" ) - onboard_request = { - "tierId": tier_id, - "cloudaicompanionProject": configured_project_id, - "metadata": { - **core_client_metadata, - "duetProject": configured_project_id, - } - if configured_project_id - else core_client_metadata, - } + else: + lib_logger.debug( + "No existing Code Assist project found - server will create one" + ) + + # Build onboard request - only include cloudaicompanionProject if we have one + # CRITICAL: For free-tier, cloudaicompanionProject MUST be omitted entirely. + # Setting it to null causes 412 Precondition Failed error. + # Native CLI behavior: field is undefined (omitted) for free tier. + onboard_request = { + "tierId": tier_id, + "metadata": core_client_metadata.copy(), + } + + # Only add cloudaicompanionProject and duetProject if we have a project ID + if onboard_project_id: + onboard_request["cloudaicompanionProject"] = onboard_project_id + onboard_request["metadata"]["duetProject"] = onboard_project_id + + if onboard_project_id: lib_logger.debug( - f"Paid tier onboarding: using project {configured_project_id}" + f"Onboarding with user-provided project: {onboard_project_id}" + ) + else: + lib_logger.debug( + "Onboarding with server-managed project (will be created by server)" ) lib_logger.debug("Initiating onboardUser request...") @@ -474,15 +642,10 @@ async def _discover_project_id( "Onboarding process timed out after 5 minutes. Please try again or contact support." ) - # Extract project ID from LRO response - # Note: onboardUser returns response.cloudaicompanionProject as an object with .id + # Extract project ID from LRO response using helper + # This handles both string and object formats for cloudaicompanionProject lro_response_data = lro_data.get("response", {}) - lro_project_obj = lro_response_data.get("cloudaicompanionProject", {}) - project_id = ( - lro_project_obj.get("id") - if isinstance(lro_project_obj, dict) - else None - ) + project_id = extract_project_id_from_response(lro_response_data) # Fallback to configured project if LRO didn't return one if not project_id and configured_project_id: @@ -504,28 +667,30 @@ async def _discover_project_id( f"Successfully extracted project ID from onboarding response: {project_id}" ) - # Cache tier info - self.project_tier_cache[credential_path] = tier_id - discovered_tier = tier_id - lib_logger.debug(f"Cached tier information: {tier_id}") + # Cache tier info - use canonical tier name for consistency + canonical_tier = normalize_tier_name(tier_id) or tier_id + self.project_tier_cache[credential_path] = canonical_tier + discovered_tier = canonical_tier - # Log concise message for paid projects - is_paid = tier_id and tier_id not in ["free-tier", "legacy-tier"] - if is_paid: - lib_logger.info( - f"Using Gemini paid tier '{tier_id}' with project: {project_id}" - ) - else: - lib_logger.info( - f"Successfully onboarded user and discovered project ID: {project_id}" - ) + # Get and cache full tier name for display + tier_full = get_tier_full_name(tier_id) + self.tier_full_cache[credential_path] = tier_full + discovered_tier_full = tier_full + lib_logger.debug( + f"Cached tier information: {canonical_tier} (full: {tier_full})" + ) + + # Log with full tier name for onboarding messages + lib_logger.info( + f"Onboarded Gemini credential with tier '{tier_full}', project: {project_id}" + ) self.project_id_cache[credential_path] = project_id discovered_project_id = project_id # Persist to credential file await self._persist_project_metadata( - credential_path, project_id, discovered_tier + credential_path, project_id, discovered_tier, discovered_tier_full ) return project_id @@ -628,43 +793,6 @@ async def _discover_project_id( "To manually specify a project, set GEMINI_CLI_PROJECT_ID in your .env file." ) - async def _persist_project_metadata( - self, credential_path: str, project_id: str, tier: Optional[str] - ): - """Persists project ID and tier to the credential file for faster future startups.""" - # Skip persistence for env:// paths (environment-based credentials) - credential_index = self._parse_env_credential_path(credential_path) - if credential_index is not None: - lib_logger.debug( - f"Skipping project metadata persistence for env:// credential path: {credential_path}" - ) - return - - try: - # Load current credentials - with open(credential_path, "r") as f: - creds = json.load(f) - - # Update metadata - if "_proxy_metadata" not in creds: - creds["_proxy_metadata"] = {} - - creds["_proxy_metadata"]["project_id"] = project_id - if tier: - creds["_proxy_metadata"]["tier"] = tier - - # Save back using the existing save method (handles atomic writes and permissions) - await self._save_credentials(credential_path, creds) - - lib_logger.debug( - f"Persisted project_id and tier to credential file: {credential_path}" - ) - except Exception as e: - lib_logger.warning( - f"Failed to persist project metadata to credential file: {e}" - ) - # Non-fatal - just means slower startup next time - # ========================================================================= # CREDENTIAL MANAGEMENT OVERRIDES # ========================================================================= @@ -682,16 +810,7 @@ def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]: # Get base lines from parent class lines = super().build_env_lines(creds, cred_number) - # Add Gemini-specific fields (tier and project_id) - metadata = creds.get("_proxy_metadata", {}) - prefix = f"{self.ENV_PREFIX}_{cred_number}" - - project_id = metadata.get("project_id", "") - tier = metadata.get("tier", "") - - if project_id: - lines.append(f"{prefix}_PROJECT_ID={project_id}") - if tier: - lines.append(f"{prefix}_TIER={tier}") + # Add project_id and tier using shared helper + lines.extend(build_project_tier_env_lines(creds, self.ENV_PREFIX, cred_number)) return lines diff --git a/src/rotator_library/providers/gemini_cli_provider.py b/src/rotator_library/providers/gemini_cli_provider.py index 944cae82..e8562a66 100644 --- a/src/rotator_library/providers/gemini_cli_provider.py +++ b/src/rotator_library/providers/gemini_cli_provider.py @@ -24,6 +24,9 @@ FINISH_REASON_MAP, CODE_ASSIST_ENDPOINT, GEMINI_CLI_ENDPOINT_FALLBACKS, + # Tier utilities + TIER_PRIORITIES, + DEFAULT_TIER_PRIORITY, ) from ..transaction_logger import ProviderLogger from .utilities.gemini_tool_handler import GeminiToolHandler @@ -146,22 +149,12 @@ class GeminiCliProvider( # Provider name for env var lookups (QUOTA_GROUPS_GEMINI_CLI_*) provider_env_name: str = "gemini_cli" - # Tier name -> priority mapping (Single Source of Truth) - # Same tier names as Antigravity (coincidentally), but defined separately - tier_priorities = { - # Priority 1: Highest paid tier (Google AI Ultra - name unconfirmed) - # "google-ai-ultra": 1, # Uncomment when tier name is confirmed - # Priority 2: Standard paid tier - "standard-tier": 2, - # Priority 3: Free tier - "free-tier": 3, - # Priority 10: Legacy/Unknown (lowest) - "legacy-tier": 10, - "unknown": 10, - } + # Tier name -> priority mapping (from centralized tier utilities) + # Lower numbers = higher priority (ULTRA=1 > PRO=2 > FREE=3) + tier_priorities = TIER_PRIORITIES # Default priority for tiers not in the mapping - default_tier_priority: int = 10 + default_tier_priority: int = DEFAULT_TIER_PRIORITY # Usage reset configs for Gemini CLI # Verified 2026-01-07: 24-hour fixed window from first request for ALL tiers @@ -193,11 +186,11 @@ class GeminiCliProvider( # Priority 1 (paid ultra): 5x concurrent requests # Priority 2 (standard paid): 3x concurrent requests # Others: Use sequential fallback (2x) or balanced default (1x) - default_priority_multipliers = {1: 5, 2: 3} + default_priority_multipliers = {1: 2, 2: 1} # For sequential mode, lower priority tiers still get 2x to maintain stickiness # For balanced mode, this doesn't apply (falls back to 1x) - default_sequential_fallback_multiplier = 2 + default_sequential_fallback_multiplier = 1 @staticmethod def parse_quota_error( @@ -413,7 +406,6 @@ def __init__(self): self._learned_costs: Dict[str, Dict[str, float]] = {} self._learned_costs_loaded: bool = False - # ========================================================================= # CREDENTIAL TIER LOOKUP (Provider-specific - uses cache) # ========================================================================= @@ -536,7 +528,7 @@ def _get_gemini_cli_request_headers(self, model: str) -> Dict[str, str]: # Hardcoded to Windows x64 platform (matching common development environment) # Native format: GeminiCLI/${version}/${model} (${platform}; ${arch}) - user_agent = f"GeminiCLI/0.26.0/{model_name} (win32; x64)" + user_agent = f"GeminiCLI/0.28.0/{model_name} (win32; x64)" # ========================================================================= # COMMENTED OUT HEADERS - Not sent by native gemini-cli for Code Assist path @@ -554,7 +546,7 @@ def _get_gemini_cli_request_headers(self, model: str) -> Dict[str, str]: # client_metadata = ( # "ideType=IDE_UNSPECIFIED," # "pluginType=GEMINI," - # "ideVersion=0.26.0," + # "ideVersion=0.28.0," # "platform=WINDOWS_AMD64," # "updateChannel=stable" # ) diff --git a/src/rotator_library/providers/google_oauth_base.py b/src/rotator_library/providers/google_oauth_base.py index 0ac75590..a0a836be 100644 --- a/src/rotator_library/providers/google_oauth_base.py +++ b/src/rotator_library/providers/google_oauth_base.py @@ -214,26 +214,20 @@ def __init__(self): str, float ] = {} # Track backoff timers (Unix timestamp) - # [QUEUE SYSTEM] Sequential refresh processing with two separate queues + # [QUEUE SYSTEM] Sequential refresh processing # Normal refresh queue: for proactive token refresh (old token still valid) self._refresh_queue: asyncio.Queue = asyncio.Queue() self._queue_processor_task: Optional[asyncio.Task] = None - # Re-auth queue: for invalid refresh tokens (requires user interaction) - self._reauth_queue: asyncio.Queue = asyncio.Queue() - self._reauth_processor_task: Optional[asyncio.Task] = None - # Tracking sets/dicts - self._queued_credentials: set = set() # Track credentials in either queue - # Only credentials in re-auth queue are marked unavailable (not normal refresh) - # TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes - self._unavailable_credentials: Dict[ - str, float - ] = {} # Maps credential path -> timestamp when marked unavailable - # TTL should exceed reauth timeout (300s) to avoid premature cleanup - self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries + self._queued_credentials: set = set() # Track credentials in refresh queue self._queue_tracking_lock = asyncio.Lock() # Protects queue sets + # [PERMANENTLY EXPIRED] Track credentials that have been permanently removed from rotation + # These credentials have invalid/revoked refresh tokens and require manual re-authentication + # via credential_tool.py. They will NOT be selected for rotation until proxy restart. + self._permanently_expired_credentials: set = set() + # Retry tracking for normal refresh queue self._queue_retry_count: Dict[ str, int @@ -243,7 +237,6 @@ def __init__(self): self._refresh_timeout_seconds: int = 15 # Max time for single refresh self._refresh_interval_seconds: int = 30 # Delay between queue items self._refresh_max_retries: int = 3 # Attempts before kicked out - self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth def _parse_env_credential_path(self, path: str) -> Optional[str]: """ @@ -483,39 +476,31 @@ async def _refresh_token( status_code = e.response.status_code error_body = e.response.text - # [INVALID GRANT HANDLING] Handle 400/401/403 by queuing for re-auth - # We must NOT call initialize_token from here as we hold a lock (would deadlock) + # [INVALID GRANT HANDLING] Handle 400/401/403 by marking as expired + # These errors indicate the refresh token is invalid/revoked + # Mark as permanently expired - no interactive re-auth during proxy operation if status_code == 400: # Check if this is an invalid_grant error if "invalid_grant" in error_body.lower(): - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP 400: invalid_grant). " - f"Queued for re-authentication, rotating to next credential." - ) - asyncio.create_task( - self._queue_refresh( - path, force=True, needs_reauth=True - ) + self._mark_credential_expired( + path, + f"Refresh token invalid (HTTP 400: invalid_grant)", ) raise CredentialNeedsReauthError( credential_path=path, - message=f"Refresh token invalid for '{Path(path).name}'. Re-auth queued.", + message=f"Refresh token invalid for '{Path(path).name}'. Credential removed from rotation.", ) else: # Other 400 error - raise it raise elif status_code in (401, 403): - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP {status_code}). " - f"Queued for re-authentication, rotating to next credential." - ) - asyncio.create_task( - self._queue_refresh(path, force=True, needs_reauth=True) + self._mark_credential_expired( + path, f"Credential unauthorized (HTTP {status_code})" ) raise CredentialNeedsReauthError( credential_path=path, - message=f"Token invalid for '{Path(path).name}' (HTTP {status_code}). Re-auth queued.", + message=f"Token invalid for '{Path(path).name}' (HTTP {status_code}). Credential removed from rotation.", ) elif status_code == 429: @@ -634,7 +619,7 @@ async def proactively_refresh(self, credential_path: str): creds = await self._load_credentials(credential_path) if self._is_token_expired(creds): # lib_logger.info(f"Proactive refresh triggered for '{Path(credential_path).name}'") - await self._queue_refresh(credential_path, force=False, needs_reauth=False) + await self._queue_refresh(credential_path, force=False) async def _get_lock(self, path: str) -> asyncio.Lock: # [FIX RACE CONDITION] Protect lock creation with a master lock @@ -657,40 +642,65 @@ def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool: expiry_timestamp = time.mktime(time.strptime(expiry, "%Y-%m-%dT%H:%M:%SZ")) return expiry_timestamp < time.time() + def _mark_credential_expired(self, path: str, reason: str) -> None: + """ + Permanently mark a credential as expired and remove it from rotation. + + This is called when a credential's refresh token is invalid or revoked, + meaning normal token refresh cannot work. The credential is removed from + rotation entirely and requires manual re-authentication via credential_tool.py. + + The proxy must be restarted after fixing the credential. + + Args: + path: Credential file path or env:// path + reason: Human-readable reason for expiration (e.g., "invalid_grant", "HTTP 401") + """ + # Add to permanently expired set + self._permanently_expired_credentials.add(path) + + # Clean up other tracking structures + self._queued_credentials.discard(path) + + # Get display name + if path.startswith("env://"): + display_name = path + else: + display_name = Path(path).name + + # Rich-formatted output for high visibility + console.print( + Panel( + f"[bold red]Credential:[/bold red] {display_name}\n" + f"[bold red]Reason:[/bold red] {reason}\n\n" + f"[yellow]This credential has been removed from rotation.[/yellow]\n" + f"[yellow]To fix: Run 'python credential_tool.py' to re-authenticate,[/yellow]\n" + f"[yellow]then restart the proxy.[/yellow]", + title="[bold red]⚠ CREDENTIAL EXPIRED - REMOVED FROM ROTATION[/bold red]", + border_style="red", + ) + ) + + # Also log at ERROR level for log files + lib_logger.error( + f"CREDENTIAL EXPIRED - REMOVED FROM ROTATION | " + f"Credential: {display_name} | Reason: {reason} | " + f"Action: Run 'credential_tool.py' to re-authenticate, then restart proxy" + ) + def is_credential_available(self, path: str) -> bool: """Check if a credential is available for rotation. Credentials are unavailable if: - 1. In re-auth queue (token is truly broken, requires user interaction) + 1. Permanently expired (refresh token invalid/revoked, requires manual re-auth) 2. Token is TRULY expired (past actual expiry, not just threshold) Note: Credentials in normal refresh queue are still available because the old token is valid until actual expiry. - - TTL cleanup (defense-in-depth): If a credential has been in the re-auth - queue longer than _unavailable_ttl_seconds without being processed, it's - cleaned up. This should only happen if the re-auth processor crashes or - is cancelled without proper cleanup. """ - # Check if in re-auth queue (truly unavailable) - if path in self._unavailable_credentials: - marked_time = self._unavailable_credentials.get(path) - if marked_time is not None: - now = time.time() - if now - marked_time > self._unavailable_ttl_seconds: - # Entry is stale - clean it up and return available - # This is a defense-in-depth for edge cases where re-auth - # processor crashed or was cancelled without cleanup - lib_logger.warning( - f"Credential '{Path(path).name}' stuck in re-auth queue for " - f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). " - f"Re-auth processor may have crashed. Auto-cleaning stale entry." - ) - # Clean up both tracking structures for consistency - self._unavailable_credentials.pop(path, None) - self._queued_credentials.discard(path) - else: - return False # Still in re-auth, not available + # Check if permanently expired (requires manual re-authentication) + if path in self._permanently_expired_credentials: + return False # Check if token is TRULY expired (not just threshold-expired) creds = self._credentials_cache.get(path) @@ -698,12 +708,7 @@ def is_credential_available(self, path: str) -> bool: # Token is actually expired - should not be used # Queue for refresh if not already queued if path not in self._queued_credentials: - # lib_logger.debug( - # f"Credential '{Path(path).name}' is truly expired, queueing for refresh" - # ) - asyncio.create_task( - self._queue_refresh(path, force=True, needs_reauth=False) - ) + asyncio.create_task(self._queue_refresh(path, force=True)) return False return True @@ -715,63 +720,26 @@ async def _ensure_queue_processor_running(self): self._process_refresh_queue() ) - async def _ensure_reauth_processor_running(self): - """Lazily starts the re-auth queue processor if not already running.""" - if self._reauth_processor_task is None or self._reauth_processor_task.done(): - self._reauth_processor_task = asyncio.create_task( - self._process_reauth_queue() - ) - - async def _queue_refresh( - self, path: str, force: bool = False, needs_reauth: bool = False - ): - """Add a credential to the appropriate refresh queue if not already queued. + async def _queue_refresh(self, path: str, force: bool = False): + """Add a credential to the refresh queue if not already queued. Args: path: Credential file path force: Force refresh even if not expired - needs_reauth: True if full re-authentication needed (routes to re-auth queue) - - Queue routing: - - needs_reauth=True: Goes to re-auth queue, marks as unavailable - - needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable - (old token is still valid until actual expiry) """ - # IMPORTANT: Only check backoff for simple automated refreshes - # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input - if not needs_reauth: - now = time.time() - if path in self._next_refresh_after: - backoff_until = self._next_refresh_after[path] - if now < backoff_until: - # Credential is in backoff for automated refresh, do not queue - # remaining = int(backoff_until - now) - # lib_logger.debug( - # f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)" - # ) - return + # Check backoff for automated refreshes + now = time.time() + if path in self._next_refresh_after: + backoff_until = self._next_refresh_after[path] + if now < backoff_until: + # Credential is in backoff, do not queue + return async with self._queue_tracking_lock: if path not in self._queued_credentials: self._queued_credentials.add(path) - - if needs_reauth: - # Re-auth queue: mark as unavailable (token is truly broken) - self._unavailable_credentials[path] = time.time() - # lib_logger.debug( - # f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). " - # f"Total unavailable: {len(self._unavailable_credentials)}" - # ) - await self._reauth_queue.put(path) - await self._ensure_reauth_processor_running() - else: - # Normal refresh queue: do NOT mark unavailable (old token still valid) - # lib_logger.debug( - # f"Queued '{Path(path).name}' for refresh (still available). " - # f"Queue size: {self._refresh_queue.qsize() + 1}" - # ) - await self._refresh_queue.put((path, force)) - await self._ensure_queue_processor_running() + await self._refresh_queue.put((path, force)) + await self._ensure_queue_processor_running() async def _process_refresh_queue(self): """Background worker that processes normal refresh requests sequentially. @@ -780,8 +748,8 @@ async def _process_refresh_queue(self): - 15s timeout per refresh operation - 30s delay between processing credentials (prevents thundering herd) - On failure: back of queue, max 3 retries before kicked - - If 401/403 detected: routes to re-auth queue - - Does NOT mark credentials unavailable (old token still valid) + - If 401/403 detected: marks credential as permanently expired + - Does NOT mark credentials unavailable for normal refresh (old token still valid) """ # lib_logger.info("Refresh queue processor started") while True: @@ -806,9 +774,6 @@ async def _process_refresh_queue(self): creds = self._credentials_cache.get(path) if creds and not self._is_token_expired(creds): # No longer expired, skip refresh - # lib_logger.debug( - # f"Credential '{Path(path).name}' no longer expired, skipping refresh" - # ) # Clear retry count on skip (not a failure) self._queue_retry_count.pop(path, None) continue @@ -834,19 +799,28 @@ async def _process_refresh_queue(self): except httpx.HTTPStatusError as e: status_code = e.response.status_code if status_code in (401, 403): - # Invalid refresh token - route to re-auth queue - lib_logger.warning( - f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). " - f"Routing to re-auth queue." - ) + # Invalid refresh token - mark as permanently expired self._queue_retry_count.pop(path, None) # Clear retry count async with self._queue_tracking_lock: - self._queued_credentials.discard( - path - ) # Remove from queued - await self._queue_refresh( - path, force=True, needs_reauth=True + self._queued_credentials.discard(path) + self._mark_credential_expired( + path, f"Refresh token invalid (HTTP {status_code})" ) + elif status_code == 400: + # Check for invalid_grant + error_body = e.response.text + if "invalid_grant" in error_body.lower(): + self._queue_retry_count.pop(path, None) + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + self._mark_credential_expired( + path, + f"Refresh token invalid (HTTP 400: invalid_grant)", + ) + else: + await self._handle_refresh_failure( + path, force, f"HTTP {status_code}" + ) else: await self._handle_refresh_failure( path, force, f"HTTP {status_code}" @@ -907,65 +881,6 @@ async def _handle_refresh_failure(self, path: str, force: bool, error: str): # Keep in queued_credentials set, add back to queue await self._refresh_queue.put((path, force)) - async def _process_reauth_queue(self): - """Background worker that processes re-auth requests. - - Key behaviors: - - Credentials ARE marked unavailable (token is truly broken) - - Uses ReauthCoordinator for interactive OAuth - - No automatic retry (requires user action) - - Cleans up unavailable status when done - """ - # lib_logger.info("Re-auth queue processor started") - while True: - path = None - try: - # Wait for an item with timeout to allow graceful shutdown - try: - path = await asyncio.wait_for( - self._reauth_queue.get(), timeout=60.0 - ) - except asyncio.TimeoutError: - # Queue is empty and idle for 60s - exit - self._reauth_processor_task = None - # lib_logger.debug("Re-auth queue processor idle, shutting down") - return - - try: - lib_logger.info(f"Starting re-auth for '{Path(path).name}'...") - await self.initialize_token(path, force_interactive=True) - lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'") - - except Exception as e: - lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}") - # No automatic retry for re-auth (requires user action) - - finally: - # Always clean up - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - # lib_logger.debug( - # f"Re-auth cleanup for '{Path(path).name}'. " - # f"Remaining unavailable: {len(self._unavailable_credentials)}" - # ) - self._reauth_queue.task_done() - - except asyncio.CancelledError: - # Clean up current credential before breaking - if path: - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - # lib_logger.debug("Re-auth queue processor cancelled") - break - except Exception as e: - lib_logger.error(f"Error in re-auth queue processor: {e}") - if path: - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - async def _perform_interactive_oauth( self, path: str, creds: Dict[str, Any], display_name: str ) -> Dict[str, Any]: @@ -1139,15 +1054,15 @@ async def handle_callback(reader, writer): lib_logger.info(f"Attempting to exchange authorization code for tokens...") async with httpx.AsyncClient() as client: # [PKCE + HEADERS] Include code_verifier and explicit headers for token exchange - # Uses GEMINI_CLI style headers for consistent fingerprinting + # Uses google-auth-library default headers to match native gemini-cli behavior + # Library version: google-auth-library@9.15.1 (as used by gemini-cli v0.28.0) response = await client.post( self.TOKEN_URI, headers={ "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br", - "User-Agent": "google-api-nodejs-client/10.3.0", - "X-Goog-Api-Client": "gl-node/22.18.0", + "User-Agent": "google-api-nodejs-client/9.15.1", + "X-Goog-Api-Client": "gl-node/22.16.0 auth/9.15.1", }, data={ "code": auth_code.strip(), @@ -1176,13 +1091,13 @@ async def handle_callback(reader, writer): new_creds["universe_domain"] = "googleapis.com" # Fetch user info and add metadata - # Uses GEMINI_CLI style headers per PR 246 + # Uses google-auth-library default headers to match native gemini-cli behavior user_info_response = await client.get( self.USER_INFO_URI, headers={ "Authorization": f"Bearer {new_creds['access_token']}", - "User-Agent": "google-api-nodejs-client/10.3.0", - "X-Goog-Api-Client": "gl-node/22.18.0", + "User-Agent": "google-api-nodejs-client/9.15.1", + "X-Goog-Api-Client": "gl-node/22.16.0 auth/9.15.1", }, ) user_info_response.raise_for_status() @@ -1219,15 +1134,18 @@ async def initialize_token( """ Initialize OAuth token, triggering interactive OAuth flow if needed. - If interactive OAuth is required (expired refresh token, missing credentials, etc.), - the flow is coordinated globally via ReauthCoordinator to ensure only one - interactive OAuth flow runs at a time across all providers. + For new credential setup (CLI tool), interactive OAuth is used when: + - Token is expired and refresh fails + - Refresh token is missing + + For proxy operation with force_interactive=True (deprecated): + - The credential is marked as permanently expired instead of interactive OAuth + - This prevents breaking the proxy flow with browser prompts Args: creds_or_path: Either a credentials dict or path to credentials file. - force_interactive: If True, skip expiry checks and force interactive OAuth. - Use this when the refresh token is known to be invalid - (e.g., after HTTP 400 from token endpoint). + force_interactive: If True, mark credential as expired (for proxy context). + For CLI tool, use the normal path (force_interactive=False). """ path = creds_or_path if isinstance(creds_or_path, str) else None @@ -1246,12 +1164,21 @@ async def initialize_token( creds = ( await self._load_credentials(creds_or_path) if path else creds_or_path ) - reason = "" + + # If force_interactive is True, this was called from proxy context + # where re-auth was requested. Instead of interactive OAuth, mark as expired. if force_interactive: - reason = ( - "re-authentication was explicitly requested (refresh token invalid)" + if path: + self._mark_credential_expired( + path, "Refresh token invalid - re-authentication required" + ) + raise ValueError( + f"Credential '{display_name}' requires re-authentication. " + f"Run 'credential_tool.py' to manually re-authenticate, then restart proxy." ) - elif not creds.get("refresh_token"): + + reason = "" + if not creds.get("refresh_token"): reason = "refresh token is missing" elif self._is_token_expired(creds): reason = "token is expired" @@ -1262,30 +1189,29 @@ async def initialize_token( return await self._refresh_token(path, creds) except Exception as e: lib_logger.warning( - f"Automatic token refresh for '{display_name}' failed: {e}. Proceeding to interactive login." + f"Automatic token refresh for '{display_name}' failed: {e}." ) + # Fall through to handle expired credential - lib_logger.warning( - f"{self.ENV_PREFIX} OAuth token for '{display_name}' needs setup: {reason}." - ) - - # [GLOBAL REAUTH COORDINATION] Use the global coordinator to ensure - # only one interactive OAuth flow runs at a time across all providers - coordinator = get_reauth_coordinator() - - # Define the interactive OAuth function to be executed by coordinator - async def _do_interactive_oauth(): - return await self._perform_interactive_oauth( - path, creds, display_name + # Distinguish between proxy context (has path) and credential tool context (no path) + # - Proxy context: mark as expired and fail (no interactive OAuth during proxy operation) + # - Credential tool context: do interactive OAuth for new credential setup + if path: + # [NO AUTO-REAUTH] Proxy context - mark as permanently expired + self._mark_credential_expired( + path, + f"{reason}. Manual re-authentication required via credential_tool.py", + ) + raise ValueError( + f"Credential '{display_name}' is expired and requires manual re-authentication. " + f"Run 'python credential_tool.py' to fix, then restart the proxy." ) - # Execute via global coordinator (ensures only one at a time) - return await coordinator.execute_reauth( - credential_path=path or display_name, - provider_name=self.ENV_PREFIX, - reauth_func=_do_interactive_oauth, - timeout=300.0, # 5 minute timeout for user to complete OAuth + # Credential tool context - do interactive OAuth for new credential setup + lib_logger.warning( + f"{self.ENV_PREFIX} OAuth token for '{display_name}' needs setup: {reason}." ) + return await self._perform_interactive_oauth(path, creds, display_name) lib_logger.info( f"{self.ENV_PREFIX} OAuth token at '{display_name}' is valid." @@ -1343,6 +1269,60 @@ async def _post_auth_discovery( # Default implementation does nothing - subclasses can override pass + async def _persist_project_metadata( + self, + credential_path: str, + project_id: str, + tier: Optional[str], + tier_full: Optional[str] = None, + ) -> None: + """ + Persist project ID and tier to the credential file for faster future startups. + + This is a shared implementation for Google Cloud OAuth providers that need + to cache project and tier information (e.g., Gemini CLI, Antigravity). + + Args: + credential_path: Path to the credential file + project_id: The Google Cloud project ID to persist + tier: Optional tier identifier (e.g., "PRO", "FREE", "ULTRA") + tier_full: Optional full tier name for display (e.g., "Google One AI PRO") + """ + # Skip persistence for env:// paths (environment-based credentials) + credential_index = self._parse_env_credential_path(credential_path) + if credential_index is not None: + lib_logger.debug( + f"Skipping project metadata persistence for env:// credential path: {credential_path}" + ) + return + + try: + # Load current credentials + with open(credential_path, "r") as f: + creds = json.load(f) + + # Update metadata + if "_proxy_metadata" not in creds: + creds["_proxy_metadata"] = {} + + creds["_proxy_metadata"]["project_id"] = project_id + if tier: + creds["_proxy_metadata"]["tier"] = tier + if tier_full: + creds["_proxy_metadata"]["tier_full"] = tier_full + + # Save back using the existing save method (handles atomic writes and permissions) + await self._save_credentials(credential_path, creds) + + lib_logger.debug( + f"Persisted project_id and tier to credential file: {credential_path}" + ) + except Exception as e: + lib_logger.warning( + f"Failed to persist project metadata to credential file: {e}" + ) + # Non-fatal - just means slower startup next time + async def get_user_info( self, creds_or_path: Union[Dict[str, Any], str] ) -> Dict[str, Any]: diff --git a/src/rotator_library/providers/iflow_auth_base.py b/src/rotator_library/providers/iflow_auth_base.py index 7f5a133b..0ac03208 100644 --- a/src/rotator_library/providers/iflow_auth_base.py +++ b/src/rotator_library/providers/iflow_auth_base.py @@ -33,12 +33,16 @@ lib_logger = logging.getLogger("rotator_library") +# OAuth endpoints IFLOW_OAUTH_AUTHORIZE_ENDPOINT = "https://iflow.cn/oauth" IFLOW_OAUTH_TOKEN_ENDPOINT = "https://iflow.cn/oauth/token" IFLOW_USER_INFO_ENDPOINT = "https://iflow.cn/api/oauth/getUserInfo" IFLOW_SUCCESS_REDIRECT_URL = "https://iflow.cn/oauth/success" IFLOW_ERROR_REDIRECT_URL = "https://iflow.cn/oauth/error" +# Cookie-based authentication endpoint +IFLOW_API_KEY_ENDPOINT = "https://platform.iflow.cn/api/openapi/apikey" + # Client credentials provided by iFlow IFLOW_CLIENT_ID = "10009311001" IFLOW_CLIENT_SECRET = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" @@ -46,6 +50,9 @@ # Local callback server port CALLBACK_PORT = 11451 +# Cookie API key refresh buffer (48 hours before expiry) +COOKIE_REFRESH_BUFFER_HOURS = 48 + @dataclass class IFlowCredentialSetupResult: @@ -79,6 +86,96 @@ def get_callback_port() -> int: return CALLBACK_PORT +def normalize_cookie(raw: str) -> str: + """ + Normalize and validate a cookie string for iFlow authentication. + + Ensures the cookie contains the required BXAuth field and is properly formatted. + + Args: + raw: Raw cookie string from user input + + Returns: + Normalized cookie string ending with semicolon + + Raises: + ValueError: If cookie is empty or missing BXAuth field + """ + trimmed = raw.strip() + if not trimmed: + raise ValueError("Cookie cannot be empty") + + # Normalize whitespace + combined = " ".join(trimmed.split()) + + # Ensure ends with semicolon + if not combined.endswith(";"): + combined += ";" + + # Validate BXAuth field is present + if "BXAuth=" not in combined: + raise ValueError( + "Cookie missing required 'BXAuth' field. " + "Please copy the complete cookie including BXAuth." + ) + + return combined + + +def extract_bx_auth(cookie: str) -> Optional[str]: + """ + Extract the BXAuth value from a cookie string. + + Args: + cookie: Cookie string (e.g., "BXAuth=abc123; other=value;") + + Returns: + The BXAuth value, or None if not found + """ + parts = cookie.split(";") + for part in parts: + part = part.strip() + if part.startswith("BXAuth="): + return part[7:] # Remove "BXAuth=" prefix + return None + + +def should_refresh_cookie_api_key(expire_time: str) -> Tuple[bool, float]: + """ + Check if a cookie-based API key needs refresh. + + Uses a 48-hour buffer to proactively refresh + API keys before they expire. + + Args: + expire_time: Expiry time string in format "YYYY-MM-DD HH:MM" + + Returns: + Tuple of (needs_refresh, seconds_until_expiry) + - needs_refresh: True if key expires within 48 hours + - seconds_until_expiry: Time until expiry (negative if already expired) + """ + if not expire_time or not expire_time.strip(): + return True, 0 + + try: + from datetime import datetime + + # Parse iFlow's expire time format: "YYYY-MM-DD HH:MM" + expire_dt = datetime.strptime(expire_time.strip(), "%Y-%m-%d %H:%M") + now = datetime.now() + + seconds_until_expiry = (expire_dt - now).total_seconds() + buffer_seconds = COOKIE_REFRESH_BUFFER_HOURS * 3600 + + needs_refresh = seconds_until_expiry < buffer_seconds + return needs_refresh, seconds_until_expiry + + except (ValueError, AttributeError) as e: + lib_logger.warning(f"Could not parse cookie expire_time '{expire_time}': {e}") + return True, 0 + + # Refresh tokens 24 hours before expiry REFRESH_EXPIRY_BUFFER_SECONDS = 24 * 60 * 60 @@ -210,26 +307,20 @@ def __init__(self): str, float ] = {} # Track backoff timers (Unix timestamp) - # [QUEUE SYSTEM] Sequential refresh processing with two separate queues + # [QUEUE SYSTEM] Sequential refresh processing # Normal refresh queue: for proactive token refresh (old token still valid) self._refresh_queue: asyncio.Queue = asyncio.Queue() self._queue_processor_task: Optional[asyncio.Task] = None - # Re-auth queue: for invalid refresh tokens (requires user interaction) - self._reauth_queue: asyncio.Queue = asyncio.Queue() - self._reauth_processor_task: Optional[asyncio.Task] = None - # Tracking sets/dicts - self._queued_credentials: set = set() # Track credentials in either queue - # Only credentials in re-auth queue are marked unavailable (not normal refresh) - # TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes - self._unavailable_credentials: Dict[ - str, float - ] = {} # Maps credential path -> timestamp when marked unavailable - # TTL should exceed reauth timeout (300s) to avoid premature cleanup - self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries + self._queued_credentials: set = set() # Track credentials in refresh queue self._queue_tracking_lock = asyncio.Lock() # Protects queue sets + # [PERMANENTLY EXPIRED] Track credentials that have been permanently removed from rotation + # These credentials have invalid/revoked refresh tokens and require manual re-authentication + # via credential_tool.py. They will NOT be selected for rotation until proxy restart. + self._permanently_expired_credentials: set = set() + # Retry tracking for normal refresh queue self._queue_retry_count: Dict[ str, int @@ -239,7 +330,6 @@ def __init__(self): self._refresh_timeout_seconds: int = 15 # Max time for single refresh self._refresh_interval_seconds: int = 30 # Delay between queue items self._refresh_max_retries: int = 3 # Attempts before kicked out - self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth def _parse_env_credential_path(self, path: str) -> Optional[str]: """ @@ -318,6 +408,7 @@ def _load_from_env( "last_check_timestamp": time.time(), "loaded_from_env": True, "env_credential_index": credential_index or "0", + "credential_type": "oauth", }, } @@ -435,6 +526,14 @@ def _is_token_expired(self, creds: Dict[str, Any]) -> bool: return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS + async def _get_lock(self, path: str) -> asyncio.Lock: + # [FIX RACE CONDITION] Protect lock creation with a master lock + # This prevents TOCTOU bug where multiple coroutines check and create simultaneously + async with self._locks_lock: + if path not in self._refresh_locks: + self._refresh_locks[path] = asyncio.Lock() + return self._refresh_locks[path] + def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool: """Check if token is TRULY expired (past actual expiry, not just threshold). @@ -458,6 +557,52 @@ def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool: return expiry_timestamp < time.time() + def _mark_credential_expired(self, path: str, reason: str) -> None: + """ + Permanently mark a credential as expired and remove it from rotation. + + This is called when a credential's refresh token is invalid or revoked, + meaning normal token refresh cannot work. The credential is removed from + rotation entirely and requires manual re-authentication via credential_tool.py. + + The proxy must be restarted after fixing the credential. + + Args: + path: Credential file path or env:// path + reason: Human-readable reason for expiration (e.g., "invalid_grant", "HTTP 401") + """ + # Add to permanently expired set + self._permanently_expired_credentials.add(path) + + # Clean up other tracking structures + self._queued_credentials.discard(path) + + # Get display name + if path.startswith("env://"): + display_name = path + else: + display_name = Path(path).name + + # Rich-formatted output for high visibility + console.print( + Panel( + f"[bold red]Credential:[/bold red] {display_name}\n" + f"[bold red]Reason:[/bold red] {reason}\n\n" + f"[yellow]This credential has been removed from rotation.[/yellow]\n" + f"[yellow]To fix: Run 'python credential_tool.py' to re-authenticate,[/yellow]\n" + f"[yellow]then restart the proxy.[/yellow]", + title="[bold red]⚠ CREDENTIAL EXPIRED - REMOVED FROM ROTATION[/bold red]", + border_style="red", + ) + ) + + # Also log at ERROR level for log files + lib_logger.error( + f"CREDENTIAL EXPIRED - REMOVED FROM ROTATION | " + f"Credential: {display_name} | Reason: {reason} | " + f"Action: Run 'credential_tool.py' to re-authenticate, then restart proxy" + ) + async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]: """ Fetches user info (including API key) from iFlow API. @@ -477,7 +622,7 @@ async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]: if not result.get("success"): raise ValueError("iFlow user info request not successful") - data = result.get("data", {}) + data = result.get("data") or {} api_key = data.get("apiKey", "").strip() if not api_key: raise ValueError("Missing API key in user info response") @@ -490,6 +635,256 @@ async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]: return {"api_key": api_key, "email": email} + # ========================================================================= + # COOKIE-BASED AUTHENTICATION METHODS + # ========================================================================= + + async def _fetch_api_key_info_with_cookie(self, cookie: str) -> Dict[str, Any]: + """ + Fetch API key info using browser cookie (GET request). + + This retrieves the current API key information including name, + masked key, and expiry time. + + Args: + cookie: Cookie string containing BXAuth + + Returns: + Dict with keys: name, apiKey, apiKeyMask, expireTime, hasExpired + """ + headers = { + "Cookie": cookie, + "Accept": "application/json, text/plain, */*", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + } + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(IFLOW_API_KEY_ENDPOINT, headers=headers) + + if response.status_code != 200: + lib_logger.error( + f"iFlow cookie GET failed: {response.status_code} {response.text}" + ) + raise ValueError( + f"Cookie authentication failed: HTTP {response.status_code}" + ) + + result = response.json() + + if not result.get("success"): + error_msg = result.get("message", "Unknown error") + raise ValueError(f"Cookie authentication failed: {error_msg}") + + data = result.get("data") or {} + + # Handle case where apiKey is masked - use apiKeyMask if apiKey is empty + if not data.get("apiKey") and data.get("apiKeyMask"): + data["apiKey"] = data["apiKeyMask"] + + return data + + async def _refresh_api_key_with_cookie( + self, cookie: str, name: str + ) -> Dict[str, Any]: + """ + Refresh/regenerate API key using browser cookie (POST request). + + This requests a new API key from iFlow using the session cookie. + + Args: + cookie: Cookie string containing BXAuth + name: The API key name (obtained from GET request) + + Returns: + Dict with keys: name, apiKey, expireTime, hasExpired + """ + if not name or not name.strip(): + raise ValueError("API key name is required for refresh") + + headers = { + "Cookie": cookie, + "Content-Type": "application/json", + "Accept": "application/json, text/plain, */*", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Origin": "https://platform.iflow.cn", + "Referer": "https://platform.iflow.cn/", + } + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + IFLOW_API_KEY_ENDPOINT, + headers=headers, + json={"name": name}, + ) + + if response.status_code != 200: + lib_logger.error( + f"iFlow cookie POST failed: {response.status_code} {response.text}" + ) + raise ValueError( + f"Cookie API key refresh failed: HTTP {response.status_code}" + ) + + result = response.json() + + if not result.get("success"): + error_msg = result.get("message", "Unknown error") + raise ValueError(f"Cookie API key refresh failed: {error_msg}") + + return result.get("data") or {} + + async def authenticate_with_cookie(self, cookie: str) -> Dict[str, Any]: + """ + Authenticate using browser cookie and obtain API key. + + This performs the full cookie-based authentication flow: + 1. Validate and normalize the cookie + 2. GET request to fetch current API key info + 3. POST request to refresh/get full API key + + Args: + cookie: Raw cookie string from browser (must contain BXAuth) + + Returns: + Dict with credential data including: + - cookie: Normalized cookie string (BXAuth only) + - api_key: The API key for iFlow API calls + - name: Account/key name + - expire_time: When the API key expires + - type: "iflow_cookie" + """ + # Normalize and validate cookie + try: + normalized_cookie = normalize_cookie(cookie) + except ValueError as e: + raise ValueError(f"Invalid cookie: {e}") + + # Extract BXAuth value for storage (only store what's needed) + bx_auth = extract_bx_auth(normalized_cookie) + if not bx_auth: + raise ValueError("Could not extract BXAuth from cookie") + + # Store only BXAuth for security (don't store other cookies) + cookie_to_store = f"BXAuth={bx_auth};" + + lib_logger.debug("Fetching API key info with cookie...") + + # GET request to fetch current info + key_info = await self._fetch_api_key_info_with_cookie(cookie_to_store) + name = key_info.get("name", "") + + if not name: + raise ValueError("Could not get API key name from iFlow") + + lib_logger.debug(f"Got API key info for '{name}', refreshing key...") + + # POST request to refresh/get full API key + refreshed = await self._refresh_api_key_with_cookie(cookie_to_store, name) + + api_key = refreshed.get("apiKey", "") + if not api_key: + raise ValueError("Could not get API key from iFlow") + + expire_time = refreshed.get("expireTime", "") + + return { + "cookie": cookie_to_store, + "api_key": api_key, + "name": name, + "expire_time": expire_time, + "_proxy_metadata": { + "email": name, # Use name as identifier + "last_check_timestamp": time.time(), + "credential_type": "cookie", + }, + } + + async def _refresh_cookie_credential(self, path: str) -> Dict[str, Any]: + """ + Refresh API key for a cookie-based credential. + + This is called when the API key is approaching expiry. + Note: If the browser session cookie (BXAuth) expires, the user + will need to re-authenticate manually. + + Args: + path: Path to the credential file + + Returns: + Updated credentials dict + """ + async with await self._get_lock(path): + # Read current credentials + creds = await self._load_credentials(path) + + if not self._is_cookie_credential(creds): + raise ValueError(f"Credential at '{path}' is not a cookie credential") + + cookie = creds.get("cookie", "") + name = creds.get("name", "") + + if not cookie or not name: + raise ValueError("Cookie credential missing cookie or name") + + # Check if refresh is actually needed + expire_time = creds.get("expire_time", "") + needs_refresh, seconds_until = should_refresh_cookie_api_key(expire_time) + + if not needs_refresh: + lib_logger.debug( + f"Cookie API key for '{name}' not due for refresh " + f"({seconds_until / 3600:.1f}h until expiry)" + ) + return creds + + lib_logger.info(f"Refreshing cookie API key for '{name}'...") + + try: + # Refresh the API key + refreshed = await self._refresh_api_key_with_cookie(cookie, name) + + # Update credentials + creds["api_key"] = refreshed.get("apiKey", creds["api_key"]) + creds["expire_time"] = refreshed.get("expireTime", creds["expire_time"]) + creds["_proxy_metadata"]["last_check_timestamp"] = time.time() + + # Save to disk + if not await self._save_credentials(path, creds): + raise IOError(f"Failed to save refreshed cookie credentials") + + lib_logger.info( + f"Successfully refreshed cookie API key for '{name}'. " + f"New expiry: {creds['expire_time']}" + ) + return creds + + except Exception as e: + # If refresh fails, the session cookie may be expired + lib_logger.error(f"Failed to refresh cookie API key for '{name}': {e}") + # Mark as expired if it's an auth error + if ( + "401" in str(e) + or "403" in str(e) + or "authentication" in str(e).lower() + ): + self._mark_credential_expired( + path, + f"Cookie session expired. Please re-authenticate with a fresh cookie.", + ) + raise + + def _is_cookie_credential(self, creds: Dict[str, Any]) -> bool: + """Check if credentials are cookie-based (vs OAuth-based).""" + # Primary check: explicit credential_type in metadata + cred_type = creds.get("_proxy_metadata", {}).get("credential_type") + if cred_type: + return cred_type == "cookie" + + # Fallback: infer from fields (for backwards compatibility) + # Cookie creds have 'cookie' field but no 'refresh_token' + return "cookie" in creds and "refresh_token" not in creds + async def _exchange_code_for_tokens( self, code: str, redirect_uri: str ) -> Dict[str, Any]: @@ -644,8 +1039,9 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] ) # [STATUS CODE HANDLING] - # [INVALID GRANT HANDLING] Handle 400/401/403 by raising - # Queue for re-auth in background so credential gets fixed automatically + # [INVALID GRANT HANDLING] Handle 400/401/403 by marking as expired + # These errors indicate the refresh token is invalid/revoked + # Mark as permanently expired - no interactive re-auth during proxy operation if status_code == 400: # Check if this is an invalid refresh token error try: @@ -662,39 +1058,25 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] "invalid" in error_desc.lower() or error_type == "invalid_request" ): - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP 400: {error_desc}). " - f"Queued for re-authentication, rotating to next credential." - ) - # Queue for re-auth in background (non-blocking, fire-and-forget) - # This ensures credential gets fixed even if caller doesn't handle it - asyncio.create_task( - self._queue_refresh( - path, force=True, needs_reauth=True - ) + self._mark_credential_expired( + path, + f"Refresh token invalid (HTTP 400: {error_desc})", ) - # Raise rotatable error instead of raw HTTPStatusError raise CredentialNeedsReauthError( credential_path=path, - message=f"Refresh token invalid for '{Path(path).name}'. Re-auth queued.", + message=f"Refresh token invalid for '{Path(path).name}'. Credential removed from rotation.", ) else: # Other 400 error - raise it raise elif status_code in (401, 403): - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP {status_code}). " - f"Queued for re-authentication, rotating to next credential." + self._mark_credential_expired( + path, f"Credential unauthorized (HTTP {status_code})" ) - # Queue for re-auth in background (non-blocking, fire-and-forget) - asyncio.create_task( - self._queue_refresh(path, force=True, needs_reauth=True) - ) - # Raise rotatable error instead of raw HTTPStatusError raise CredentialNeedsReauthError( credential_path=path, - message=f"Token invalid for '{Path(path).name}' (HTTP {status_code}). Re-auth queued.", + message=f"Token invalid for '{Path(path).name}' (HTTP {status_code}). Credential removed from rotation.", ) elif status_code == 429: @@ -827,25 +1209,41 @@ async def get_api_details(self, credential_identifier: str) -> Tuple[str, str]: Returns the API base URL and API key (NOT access_token). CRITICAL: iFlow uses the api_key for API requests, not the OAuth access_token. - Supports both credential types: - - OAuth: credential_identifier is a file path to JSON credentials + Supports three credential types: + - OAuth: credential_identifier is a file path to JSON credentials with refresh_token + - Cookie: credential_identifier is a file path to JSON credentials with cookie - API Key: credential_identifier is the API key string itself """ # Detect credential type if os.path.isfile(credential_identifier): - # OAuth credential: file path to JSON - lib_logger.debug( - f"Using OAuth credentials from file: {credential_identifier}" - ) creds = await self._load_credentials(credential_identifier) - # Check if token needs refresh - if self._is_token_expired(creds): - creds = await self._refresh_token(credential_identifier) + # Check if this is a cookie-based credential + if self._is_cookie_credential(creds): + lib_logger.debug( + f"Using cookie credentials from file: {credential_identifier}" + ) + # Check if API key needs refresh + expire_time = creds.get("expire_time", "") + needs_refresh, _ = should_refresh_cookie_api_key(expire_time) + if needs_refresh: + creds = await self._refresh_cookie_credential(credential_identifier) + + api_key = creds.get("api_key") + if not api_key: + raise ValueError("Missing api_key in iFlow cookie credentials") + else: + # OAuth credential + lib_logger.debug( + f"Using OAuth credentials from file: {credential_identifier}" + ) + # Check if token needs refresh + if self._is_token_expired(creds): + creds = await self._refresh_token(credential_identifier) - api_key = creds.get("api_key") - if not api_key: - raise ValueError("Missing api_key in iFlow OAuth credentials") + api_key = creds.get("api_key") + if not api_key: + raise ValueError("Missing api_key in iFlow OAuth credentials") else: # Direct API key: use as-is lib_logger.debug("Using direct API key for iFlow") @@ -856,159 +1254,74 @@ async def get_api_details(self, credential_identifier: str) -> Tuple[str, str]: async def proactively_refresh(self, credential_identifier: str): """ - Proactively refreshes tokens if they're close to expiry. - Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped. - """ - # lib_logger.debug(f"proactively_refresh called for: {credential_identifier}") + Proactively refreshes tokens/API keys if they're close to expiry. + + Handles both credential types: + - OAuth credentials: Refresh access token using refresh_token + - Cookie credentials: Refresh API key using browser session cookie + Direct API keys are skipped. + """ # Try to load credentials - this will fail for direct API keys - # and succeed for OAuth credentials (file paths or env:// paths) try: creds = await self._load_credentials(credential_identifier) except IOError as e: # Not a valid credential path (likely a direct API key string) - # lib_logger.debug( - # f"Skipping refresh for '{credential_identifier}' - not an OAuth credential: {e}" - # ) return - is_expired = self._is_token_expired(creds) - # lib_logger.debug( - # f"Token expired check for '{Path(credential_identifier).name}': {is_expired}" - # ) - - if is_expired: - # lib_logger.debug( - # f"Queueing refresh for '{Path(credential_identifier).name}'" - # ) - # lib_logger.info(f"Proactive refresh triggered for '{Path(credential_identifier).name}'") - await self._queue_refresh( - credential_identifier, force=False, needs_reauth=False - ) - - async def _get_lock(self, path: str) -> asyncio.Lock: - """Gets or creates a lock for the given credential path.""" - # [FIX RACE CONDITION] Protect lock creation with a master lock - async with self._locks_lock: - if path not in self._refresh_locks: - self._refresh_locks[path] = asyncio.Lock() - return self._refresh_locks[path] - - def is_credential_available(self, path: str) -> bool: - """Check if a credential is available for rotation. + # Handle cookie-based credentials + if self._is_cookie_credential(creds): + expire_time = creds.get("expire_time", "") + needs_refresh, seconds_until = should_refresh_cookie_api_key(expire_time) - Credentials are unavailable if: - 1. In re-auth queue (token is truly broken, requires user interaction) - 2. Token is TRULY expired (past actual expiry, not just threshold) - - Note: Credentials in normal refresh queue are still available because - the old token is valid until actual expiry. - - TTL cleanup (defense-in-depth): If a credential has been in the re-auth - queue longer than _unavailable_ttl_seconds without being processed, it's - cleaned up. This should only happen if the re-auth processor crashes or - is cancelled without proper cleanup. - """ - # Check if in re-auth queue (truly unavailable) - if path in self._unavailable_credentials: - marked_time = self._unavailable_credentials.get(path) - if marked_time is not None: - now = time.time() - if now - marked_time > self._unavailable_ttl_seconds: - # Entry is stale - clean it up and return available - # This is a defense-in-depth for edge cases where re-auth - # processor crashed or was cancelled without cleanup + if needs_refresh: + lib_logger.debug( + f"Proactive cookie API key refresh triggered for " + f"'{Path(credential_identifier).name}' " + f"({seconds_until / 3600:.1f}h until expiry)" + ) + try: + await self._refresh_cookie_credential(credential_identifier) + except Exception as e: lib_logger.warning( - f"Credential '{Path(path).name}' stuck in re-auth queue for " - f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). " - f"Re-auth processor may have crashed. Auto-cleaning stale entry." + f"Proactive cookie refresh failed for " + f"'{Path(credential_identifier).name}': {e}" ) - # Clean up both tracking structures for consistency - self._unavailable_credentials.pop(path, None) - self._queued_credentials.discard(path) - else: - return False # Still in re-auth, not available - - # Check if token is TRULY expired (not just threshold-expired) - creds = self._credentials_cache.get(path) - if creds and self._is_token_truly_expired(creds): - # Token is actually expired - should not be used - # Queue for refresh if not already queued - if path not in self._queued_credentials: - # lib_logger.debug( - # f"Credential '{Path(path).name}' is truly expired, queueing for refresh" - # ) - asyncio.create_task( - self._queue_refresh(path, force=True, needs_reauth=False) - ) - return False - - return True + return - async def _ensure_queue_processor_running(self): - """Lazily starts the queue processor if not already running.""" - if self._queue_processor_task is None or self._queue_processor_task.done(): - self._queue_processor_task = asyncio.create_task( - self._process_refresh_queue() - ) + # Handle OAuth credentials + is_expired = self._is_token_expired(creds) - async def _ensure_reauth_processor_running(self): - """Lazily starts the re-auth queue processor if not already running.""" - if self._reauth_processor_task is None or self._reauth_processor_task.done(): - self._reauth_processor_task = asyncio.create_task( - self._process_reauth_queue() - ) + if is_expired: + await self._queue_refresh(credential_identifier, force=False) - async def _queue_refresh( - self, path: str, force: bool = False, needs_reauth: bool = False - ): - """Add a credential to the appropriate refresh queue if not already queued. + async def _queue_refresh(self, path: str, force: bool = False): + """Add a credential to the refresh queue if not already queued. Args: path: Credential file path force: Force refresh even if not expired - needs_reauth: True if full re-authentication needed (routes to re-auth queue) - - Queue routing: - - needs_reauth=True: Goes to re-auth queue, marks as unavailable - - needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable - (old token is still valid until actual expiry) """ - # IMPORTANT: Only check backoff for simple automated refreshes - # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input - if not needs_reauth: - now = time.time() - if path in self._next_refresh_after: - backoff_until = self._next_refresh_after[path] - if now < backoff_until: - # Credential is in backoff for automated refresh, do not queue - # remaining = int(backoff_until - now) - # lib_logger.debug( - # f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)" - # ) - return + # Check backoff for automated refreshes + now = time.time() + if path in self._next_refresh_after: + backoff_until = self._next_refresh_after[path] + if now < backoff_until: + # Credential is in backoff, do not queue + return async with self._queue_tracking_lock: if path not in self._queued_credentials: self._queued_credentials.add(path) + await self._refresh_queue.put((path, force)) + await self._ensure_queue_processor_running() - if needs_reauth: - # Re-auth queue: mark as unavailable (token is truly broken) - self._unavailable_credentials[path] = time.time() - # lib_logger.debug( - # f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). " - # f"Total unavailable: {len(self._unavailable_credentials)}" - # ) - await self._reauth_queue.put(path) - await self._ensure_reauth_processor_running() - else: - # Normal refresh queue: do NOT mark unavailable (old token still valid) - # lib_logger.debug( - # f"Queued '{Path(path).name}' for refresh (still available). " - # f"Queue size: {self._refresh_queue.qsize() + 1}" - # ) - await self._refresh_queue.put((path, force)) - await self._ensure_queue_processor_running() + async def _ensure_queue_processor_running(self): + """Lazily starts the queue processor if not already running.""" + if self._queue_processor_task is None or self._queue_processor_task.done(): + self._queue_processor_task = asyncio.create_task( + self._process_refresh_queue() + ) async def _process_refresh_queue(self): """Background worker that processes normal refresh requests sequentially. @@ -1105,8 +1418,10 @@ async def _process_refresh_queue(self): self._queued_credentials.discard( path ) # Remove from queued - await self._queue_refresh( - path, force=True, needs_reauth=True + # Mark credential as permanently expired (no auto-reauth) + self._mark_credential_expired( + path, + f"Refresh token invalid (HTTP {status_code}). Requires manual re-authentication.", ) else: await self._handle_refresh_failure( @@ -1168,65 +1483,6 @@ async def _handle_refresh_failure(self, path: str, force: bool, error: str): # Keep in queued_credentials set, add back to queue await self._refresh_queue.put((path, force)) - async def _process_reauth_queue(self): - """Background worker that processes re-auth requests. - - Key behaviors: - - Credentials ARE marked unavailable (token is truly broken) - - Uses ReauthCoordinator for interactive OAuth - - No automatic retry (requires user action) - - Cleans up unavailable status when done - """ - # lib_logger.info("Re-auth queue processor started") - while True: - path = None - try: - # Wait for an item with timeout to allow graceful shutdown - try: - path = await asyncio.wait_for( - self._reauth_queue.get(), timeout=60.0 - ) - except asyncio.TimeoutError: - # Queue is empty and idle for 60s - exit - self._reauth_processor_task = None - # lib_logger.debug("Re-auth queue processor idle, shutting down") - return - - try: - lib_logger.info(f"Starting re-auth for '{Path(path).name}'...") - await self.initialize_token(path, force_interactive=True) - lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'") - - except Exception as e: - lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}") - # No automatic retry for re-auth (requires user action) - - finally: - # Always clean up - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - # lib_logger.debug( - # f"Re-auth cleanup for '{Path(path).name}'. " - # f"Remaining unavailable: {len(self._unavailable_credentials)}" - # ) - self._reauth_queue.task_done() - - except asyncio.CancelledError: - # Clean up current credential before breaking - if path: - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - # lib_logger.debug("Re-auth queue processor cancelled") - break - except Exception as e: - lib_logger.error(f"Error in re-auth queue processor: {e}") - if path: - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - async def _perform_interactive_oauth( self, path: str, creds: Dict[str, Any], display_name: str ) -> Dict[str, Any]: @@ -1336,6 +1592,8 @@ async def _perform_interactive_oauth( "email": token_data["email"], "last_check_timestamp": time.time(), } + # Always set credential_type for OAuth credentials + creds["_proxy_metadata"]["credential_type"] = "oauth" if path: if not await self._save_credentials(path, creds): @@ -1387,6 +1645,59 @@ async def initialize_token( await self._load_credentials(creds_or_path) if path else creds_or_path ) + # ========================================================= + # COOKIE CREDENTIAL HANDLING - check first before OAuth logic + # ========================================================= + if self._is_cookie_credential(creds): + # Validate required fields for cookie credentials + if not creds.get("cookie") or not creds.get("api_key"): + error_msg = ( + "Cookie credential missing required fields (cookie or api_key)" + ) + if path: + self._mark_credential_expired(path, error_msg) + raise ValueError( + f"Credential '{display_name}' is invalid: {error_msg}. " + f"Run 'python credential_tool.py' to re-authenticate." + ) + raise ValueError(error_msg) + + # Check if API key needs refresh (48-hour buffer) + if path: + expire_time = creds.get("expire_time", "") + needs_refresh, seconds_until = should_refresh_cookie_api_key( + expire_time + ) + if needs_refresh: + try: + lib_logger.info( + f"Cookie API key for '{display_name}' needs refresh " + f"({seconds_until / 3600:.1f}h until expiry)" + ) + creds = await self._refresh_cookie_credential(path) + except Exception as e: + lib_logger.warning( + f"Cookie API key refresh for '{display_name}' failed: {e}" + ) + # If API key is already expired (negative seconds), mark as expired + if seconds_until < 0: + self._mark_credential_expired( + path, + f"Cookie API key expired and refresh failed: {e}. " + f"Please re-authenticate with a fresh cookie.", + ) + raise ValueError( + f"Credential '{display_name}' cookie API key expired. " + f"Run 'python credential_tool.py' to re-authenticate." + ) + # Otherwise continue with existing (still valid) API key + + lib_logger.info(f"Cookie credential at '{display_name}' is valid.") + return creds + + # ========================================================= + # OAUTH CREDENTIAL HANDLING - existing logic + # ========================================================= reason = "" if force_interactive: reason = ( @@ -1404,31 +1715,29 @@ async def initialize_token( return await self._refresh_token(path) except Exception as e: lib_logger.warning( - f"Automatic token refresh for '{display_name}' failed: {e}. Proceeding to interactive login." + f"Automatic token refresh for '{display_name}' failed: {e}." ) + # Fall through to handle expired credential - # Interactive OAuth flow - lib_logger.warning( - f"iFlow OAuth token for '{display_name}' needs setup: {reason}." - ) - - # [GLOBAL REAUTH COORDINATION] Use the global coordinator to ensure - # only one interactive OAuth flow runs at a time across all providers - coordinator = get_reauth_coordinator() - - # Define the interactive OAuth function to be executed by coordinator - async def _do_interactive_oauth(): - return await self._perform_interactive_oauth( - path, creds, display_name + # Distinguish between proxy context (has path) and credential tool context (no path) + # - Proxy context: mark as expired and fail (no interactive OAuth during proxy operation) + # - Credential tool context: do interactive OAuth for new credential setup + if path: + # [NO AUTO-REAUTH] Proxy context - mark as permanently expired + self._mark_credential_expired( + path, + f"{reason}. Manual re-authentication required via credential_tool.py", + ) + raise ValueError( + f"Credential '{display_name}' is expired and requires manual re-authentication. " + f"Run 'python credential_tool.py' to fix, then restart the proxy." ) - # Execute via global coordinator (ensures only one at a time) - return await coordinator.execute_reauth( - credential_path=path or display_name, - provider_name="IFLOW", - reauth_func=_do_interactive_oauth, - timeout=300.0, # 5 minute timeout for user to complete OAuth + # Credential tool context - do interactive OAuth for new credential setup + lib_logger.warning( + f"iFlow OAuth token for '{display_name}' needs setup: {reason}." ) + return await self._perform_interactive_oauth(path, creds, display_name) lib_logger.info(f"iFlow OAuth token at '{display_name}' is valid.") return creds @@ -1440,10 +1749,24 @@ async def get_auth_header(self, credential_path: str) -> Dict[str, str]: """ Returns auth header with API key (NOT OAuth access_token). CRITICAL: iFlow API requests use the api_key, not the OAuth tokens. + + Handles both OAuth and cookie-based credentials: + - OAuth: checks token expiry and refreshes OAuth tokens if needed + - Cookie: checks API key expiry and refreshes via cookie if needed """ creds = await self._load_credentials(credential_path) - if self._is_token_expired(creds): - creds = await self._refresh_token(credential_path) + + # Handle credential refresh based on type + if self._is_cookie_credential(creds): + # Cookie credential: check API key expiry + expire_time = creds.get("expire_time", "") + needs_refresh, _ = should_refresh_cookie_api_key(expire_time) + if needs_refresh: + creds = await self._refresh_cookie_credential(credential_path) + else: + # OAuth credential: check token expiry + if self._is_token_expired(creds): + creds = await self._refresh_token(credential_path) api_key = creds.get("api_key") if not api_key: @@ -1592,6 +1915,16 @@ async def setup_credential( if is_update: file_path = existing_path + # Check if existing credential is Cookie type (will be replaced) + try: + with open(existing_path, "r") as f: + existing_creds = json.load(f) + if self._is_cookie_credential(existing_creds): + lib_logger.info( + f"Replacing existing Cookie credential for {email} with OAuth credential" + ) + except Exception: + pass lib_logger.info( f"Found existing credential for {email}, updating {file_path.name}" ) @@ -1620,6 +1953,220 @@ async def setup_credential( lib_logger.error(f"Credential setup failed: {e}") return IFlowCredentialSetupResult(success=False, error=str(e)) + async def setup_cookie_credential( + self, base_dir: Optional[Path] = None + ) -> IFlowCredentialSetupResult: + """ + Complete cookie-based credential setup flow with manual paste. + + This guides the user through obtaining the BXAuth cookie from their + browser and uses it to authenticate and get an API key. + + Args: + base_dir: Directory to save credential file (defaults to oauth_creds/) + + Returns: + IFlowCredentialSetupResult with success status and file path + """ + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + # Ensure directory exists + base_dir.mkdir(exist_ok=True) + + try: + # Display instructions for cookie extraction + console.print( + Panel( + Text.from_markup( + "[bold]To get your iFlow session cookie:[/bold]\n\n" + "1. Open [link=https://platform.iflow.cn]https://platform.iflow.cn[/link] in your browser\n" + "2. Make sure you are [bold]logged in[/bold]\n" + "3. Press [bold]F12[/bold] to open Developer Tools\n" + "4. Go to: [bold]Application[/bold] (tab) → [bold]Cookies[/bold] → [bold]platform.iflow.cn[/bold]\n" + " [dim](In Firefox: Storage → Cookies)[/dim]\n" + "5. Find the row with Name = [bold cyan]'BXAuth'[/bold cyan]\n" + "6. Double-click the [bold]Value[/bold] cell and copy it (Ctrl+C)\n" + "7. Paste it below\n\n" + "[dim]Note: The cookie typically starts with 'eyJ' and is a long string.[/dim]" + ), + title="[bold blue]iFlow Cookie Setup[/bold blue]", + border_style="blue", + ) + ) + + # Prompt for cookie value + while True: + cookie_value = Prompt.ask( + "\n[bold]Paste your BXAuth cookie value[/bold] (or 'q' to quit)" + ) + + if cookie_value.lower() == "q": + return IFlowCredentialSetupResult( + success=False, error="Setup cancelled by user" + ) + + if not cookie_value.strip(): + console.print( + "[yellow]Cookie value cannot be empty. Please try again.[/yellow]" + ) + continue + + # Clean up common paste issues + cookie_value = cookie_value.strip() + if cookie_value.startswith("BXAuth="): + cookie_value = cookie_value[7:] + if cookie_value.endswith(";"): + cookie_value = cookie_value[:-1] + + if len(cookie_value) < 20: + console.print( + "[yellow]Cookie value seems too short. " + "Make sure you copied the complete BXAuth value.[/yellow]" + ) + continue + + break + + # Build the full cookie string + cookie_string = f"BXAuth={cookie_value};" + + console.print("\n[dim]Validating cookie...[/dim]") + + # Authenticate with the cookie + try: + new_creds = await self.authenticate_with_cookie(cookie_string) + except ValueError as e: + return IFlowCredentialSetupResult( + success=False, error=f"Cookie authentication failed: {e}" + ) + + # Get identifier for deduplication + name = new_creds.get("name", "") + if not name: + return IFlowCredentialSetupResult( + success=False, error="Could not retrieve account name from cookie" + ) + + console.print(f"[green]✓ Cookie validated for account: {name}[/green]") + + # Check for existing credential with same name/email + # Use name as the email identifier for deduplication + existing_path = self._find_existing_credential_by_email(name, base_dir) + is_update = existing_path is not None + + if is_update: + file_path = existing_path + # Check if existing credential is OAuth type (will be replaced) + try: + with open(existing_path, "r") as f: + existing_creds = json.load(f) + if not self._is_cookie_credential(existing_creds): + console.print( + f"[yellow]Replacing existing OAuth credential for {name} with Cookie credential[/yellow]" + ) + except Exception: + pass + console.print( + f"[yellow]Found existing credential for {name}, updating {file_path.name}[/yellow]" + ) + else: + file_path = self._build_credential_path(base_dir) + console.print( + f"[green]Creating new credential for {name} at {file_path.name}[/green]" + ) + + # Set email field to name for consistency with OAuth credentials + new_creds["email"] = name + + # Save credentials to file + if not await self._save_credentials(str(file_path), new_creds): + return IFlowCredentialSetupResult( + success=False, + error=f"Failed to save credentials to disk at {file_path.name}", + ) + + console.print( + Panel( + f"[bold green]Cookie credential saved successfully![/bold green]\n\n" + f"Account: {name}\n" + f"API Key: {new_creds.get('api_key', '')[:20]}...\n" + f"Expires: {new_creds.get('expire_time', 'Unknown')}\n" + f"File: {file_path.name}", + title="[bold green]Success[/bold green]", + border_style="green", + ) + ) + + return IFlowCredentialSetupResult( + success=True, + file_path=str(file_path), + email=name, + is_update=is_update, + credentials=new_creds, + ) + + except Exception as e: + lib_logger.error(f"Cookie credential setup failed: {e}") + return IFlowCredentialSetupResult(success=False, error=str(e)) + + def _find_existing_cookie_credential_by_name( + self, name: str, base_dir: Optional[Path] = None + ) -> Optional[Path]: + """Find an existing cookie credential file for the given name.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_cookie_*.json") + + for cred_file in glob(pattern): + try: + with open(cred_file, "r") as f: + creds = json.load(f) + existing_name = creds.get("name", "") + if existing_name == name: + return Path(cred_file) + except (json.JSONDecodeError, IOError) as e: + lib_logger.debug(f"Could not read credential file {cred_file}: {e}") + continue + + return None + + def _get_next_cookie_credential_number( + self, base_dir: Optional[Path] = None + ) -> int: + """Get the next available cookie credential number.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_cookie_*.json") + + existing_numbers = [] + for cred_file in glob(pattern): + match = re.search(r"_cookie_(\d+)\.json$", cred_file) + if match: + existing_numbers.append(int(match.group(1))) + + if not existing_numbers: + return 1 + return max(existing_numbers) + 1 + + def _build_cookie_credential_path( + self, base_dir: Optional[Path] = None, number: Optional[int] = None + ) -> Path: + """Build a path for a new cookie credential file.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + if number is None: + number = self._get_next_cookie_credential_number(base_dir) + + prefix = self._get_provider_file_prefix() + filename = f"{prefix}_cookie_{number}.json" + return base_dir / filename + def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]: """Generate .env file lines for an iFlow credential.""" email = creds.get("email") or creds.get("_proxy_metadata", {}).get( @@ -1687,15 +2234,16 @@ def export_credential_to_env( return None def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]: - """List all iFlow credential files.""" + """List all iFlow credential files (both OAuth and cookie-based).""" if base_dir is None: base_dir = self._get_oauth_base_dir() prefix = self._get_provider_file_prefix() - pattern = str(base_dir / f"{prefix}_oauth_*.json") - credentials = [] - for cred_file in sorted(glob(pattern)): + + # List all credentials (both OAuth and cookie are stored as *_oauth_*.json) + oauth_pattern = str(base_dir / f"{prefix}_oauth_*.json") + for cred_file in sorted(glob(oauth_pattern)): try: with open(cred_file, "r") as f: creds = json.load(f) @@ -1704,17 +2252,31 @@ def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, An "email", "unknown" ) + # Determine credential type from _proxy_metadata + cred_type = creds.get("_proxy_metadata", {}).get("credential_type") + if not cred_type: + # Fallback: infer from fields + if "cookie" in creds and "refresh_token" not in creds: + cred_type = "cookie" + else: + cred_type = "oauth" + # Extract number from filename match = re.search(r"_oauth_(\d+)\.json$", cred_file) number = int(match.group(1)) if match else 0 - credentials.append( - { - "file_path": cred_file, - "email": email, - "number": number, - } - ) + cred_info = { + "file_path": cred_file, + "email": email, + "number": number, + "type": cred_type, + } + + # Add expire_time for cookie credentials + if cred_type == "cookie": + cred_info["expire_time"] = creds.get("expire_time", "") + + credentials.append(cred_info) except Exception as e: lib_logger.debug(f"Could not read credential file {cred_file}: {e}") continue @@ -1722,7 +2284,7 @@ def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, An return credentials def delete_credential(self, credential_path: str) -> bool: - """Delete a credential file.""" + """Delete a credential file (OAuth or cookie-based).""" try: cred_path = Path(credential_path) diff --git a/src/rotator_library/providers/iflow_provider.py b/src/rotator_library/providers/iflow_provider.py index 932b8844..ae292719 100644 --- a/src/rotator_library/providers/iflow_provider.py +++ b/src/rotator_library/providers/iflow_provider.py @@ -4,14 +4,17 @@ # src/rotator_library/providers/iflow_provider.py import copy +import hmac +import hashlib import json import time import os import httpx import logging -from typing import Union, AsyncGenerator, List, Dict, Any +from typing import Union, AsyncGenerator, List, Dict, Any, Optional from .provider_interface import ProviderInterface from .iflow_auth_base import IFlowAuthBase +from .provider_cache import ProviderCache from ..model_definitions import ModelDefinitions from ..timeout_config import TimeoutConfig from ..transaction_logger import ProviderLogger @@ -27,15 +30,20 @@ # Model list can be expanded as iFlow supports more models HARDCODED_MODELS = [ "glm-4.6", + "glm-4.7", "minimax-m2", + "minimax-m2.1", "qwen3-coder-plus", "kimi-k2", "kimi-k2-0905", - "kimi-k2-thinking", + "kimi-k2-thinking", # Seems to not work, but should + "kimi-k2.5", # Seems to not work, but should "qwen3-max", + "qwen3-max-preview", "qwen3-235b-a22b-thinking-2507", + "deepseek-v3.2-reasoner", "deepseek-v3.2-chat", - "deepseek-v3.2", + "deepseek-v3.2", # seems to not work, but should. Use above variants instead "deepseek-v3.1", "deepseek-v3", "deepseek-r1", @@ -62,6 +70,41 @@ "response_format", } +IFLOW_USER_AGENT = "iFlow-Cli" +IFLOW_HEADER_SESSION_ID = "session-id" +IFLOW_HEADER_TIMESTAMP = "x-iflow-timestamp" +IFLOW_HEADER_SIGNATURE = "x-iflow-signature" + +# ============================================================================= +# THINKING MODE CONFIGURATION +# ============================================================================= +# Models using chat_template_kwargs.enable_thinking (boolean toggle) +# Based on Go implementation: internal/thinking/provider/iflow/apply.go +ENABLE_THINKING_MODELS = { + "glm-4.6", + "glm-4.7", + "qwen3-max-preview", + "deepseek-v3.2", + "deepseek-v3.1", +} + +# GLM models need additional clear_thinking=false when thinking is enabled +GLM_MODELS = {"glm-4.6", "glm-4.7"} + +# Models using reasoning_split (boolean) instead of enable_thinking +REASONING_SPLIT_MODELS = {"minimax-m2", "minimax-m2.1"} + +# Models that benefit from reasoning_content preservation in message history +# (for multi-turn conversations) +REASONING_PRESERVATION_MODELS_PREFIXES = ("glm-4", "minimax-m2") + +# Cache file path for reasoning content preservation +_REASONING_CACHE_FILE = ( + Path(__file__).resolve().parent.parent.parent.parent + / "cache" + / "iflow_reasoning.json" +) + class IFlowProvider(IFlowAuthBase, ProviderInterface): """ @@ -75,6 +118,15 @@ def __init__(self): super().__init__() self.model_definitions = ModelDefinitions() + # Initialize reasoning cache for multi-turn conversation support + # Created in __init__ (not module level) to ensure event loop exists + self._reasoning_cache = ProviderCache( + cache_file=_REASONING_CACHE_FILE, + memory_ttl_seconds=3600, # 1 hour in memory + disk_ttl_seconds=86400, # 24 hours on disk + env_prefix="IFLOW_REASONING_CACHE", + ) + def has_custom_logic(self) -> bool: return True @@ -218,10 +270,211 @@ def _clean_schema_properties(self, properties: Dict[str, Any]) -> None: if "items" in prop_schema and isinstance(prop_schema["items"], dict): self._clean_schema_properties({"item": prop_schema["items"]}) - def _build_request_payload(self, **kwargs) -> Dict[str, Any]: + # ========================================================================= + # THINKING MODE SUPPORT + # ========================================================================= + + def _should_enable_thinking(self, kwargs: Dict[str, Any]) -> Optional[bool]: + """ + Check if thinking should be enabled based on request parameters. + + Uses OpenAI-compatible format. Checks for reasoning_effort parameter. + Thinking is enabled for any value except "none", "disabled", or "0". + + Returns: + True: Enable thinking + False: Disable thinking explicitly + None: No thinking params (passthrough - don't modify payload) + """ + # Check reasoning_effort (OpenAI-style) + reasoning_effort = kwargs.get("reasoning_effort") + if reasoning_effort is not None: + effort_lower = str(reasoning_effort).lower().strip() + # Disabled values + if effort_lower in ("none", "disabled", "0", "off", "false"): + # lib_logger.info( + # f"iFlow: Detected reasoning_effort='{reasoning_effort}' → thinking DISABLED" + # ) + return False + # Any other value enables thinking + # lib_logger.info( + # f"iFlow: Detected reasoning_effort='{reasoning_effort}' → thinking ENABLED" + # ) + return True + + # Check extra_body for thinking config (Claude-style, for compatibility) + extra_body = kwargs.get("extra_body", {}) + if extra_body and "thinking" in extra_body: + thinking = extra_body["thinking"] + if isinstance(thinking, dict): + budget = thinking.get("budget_tokens", 0) + return budget != 0 + return bool(thinking) + + return None # No thinking params specified + + def _apply_thinking_config( + self, payload: Dict[str, Any], model_name: str, kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Apply thinking configuration for supported iFlow models. + + Logic matches Go implementation (internal/thinking/provider/iflow/apply.go): + - GLM models: enable_thinking + clear_thinking=false (when enabled) + - Qwen/DeepSeek: enable_thinking only + - MiniMax: reasoning_split + + Args: + payload: The request payload to modify + model_name: Model name (without provider prefix) + kwargs: Original request kwargs containing reasoning_effort, etc. + + Returns: + Modified payload with thinking config applied + """ + model_lower = model_name.lower() + enable_thinking = self._should_enable_thinking(kwargs) + + if enable_thinking is None: + return payload # No thinking params, passthrough + + # Check model type + is_glm = model_lower in GLM_MODELS + is_enable_thinking = model_lower in ENABLE_THINKING_MODELS or is_glm + is_minimax = model_lower in REASONING_SPLIT_MODELS + + if is_enable_thinking: + # Models using chat_template_kwargs.enable_thinking + if "chat_template_kwargs" not in payload: + payload["chat_template_kwargs"] = {} + payload["chat_template_kwargs"]["enable_thinking"] = enable_thinking + + # GLM models: strip clear_thinking first (like Go does with DeleteBytes), + # then set it to false only when thinking is enabled + if is_glm: + payload["chat_template_kwargs"].pop("clear_thinking", None) + if enable_thinking: + payload["chat_template_kwargs"]["clear_thinking"] = False + + lib_logger.info( + f"iFlow: Applied enable_thinking={enable_thinking} for {model_name}" + ) + elif is_minimax: + # MiniMax models use reasoning_split + payload["reasoning_split"] = enable_thinking + lib_logger.info( + f"iFlow: Applied reasoning_split={enable_thinking} for {model_name}" + ) + + return payload + + # ========================================================================= + # REASONING CONTENT PRESERVATION + # ========================================================================= + + def _get_conversation_signature(self, messages: List[Dict[str, Any]]) -> str: + """ + Generate a stable conversation signature from the first user message. + + This provides conversation-level uniqueness for cache keys. + """ + for msg in messages: + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, str) and content: + # Use first 100 chars of first user message as conversation signature + return hashlib.md5(content[:100].encode()).hexdigest()[:8] + return "default" + + def _get_message_cache_key(self, message: Dict[str, Any], conv_sig: str) -> str: + """ + Generate a cache key for a message to look up cached reasoning. + + Combines: + - Conversation signature (stable per conversation) + - Message content hash (identifies specific message) + """ + content = message.get("content", "") or "" + role = message.get("role", "") + # Use content[:200] + role for message identity + msg_hash = hashlib.md5(f"{role}:{content[:200]}".encode()).hexdigest()[:12] + return f"{conv_sig}:{msg_hash}" + + def _store_reasoning_content(self, message: Dict[str, Any], conv_sig: str) -> None: + """ + Store reasoning_content from an assistant message for later retrieval. + + Args: + message: The assistant message dict containing reasoning_content + conv_sig: Conversation signature for the cache key + """ + reasoning = message.get("reasoning_content") + if reasoning and message.get("role") == "assistant": + key = self._get_message_cache_key(message, conv_sig) + self._reasoning_cache.store(key, reasoning) + lib_logger.debug(f"iFlow: Cached reasoning_content for message {key}") + + def _inject_reasoning_content( + self, messages: List[Dict[str, Any]], model_name: str + ) -> List[Dict[str, Any]]: + """ + Inject cached reasoning_content into assistant messages. + + Only for models that benefit from reasoning preservation (GLM-4.x, MiniMax-M2.x). + This is helpful for multi-turn conversations where the model may benefit + from seeing its previous reasoning to maintain coherent thought chains. + + Args: + messages: List of messages in the conversation + model_name: Model name (without provider prefix) + + Returns: + Messages list with reasoning_content restored where available + """ + model_lower = model_name.lower() + + # Only for models that benefit from reasoning preservation + if not any( + model_lower.startswith(prefix) + for prefix in REASONING_PRESERVATION_MODELS_PREFIXES + ): + return messages + + # Get conversation signature + conv_sig = self._get_conversation_signature(messages) + + result = [] + restored_count = 0 + for msg in messages: + if msg.get("role") == "assistant" and not msg.get("reasoning_content"): + key = self._get_message_cache_key(msg, conv_sig) + cached = self._reasoning_cache.retrieve(key) + if cached: + msg = {**msg, "reasoning_content": cached} + restored_count += 1 + result.append(msg) + + if restored_count > 0: + lib_logger.debug( + f"iFlow: Restored reasoning_content for {restored_count} messages in {model_name}" + ) + + return result + + def _build_request_payload( + self, model_name: str, full_kwargs: Dict[str, Any], **kwargs + ) -> Dict[str, Any]: """ Builds a clean request payload with only supported parameters. - This prevents 400 Bad Request errors from litellm-internal parameters. + Also applies thinking mode and reasoning content preservation. + + Args: + model_name: Model name without provider prefix (for thinking/reasoning logic) + full_kwargs: Original kwargs (for extracting reasoning_effort, etc.) + **kwargs: Filtered kwargs with stripped model name + + Returns: + Complete payload ready for iFlow API """ # Extract only supported OpenAI parameters payload = {k: v for k, v in kwargs.items() if k in SUPPORTED_PARAMS} @@ -254,70 +507,266 @@ def _build_request_payload(self, **kwargs) -> Dict[str, Any]: ] lib_logger.debug("Injected placeholder tool for empty tools array") + # Inject cached reasoning_content into messages for multi-turn conversations + if "messages" in payload: + payload["messages"] = self._inject_reasoning_content( + payload["messages"], model_name + ) + + # Apply thinking mode configuration based on reasoning_effort + payload = self._apply_thinking_config(payload, model_name, full_kwargs) + return payload - def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): + def _create_iflow_signature( + self, user_agent: str, session_id: str, timestamp_ms: int, api_key: str + ) -> str: + """Generate iFlow HMAC-SHA256 signature: userAgent:sessionId:timestamp.""" + if not api_key: + return "" + + payload = f"{user_agent}:{session_id}:{timestamp_ms}" + return hmac.new( + api_key.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + def _build_iflow_headers(self, api_key: str, stream: bool) -> Dict[str, str]: + """Build iFlow request headers, including signed auth headers.""" + session_id = f"session-{uuid.uuid4()}" + timestamp_ms = int(time.time() * 1000) + signature = self._create_iflow_signature( + IFLOW_USER_AGENT, session_id, timestamp_ms, api_key + ) + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": IFLOW_USER_AGENT, + IFLOW_HEADER_SESSION_ID: session_id, + IFLOW_HEADER_TIMESTAMP: str(timestamp_ms), + "Accept": "text/event-stream" if stream else "application/json", + } + if signature: + headers[IFLOW_HEADER_SIGNATURE] = signature + + return headers + + def _extract_finish_reason_from_chunk(self, chunk: Dict[str, Any]) -> Optional[str]: + """ + Extract finish_reason from a raw iFlow chunk by searching all possible locations. + + Args: + chunk: Raw chunk from iFlow API + + Returns: + The finish_reason if found, None otherwise + """ + choices = chunk.get("choices", []) + for choice in choices: + if isinstance(choice, dict): + finish_reason = choice.get("finish_reason") + if finish_reason: + return finish_reason + return None + + def _convert_chunk_to_openai( + self, + chunk: Dict[str, Any], + model_id: str, + stream_state: Optional[Dict[str, Any]] = None, + ): """ Converts a raw iFlow SSE chunk to an OpenAI-compatible chunk. Since iFlow is OpenAI-compatible, minimal conversion is needed. - CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk) - without early return to ensure finish_reason is properly processed. + ROBUST FINISH_REASON HANDLING: + - Tracks finish_reason across all chunks in stream_state + - tool_calls takes priority over stop + - Always sets finish_reason on final chunks (with usage) + - Logs warning if no finish_reason found + + Args: + chunk: Raw chunk from iFlow API + model_id: Model identifier for response + stream_state: Mutable dict to track state across chunks """ if not isinstance(chunk, dict): return + # Initialize stream_state if not provided + if stream_state is None: + stream_state = {} + # Get choices and usage data choices = chunk.get("choices", []) usage_data = chunk.get("usage") - # Handle chunks with BOTH choices and usage (typical for final chunk) - # CRITICAL: Process choices FIRST to capture finish_reason, then yield usage - if choices and usage_data: - # Yield the choice chunk first (contains finish_reason) - yield { - "choices": choices, - "model": model_id, - "object": "chat.completion.chunk", - "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"), - "created": chunk.get("created", int(time.time())), + # IMPORTANT: Empty dict {} is falsy in Python, but "usage": {} still indicates final chunk + # Use "is not None" check instead of truthiness + has_usage = usage_data is not None + is_final_chunk = has_usage + + # Extract and track finish_reason from raw chunk + raw_finish_reason = self._extract_finish_reason_from_chunk(chunk) + if raw_finish_reason: + stream_state["last_finish_reason"] = raw_finish_reason + # lib_logger.debug( + # f"iFlow: Found finish_reason='{raw_finish_reason}' in raw chunk" + # ) + + def normalize_choices( + choices_list: List[Dict[str, Any]], + force_final: bool = False, + ) -> List[Dict[str, Any]]: + """ + Normalizes choices array with robust finish_reason handling. + + Priority for finish_reason: + 1. tool_calls (if any tool_calls were seen in the stream) + 2. Explicit finish_reason from this chunk + 3. Last tracked finish_reason from stream_state + 4. Default to 'stop' (with warning) + """ + normalized = [] + for choice in choices_list: + choice_copy = dict(choice) if isinstance(choice, dict) else choice + delta = choice_copy.get("delta", {}) + + # Track tool_calls presence + if delta.get("tool_calls"): + stream_state["has_tool_calls"] = True + + # Track reasoning_content presence (for logging) + reasoning_content = delta.get("reasoning_content") + if reasoning_content and reasoning_content.strip(): + if not stream_state.get("has_reasoning_logged"): + # lib_logger.debug( + # f"iFlow: Chunk contains reasoning_content " + # f"({len(reasoning_content)} chars)" + # ) + stream_state["has_reasoning_logged"] = True + + # Get current finish_reason + finish_reason = choice_copy.get("finish_reason") + + # Track any finish_reason we see + if finish_reason: + stream_state["last_finish_reason"] = finish_reason + + # For final chunks, ensure finish_reason is ALWAYS set + if force_final: + # Priority: tool_calls > explicit > tracked > stop (with warning) + if stream_state.get("has_tool_calls"): + # Tool calls take highest priority + final_reason = "tool_calls" + if finish_reason and finish_reason != "tool_calls": + pass # Silently override - tool_calls takes priority + # lib_logger.debug( + # f"iFlow: Overriding finish_reason '{finish_reason}' " + # f"with 'tool_calls' (tool_calls present)" + # ) + elif finish_reason: + # Use explicit finish_reason from this chunk + final_reason = finish_reason + elif stream_state.get("last_finish_reason"): + # Use tracked finish_reason from earlier chunk + final_reason = stream_state["last_finish_reason"] + # lib_logger.debug( + # f"iFlow: Using tracked finish_reason '{final_reason}' " + # f"for final chunk" + # ) + else: + # No finish_reason found anywhere - default to stop with warning + final_reason = "stop" + lib_logger.warning( + f"iFlow: No finish_reason found in stream, defaulting to 'stop'" + ) + + choice_copy = {**choice_copy, "finish_reason": final_reason} + # lib_logger.debug( + # f"iFlow: Final chunk finish_reason set to '{final_reason}'" + # ) + else: + # For non-final chunks, normalize tool_calls if needed + if ( + finish_reason + and stream_state.get("has_tool_calls") + and finish_reason != "tool_calls" + ): + choice_copy = {**choice_copy, "finish_reason": "tool_calls"} + + normalized.append(choice_copy) + return normalized + + # Handle chunks with usage (final chunk indicator) + # Note: "usage": {} (empty dict) still indicates final chunk + if choices and has_usage: + # Normalize choices for final chunk - MUST set finish_reason + normalized_choices = normalize_choices(choices, force_final=True) + # Build usage dict, handling empty usage gracefully + usage_dict = { + "prompt_tokens": usage_data.get("prompt_tokens", 0) + if usage_data + else 0, + "completion_tokens": usage_data.get("completion_tokens", 0) + if usage_data + else 0, + "total_tokens": usage_data.get("total_tokens", 0) if usage_data else 0, } - # Then yield the usage chunk + + # CRITICAL FIX: If usage is empty/all-zeros (e.g., MiniMax sends "usage": {}), + # set placeholder non-zero values to ensure downstream processing + # (litellm/client) recognizes this as a final chunk and preserves finish_reason + if not any( + usage_dict.get(k, 0) > 0 + for k in ["prompt_tokens", "completion_tokens", "total_tokens"] + ): + usage_dict = { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + } + # lib_logger.debug( + # "iFlow: Empty usage detected, using placeholder values for final chunk" + # ) + yield { - "choices": [], + "choices": normalized_choices, "model": model_id, "object": "chat.completion.chunk", "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"), "created": chunk.get("created", int(time.time())), - "usage": { - "prompt_tokens": usage_data.get("prompt_tokens", 0), - "completion_tokens": usage_data.get("completion_tokens", 0), - "total_tokens": usage_data.get("total_tokens", 0), - }, + "usage": usage_dict, } return - # Handle usage-only chunks - if usage_data: + # Handle usage-only chunks (no choices) + if has_usage and not choices: + usage_dict = { + "prompt_tokens": usage_data.get("prompt_tokens", 0) + if usage_data + else 0, + "completion_tokens": usage_data.get("completion_tokens", 0) + if usage_data + else 0, + "total_tokens": usage_data.get("total_tokens", 0) if usage_data else 0, + } yield { "choices": [], "model": model_id, "object": "chat.completion.chunk", "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"), "created": chunk.get("created", int(time.time())), - "usage": { - "prompt_tokens": usage_data.get("prompt_tokens", 0), - "completion_tokens": usage_data.get("completion_tokens", 0), - "total_tokens": usage_data.get("total_tokens", 0), - }, + "usage": usage_dict, } return - # Handle content-only chunks + # Handle content-only chunks (no usage) if choices: - # iFlow returns OpenAI-compatible format, so we can mostly pass through + # Normalize choices - not final, so finish_reason not forced + normalized_choices = normalize_choices(choices, force_final=False) yield { - "choices": choices, + "choices": normalized_choices, "model": model_id, "object": "chat.completion.chunk", "id": chunk.get("id", f"chatcmpl-iflow-{time.time()}"), @@ -355,7 +804,25 @@ def _stream_to_completion_response( continue choice = chunk.choices[0] - delta = choice.get("delta", {}) + # Handle both dict and object access patterns for choice.delta + if hasattr(choice, "get"): + delta = choice.get("delta", {}) + choice_finish = choice.get("finish_reason") + elif hasattr(choice, "delta"): + delta = choice.delta if choice.delta else {} + # Convert delta to dict if it's an object + if hasattr(delta, "__dict__") and not isinstance(delta, dict): + delta = { + k: v + for k, v in delta.__dict__.items() + if not k.startswith("_") and v is not None + } + elif hasattr(delta, "model_dump"): + delta = delta.model_dump(exclude_none=True) + choice_finish = getattr(choice, "finish_reason", None) + else: + delta = {} + choice_finish = None # Aggregate content if "content" in delta and delta["content"] is not None: @@ -419,8 +886,8 @@ def _stream_to_completion_response( ]["arguments"] # Track finish_reason from chunks (for reference only) - if choice.get("finish_reason"): - chunk_finish_reason = choice["finish_reason"] + if choice_finish: + chunk_finish_reason = choice_finish # Handle usage data from the last chunk that has it for chunk in reversed(chunks): @@ -437,6 +904,10 @@ def _stream_to_completion_response( if field not in final_message: final_message[field] = None + # Remove MiniMax-specific reasoning_details - we have the full reasoning_content + # The reasoning_details array only contains partial data from the last chunk + final_message.pop("reasoning_details", None) + # Determine finish_reason based on accumulated state # Priority: tool_calls wins if present, then chunk's finish_reason, then default to "stop" if aggregated_tool_calls: @@ -485,20 +956,21 @@ async def make_request(): kwargs_with_stripped_model = {**kwargs, "model": model_name} # Build clean payload with only supported parameters - payload = self._build_request_payload(**kwargs_with_stripped_model) + # Pass original kwargs for thinking detection (reasoning_effort, etc.) + payload = self._build_request_payload( + model_name, kwargs, **kwargs_with_stripped_model + ) - headers = { - "Authorization": f"Bearer {api_key}", # Uses api_key from user info - "Content-Type": "application/json", - "Accept": "text/event-stream", - "User-Agent": "iFlow-Cli", - } + headers = self._build_iflow_headers( + api_key=api_key, + stream=bool(payload.get("stream")), + ) url = f"{api_base.rstrip('/')}/chat/completions" # Log request to dedicated file file_logger.log_request(payload) - lib_logger.debug(f"iFlow Request URL: {url}") + # lib_logger.debug(f"iFlow Request URL: {url}") return client.stream( "POST", @@ -510,6 +982,8 @@ async def make_request(): async def stream_handler(response_stream, attempt=1): """Handles the streaming response and converts chunks.""" + # Track state across chunks for finish_reason normalization + stream_state: Dict[str, Any] = {} try: async with response_stream as response: # Check for HTTP errors before processing stream @@ -546,6 +1020,11 @@ async def stream_handler(response_stream, attempt=1): # Handle other errors else: + if not error_text: + content_type = response.headers.get("content-type", "") + error_text = ( + f"(empty response body, content-type={content_type})" + ) error_msg = ( f"iFlow HTTP {response.status_code} error: {error_text}" ) @@ -569,11 +1048,13 @@ async def stream_handler(response_stream, attempt=1): data_str = line[5:] # Skip "data:" if data_str.strip() == "[DONE]": + # lib_logger.debug("iFlow: Received [DONE] marker") break try: chunk = json.loads(data_str) + for openai_chunk in self._convert_chunk_to_openai( - chunk, model + chunk, model, stream_state ): yield litellm.ModelResponse(**openai_chunk) except json.JSONDecodeError: @@ -591,7 +1072,7 @@ async def stream_handler(response_stream, attempt=1): raise async def logging_stream_wrapper(): - """Wraps the stream to log the final reassembled response.""" + """Wraps the stream to log the final reassembled response and cache reasoning.""" openai_chunks = [] try: async for chunk in stream_handler(await make_request()): @@ -602,6 +1083,34 @@ async def logging_stream_wrapper(): final_response = self._stream_to_completion_response(openai_chunks) file_logger.log_final_response(final_response.dict()) + # Store reasoning_content from the response for future multi-turn conversations + # This enables reasoning preservation in subsequent requests + model_name = model.split("/")[-1] + messages = kwargs.get("messages", []) + if messages: + conv_sig = self._get_conversation_signature(messages) + # Get the assistant message from the final response + if final_response.choices and len(final_response.choices) > 0: + choice = final_response.choices[0] + message = getattr(choice, "message", None) + if message: + # Convert to dict if needed + if hasattr(message, "model_dump"): + msg_dict = message.model_dump() + elif hasattr(message, "__dict__"): + msg_dict = { + k: v + for k, v in message.__dict__.items() + if not k.startswith("_") + } + else: + msg_dict = ( + dict(message) + if isinstance(message, dict) + else {} + ) + self._store_reasoning_content(msg_dict, conv_sig) + if kwargs.get("stream"): return logging_stream_wrapper() else: diff --git a/src/rotator_library/providers/nanogpt_provider.py b/src/rotator_library/providers/nanogpt_provider.py index 52a648e4..456117de 100644 --- a/src/rotator_library/providers/nanogpt_provider.py +++ b/src/rotator_library/providers/nanogpt_provider.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager from .provider_interface import ProviderInterface, UsageResetConfigDef from .utilities.nanogpt_quota_tracker import NanoGptQuotaTracker @@ -74,8 +74,8 @@ class NanoGptProvider(NanoGptQuotaTracker, ProviderInterface): # Active subscriptions get highest priority tier_priorities = { "subscription-active": 1, # Active subscription - "subscription-grace": 2, # Grace period (subscription lapsed but still has access) - "no-subscription": 3, # No active subscription (pay-as-you-go only) + "subscription-grace": 2, # Grace period (subscription lapsed but still has access) + "no-subscription": 3, # No active subscription (pay-as-you-go only) } default_tier_priority = 3 @@ -86,8 +86,6 @@ class NanoGptProvider(NanoGptQuotaTracker, ProviderInterface): "monthly": ["_monthly"], } - - def __init__(self): self.model_definitions = ModelDefinitions() @@ -410,29 +408,49 @@ async def refresh_single_credential( monthly_remaining = monthly_data.get("remaining", 0) # Calculate remaining fractions - daily_fraction = daily_remaining / daily_limit if daily_limit > 0 else 1.0 - monthly_fraction = monthly_remaining / monthly_limit if monthly_limit > 0 else 1.0 + daily_fraction = ( + daily_remaining / daily_limit if daily_limit > 0 else 1.0 + ) + monthly_fraction = ( + monthly_remaining / monthly_limit + if monthly_limit > 0 + else 1.0 + ) # Get reset timestamps daily_reset_ts = daily_data.get("reset_at", 0) monthly_reset_ts = monthly_data.get("reset_at", 0) # Store daily quota baseline + daily_used = ( + int((1.0 - daily_fraction) * daily_limit) + if daily_limit > 0 + else 0 + ) await usage_manager.update_quota_baseline( api_key, "nanogpt/_daily", - daily_fraction, - max_requests=daily_limit, - reset_timestamp=daily_reset_ts if daily_reset_ts > 0 else None, + quota_max_requests=daily_limit, + quota_reset_ts=daily_reset_ts + if daily_reset_ts > 0 + else None, + quota_used=daily_used, ) # Store monthly quota baseline + monthly_used = ( + int((1.0 - monthly_fraction) * monthly_limit) + if monthly_limit > 0 + else 0 + ) await usage_manager.update_quota_baseline( api_key, "nanogpt/_monthly", - monthly_fraction, - max_requests=monthly_limit, - reset_timestamp=monthly_reset_ts if monthly_reset_ts > 0 else None, + quota_max_requests=monthly_limit, + quota_reset_ts=monthly_reset_ts + if monthly_reset_ts > 0 + else None, + quota_used=monthly_used, ) lib_logger.debug( diff --git a/src/rotator_library/providers/openai_codex_auth_base.py b/src/rotator_library/providers/openai_codex_auth_base.py new file mode 100644 index 00000000..58920392 --- /dev/null +++ b/src/rotator_library/providers/openai_codex_auth_base.py @@ -0,0 +1,1519 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +# src/rotator_library/providers/openai_codex_auth_base.py + +import asyncio +import base64 +import copy +import hashlib +import json +import logging +import os +import re +import secrets +import time +import webbrowser +from dataclasses import dataclass, field +from glob import glob +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from urllib.parse import urlencode + +import httpx +from aiohttp import web +from rich.console import Console +from rich.markup import escape as rich_escape +from rich.panel import Panel +from rich.text import Text + +from ..error_handler import CredentialNeedsReauthError +from ..utils.headless_detection import is_headless_environment +from ..utils.reauth_coordinator import get_reauth_coordinator +from ..utils.resilient_io import safe_write_json + +lib_logger = logging.getLogger("rotator_library") + +# OAuth constants +CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +SCOPE = "openid profile email offline_access" +AUTHORIZATION_ENDPOINT = "https://auth.openai.com/oauth/authorize" +TOKEN_ENDPOINT = "https://auth.openai.com/oauth/token" +# OpenAI Codex OAuth redirect path registered for this client. +# Keep legacy `/oauth2callback` handler for backward compatibility with old URLs. +CALLBACK_PATH = "/auth/callback" +LEGACY_CALLBACK_PATH = "/oauth2callback" +CALLBACK_PORT = 1455 +CALLBACK_ENV_VAR = "OPENAI_CODEX_OAUTH_PORT" + +# API constants +DEFAULT_API_BASE = "https://chatgpt.com/backend-api" +RESPONSES_ENDPOINT_PATH = "/codex/responses" + +# JWT claims +AUTH_CLAIM = "https://api.openai.com/auth" +ACCOUNT_ID_CLAIM = "https://api.openai.com/auth.chatgpt_account_id" + +# Refresh when token is close to expiry +REFRESH_EXPIRY_BUFFER_SECONDS = 5 * 60 # 5 minutes + +console = Console() + + +@dataclass +class OpenAICodexCredentialSetupResult: + """Standardized result structure for OpenAI Codex credential setup operations.""" + + success: bool + file_path: Optional[str] = None + email: Optional[str] = None + is_update: bool = False + error: Optional[str] = None + credentials: Optional[Dict[str, Any]] = field(default=None, repr=False) + + +class OAuthCallbackServer: + """Minimal HTTP server for handling OpenAI Codex OAuth callbacks.""" + + SUCCESS_HTML = """ + + + + + Authentication successful + + +

Authentication successful. Return to your terminal to continue.

+ +""" + + def __init__(self, port: int = CALLBACK_PORT): + self.port = port + self.app = web.Application() + self.runner: Optional[web.AppRunner] = None + self.site: Optional[web.TCPSite] = None + self.result_future: Optional[asyncio.Future] = None + self.expected_state: Optional[str] = None + + async def start(self, expected_state: str): + """Start callback server on localhost:.""" + self.expected_state = expected_state + self.result_future = asyncio.Future() + + for callback_path in {CALLBACK_PATH, LEGACY_CALLBACK_PATH}: + self.app.router.add_get(callback_path, self._handle_callback) + + self.runner = web.AppRunner(self.app) + await self.runner.setup() + self.site = web.TCPSite(self.runner, "localhost", self.port) + await self.site.start() + + lib_logger.debug( + "OpenAI Codex OAuth callback server started on " + f"localhost:{self.port}{CALLBACK_PATH} " + f"(legacy alias: {LEGACY_CALLBACK_PATH})" + ) + + async def stop(self): + """Stop callback server.""" + if self.site: + await self.site.stop() + if self.runner: + await self.runner.cleanup() + lib_logger.debug("OpenAI Codex OAuth callback server stopped") + + async def _handle_callback(self, request: web.Request) -> web.Response: + query = request.query + + if "error" in query: + error = query.get("error", "unknown_error") + error_desc = query.get("error_description", "") + if not self.result_future.done(): + self.result_future.set_exception( + ValueError(f"OAuth error: {error} ({error_desc})") + ) + return web.Response(status=400, text=f"OAuth error: {error}") + + code = query.get("code") + state = query.get("state", "") + + if not code: + if not self.result_future.done(): + self.result_future.set_exception( + ValueError("Missing authorization code") + ) + return web.Response(status=400, text="Missing authorization code") + + if state != self.expected_state: + if not self.result_future.done(): + self.result_future.set_exception(ValueError("State parameter mismatch")) + return web.Response(status=400, text="State mismatch") + + if not self.result_future.done(): + self.result_future.set_result(code) + + return web.Response( + status=200, + text=self.SUCCESS_HTML, + content_type="text/html", + ) + + async def wait_for_callback(self, timeout: float = 300.0) -> str: + """Wait for OAuth callback and return auth code.""" + try: + code = await asyncio.wait_for(self.result_future, timeout=timeout) + return code + except asyncio.TimeoutError: + raise TimeoutError("Timeout waiting for OAuth callback") + + +def get_callback_port() -> int: + """Get OAuth callback port from env or fallback default.""" + env_value = os.getenv(CALLBACK_ENV_VAR) + if env_value: + try: + return int(env_value) + except ValueError: + lib_logger.warning( + f"Invalid {CALLBACK_ENV_VAR} value: {env_value}, using default {CALLBACK_PORT}" + ) + return CALLBACK_PORT + + +class OpenAICodexAuthBase: + """ + OpenAI Codex OAuth authentication base class. + + Supports: + - Interactive OAuth Authorization Code + PKCE + - Token refresh with retry/backoff + - File + env credential loading (`env://openai_codex/N`) + - Queue-based refresh and re-auth workflows + - Credential management APIs for credential_tool + """ + + CALLBACK_PORT = CALLBACK_PORT + CALLBACK_ENV_VAR = CALLBACK_ENV_VAR + + def __init__(self): + self._credentials_cache: Dict[str, Dict[str, Any]] = {} + self._refresh_locks: Dict[str, asyncio.Lock] = {} + self._locks_lock = asyncio.Lock() + + # Backoff tracking + self._refresh_failures: Dict[str, int] = {} + self._next_refresh_after: Dict[str, float] = {} + + # Queue system (normal refresh + interactive re-auth) + self._refresh_queue: asyncio.Queue = asyncio.Queue() + self._queue_processor_task: Optional[asyncio.Task] = None + + self._reauth_queue: asyncio.Queue = asyncio.Queue() + self._reauth_processor_task: Optional[asyncio.Task] = None + + self._queued_credentials: set = set() + self._unavailable_credentials: Dict[str, float] = {} + self._unavailable_ttl_seconds: int = 360 + self._queue_tracking_lock = asyncio.Lock() + + self._queue_retry_count: Dict[str, int] = {} + + # Queue configuration + self._refresh_timeout_seconds: int = 20 + self._refresh_interval_seconds: int = 20 + self._refresh_max_retries: int = 3 + self._reauth_timeout_seconds: int = 300 + + # ========================================================================= + # JWT + metadata helpers + # ========================================================================= + + @staticmethod + def _decode_jwt_unverified(token: str) -> Optional[Dict[str, Any]]: + """Decode JWT payload without signature verification.""" + if not token or not isinstance(token, str): + return None + + parts = token.split(".") + if len(parts) < 2: + return None + + payload_segment = parts[1] + padding = "=" * (-len(payload_segment) % 4) + + try: + payload_bytes = base64.urlsafe_b64decode(payload_segment + padding) + payload = json.loads(payload_bytes.decode("utf-8")) + return payload if isinstance(payload, dict) else None + except Exception: + return None + + @staticmethod + def _extract_account_id_from_payload(payload: Optional[Dict[str, Any]]) -> Optional[str]: + """Extract account ID from JWT claims.""" + if not payload: + return None + + # 1) Direct dotted claim format (requested by plan) + direct = payload.get(ACCOUNT_ID_CLAIM) + if isinstance(direct, str) and direct.strip(): + return direct.strip() + + # 2) Nested object claim format observed in real tokens + auth_claim = payload.get(AUTH_CLAIM) + if isinstance(auth_claim, dict): + nested = auth_claim.get("chatgpt_account_id") + if isinstance(nested, str) and nested.strip(): + return nested.strip() + + # 3) Fallback organizations[0].id if present + orgs = payload.get("organizations") + if isinstance(orgs, list) and orgs: + first = orgs[0] + if isinstance(first, dict): + org_id = first.get("id") + if isinstance(org_id, str) and org_id.strip(): + return org_id.strip() + + return None + + @staticmethod + def _extract_explicit_email_from_payload( + payload: Optional[Dict[str, Any]], + ) -> Optional[str]: + """Extract explicit email claim only (no sub fallback).""" + if not payload: + return None + + email = payload.get("email") + if isinstance(email, str) and email.strip(): + return email.strip() + + return None + + @staticmethod + def _extract_email_from_payload(payload: Optional[Dict[str, Any]]) -> Optional[str]: + """Extract email from JWT payload using fallback chain: email -> sub.""" + if not payload: + return None + + email = payload.get("email") + if isinstance(email, str) and email.strip(): + return email.strip() + + sub = payload.get("sub") + if isinstance(sub, str) and sub.strip(): + return sub.strip() + + return None + + @staticmethod + def _extract_expiry_ms_from_payload(payload: Optional[Dict[str, Any]]) -> Optional[int]: + """Extract JWT exp claim and convert to milliseconds.""" + if not payload: + return None + + exp = payload.get("exp") + if isinstance(exp, (int, float)): + return int(float(exp) * 1000) + + return None + + def _populate_metadata_from_tokens(self, creds: Dict[str, Any]) -> None: + """Populate _proxy_metadata (email/account_id) from access_token or id_token.""" + metadata = creds.setdefault("_proxy_metadata", {}) + + access_payload = self._decode_jwt_unverified(creds.get("access_token", "")) + id_payload = self._decode_jwt_unverified(creds.get("id_token", "")) + + account_id = self._extract_account_id_from_payload( + access_payload + ) or self._extract_account_id_from_payload(id_payload) + + # Prefer explicit email claim from id_token first (most user-specific), + # then explicit access-token email, then fall back to sub-based extraction. + email = ( + self._extract_explicit_email_from_payload(id_payload) + or self._extract_explicit_email_from_payload(access_payload) + or self._extract_email_from_payload(id_payload) + or self._extract_email_from_payload(access_payload) + ) + + if account_id: + metadata["account_id"] = account_id + + if email: + metadata["email"] = email + + # Keep top-level expiry_date synchronized from token exp as fallback + if not creds.get("expiry_date"): + expiry_ms = self._extract_expiry_ms_from_payload(access_payload) or self._extract_expiry_ms_from_payload( + id_payload + ) + if expiry_ms: + creds["expiry_date"] = expiry_ms + + metadata["last_check_timestamp"] = time.time() + + def _ensure_proxy_metadata(self, creds: Dict[str, Any]) -> Dict[str, Any]: + """Ensure credentials include normalized _proxy_metadata fields.""" + metadata = creds.setdefault("_proxy_metadata", {}) + metadata.setdefault("loaded_from_env", False) + metadata.setdefault("env_credential_index", None) + + self._populate_metadata_from_tokens(creds) + + # Keep top-level token_uri stable for schema consistency + creds.setdefault("token_uri", TOKEN_ENDPOINT) + + return creds + + # ========================================================================= + # Env + file credential loading + # ========================================================================= + + def _parse_env_credential_path(self, path: str) -> Optional[str]: + """ + Parse a virtual env:// path and return the credential index. + + Supported formats: + - env://openai_codex/0 (legacy single) + - env://openai_codex/1 (numbered) + """ + if not path.startswith("env://"): + return None + + raw = path[6:] + parts = raw.split("/") + if not parts: + return None + + provider = parts[0].strip().lower() + if provider != "openai_codex": + return None + + if len(parts) >= 2 and parts[1].strip(): + return parts[1].strip() + + return "0" + + def _load_from_env( + self, credential_index: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Load OpenAI Codex OAuth credentials from environment variables. + + Legacy single credential: + - OPENAI_CODEX_ACCESS_TOKEN + - OPENAI_CODEX_REFRESH_TOKEN + - OPENAI_CODEX_EXPIRY_DATE (optional) + - OPENAI_CODEX_ID_TOKEN (optional) + - OPENAI_CODEX_ACCOUNT_ID (optional) + - OPENAI_CODEX_EMAIL (optional) + + Numbered credentials (N): + - OPENAI_CODEX_N_ACCESS_TOKEN + - OPENAI_CODEX_N_REFRESH_TOKEN + - OPENAI_CODEX_N_EXPIRY_DATE (optional) + - OPENAI_CODEX_N_ID_TOKEN (optional) + - OPENAI_CODEX_N_ACCOUNT_ID (optional) + - OPENAI_CODEX_N_EMAIL (optional) + """ + if credential_index and credential_index != "0": + prefix = f"OPENAI_CODEX_{credential_index}" + default_email = f"env-user-{credential_index}" + env_index = credential_index + else: + prefix = "OPENAI_CODEX" + default_email = "env-user" + env_index = "0" + + access_token = os.getenv(f"{prefix}_ACCESS_TOKEN") + refresh_token = os.getenv(f"{prefix}_REFRESH_TOKEN") + + if not (access_token and refresh_token): + return None + + expiry_raw = os.getenv(f"{prefix}_EXPIRY_DATE", "") + expiry_date: Optional[int] = None + if expiry_raw: + try: + expiry_date = int(float(expiry_raw)) + except ValueError: + lib_logger.warning(f"Invalid {prefix}_EXPIRY_DATE: {expiry_raw}") + + id_token = os.getenv(f"{prefix}_ID_TOKEN") + account_id = os.getenv(f"{prefix}_ACCOUNT_ID") + email = os.getenv(f"{prefix}_EMAIL") + + creds: Dict[str, Any] = { + "access_token": access_token, + "refresh_token": refresh_token, + "id_token": id_token, + "token_uri": TOKEN_ENDPOINT, + "expiry_date": expiry_date or 0, + "_proxy_metadata": { + "email": email or default_email, + "account_id": account_id, + "last_check_timestamp": time.time(), + "loaded_from_env": True, + "env_credential_index": env_index, + }, + } + + # Fill missing metadata/expiry from JWT claims + self._populate_metadata_from_tokens(creds) + + # If expiry still missing, set conservative short expiry to trigger refresh soon + if not creds.get("expiry_date"): + creds["expiry_date"] = int((time.time() + 300) * 1000) + + return creds + + async def _read_creds_from_file(self, path: str) -> Dict[str, Any]: + """Read credentials from disk and update cache.""" + try: + with open(path, "r") as f: + creds = json.load(f) + + if not isinstance(creds, dict): + raise ValueError("Credential file root must be a JSON object") + + creds = self._ensure_proxy_metadata(creds) + self._credentials_cache[path] = creds + return creds + + except FileNotFoundError: + raise IOError(f"OpenAI Codex credential file not found at '{path}'") + except Exception as e: + raise IOError( + f"Failed to load OpenAI Codex credentials from '{path}': {e}" + ) + + async def _load_credentials(self, path: str) -> Dict[str, Any]: + """Load credentials from cache, env, or file.""" + if path in self._credentials_cache: + return self._credentials_cache[path] + + async with await self._get_lock(path): + if path in self._credentials_cache: + return self._credentials_cache[path] + + credential_index = self._parse_env_credential_path(path) + if credential_index is not None: + env_creds = self._load_from_env(credential_index) + if env_creds: + self._credentials_cache[path] = env_creds + lib_logger.info( + f"Using OpenAI Codex env credential index {credential_index}" + ) + return env_creds + raise IOError( + f"Environment variables for OpenAI Codex credential index {credential_index} not found" + ) + + # File-based path, with legacy env fallback for backwards compatibility + try: + return await self._read_creds_from_file(path) + except IOError: + env_creds = self._load_from_env("0") + if env_creds: + self._credentials_cache[path] = env_creds + lib_logger.info( + f"File '{path}' not found; using legacy OPENAI_CODEX_* environment credentials" + ) + return env_creds + raise + + async def _save_credentials(self, path: str, creds: Dict[str, Any]) -> bool: + """ + Save credentials to disk, then update cache. + + Critical semantics: + - For rotating refresh tokens, disk write must succeed before cache update. + - Env-backed creds skip disk writes and update in-memory cache only. + """ + creds = self._ensure_proxy_metadata(copy.deepcopy(creds)) + + loaded_from_env = creds.get("_proxy_metadata", {}).get("loaded_from_env", False) + if loaded_from_env or self._parse_env_credential_path(path) is not None: + self._credentials_cache[path] = creds + lib_logger.debug( + f"OpenAI Codex credential '{path}' is env-backed; skipping disk write" + ) + return True + + if not safe_write_json( + path, + creds, + lib_logger, + secure_permissions=True, + buffer_on_failure=False, + ): + lib_logger.error( + f"Failed to persist OpenAI Codex credentials for '{Path(path).name}'. Cache not updated." + ) + return False + + self._credentials_cache[path] = creds + return True + + # ========================================================================= + # Expiry / refresh helpers + # ========================================================================= + + def _is_token_expired(self, creds: Dict[str, Any]) -> bool: + """Proactive expiry check using refresh buffer.""" + expiry_timestamp = float(creds.get("expiry_date", 0)) / 1000 + return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS + + def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool: + """Strict expiry check without proactive buffer.""" + expiry_timestamp = float(creds.get("expiry_date", 0)) / 1000 + return expiry_timestamp < time.time() + + async def _exchange_code_for_tokens( + self, code: str, code_verifier: str, redirect_uri: str + ) -> Dict[str, Any]: + """Exchange OAuth authorization code for tokens.""" + payload = { + "grant_type": "authorization_code", + "code": code, + "client_id": CLIENT_ID, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + } + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + "User-Agent": "LLM-API-Key-Proxy/OpenAICodex", + } + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post(TOKEN_ENDPOINT, headers=headers, data=payload) + response.raise_for_status() + token_data = response.json() + + access_token = token_data.get("access_token") + refresh_token = token_data.get("refresh_token") + expires_in = token_data.get("expires_in") + + if not access_token or not refresh_token or not isinstance(expires_in, (int, float)): + raise ValueError("Token exchange response missing required fields") + + return token_data + + async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]: + """Refresh access token using refresh_token with retry/backoff.""" + async with await self._get_lock(path): + cached_creds = self._credentials_cache.get(path) + if not force and cached_creds and not self._is_token_expired(cached_creds): + return cached_creds + + # Always load freshest source before refresh attempt + is_env = self._parse_env_credential_path(path) is not None + if is_env: + source_creds = copy.deepcopy(await self._load_credentials(path)) + else: + await self._read_creds_from_file(path) + source_creds = copy.deepcopy(self._credentials_cache[path]) + + refresh_token = source_creds.get("refresh_token") + if not refresh_token: + raise ValueError("No refresh_token found in OpenAI Codex credentials") + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + "User-Agent": "LLM-API-Key-Proxy/OpenAICodex", + } + + max_retries = 3 + token_data = None + last_error: Optional[Exception] = None + + async with httpx.AsyncClient(timeout=30.0) as client: + for attempt in range(max_retries): + try: + response = await client.post( + TOKEN_ENDPOINT, + headers=headers, + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CLIENT_ID, + }, + ) + response.raise_for_status() + token_data = response.json() + break + + except httpx.HTTPStatusError as e: + last_error = e + status_code = e.response.status_code + + error_type = "" + error_desc = "" + try: + payload = e.response.json() + error_type = payload.get("error", "") + error_desc = payload.get("error_description", "") or payload.get( + "message", "" + ) + except Exception: + error_desc = e.response.text + + # invalid_grant and authorization failures should trigger re-auth queue + if status_code == 400: + if ( + error_type == "invalid_grant" + or "invalid_grant" in error_desc.lower() + or "invalid" in error_desc.lower() + ): + asyncio.create_task( + self._queue_refresh(path, force=True, needs_reauth=True) + ) + raise CredentialNeedsReauthError( + credential_path=path, + message=( + f"OpenAI Codex refresh token invalid for '{Path(path).name}'. Re-auth queued." + ), + ) + raise + + if status_code in (401, 403): + asyncio.create_task( + self._queue_refresh(path, force=True, needs_reauth=True) + ) + raise CredentialNeedsReauthError( + credential_path=path, + message=( + f"OpenAI Codex credential '{Path(path).name}' unauthorized (HTTP {status_code}). Re-auth queued." + ), + ) + + if status_code == 429: + retry_after = e.response.headers.get("Retry-After", "60") + try: + wait_seconds = max(1, int(float(retry_after))) + except ValueError: + wait_seconds = 60 + + if attempt < max_retries - 1: + await asyncio.sleep(wait_seconds) + continue + raise + + if 500 <= status_code < 600: + if attempt < max_retries - 1: + await asyncio.sleep(2**attempt) + continue + raise + + raise + + except (httpx.RequestError, httpx.TimeoutException) as e: + last_error = e + if attempt < max_retries - 1: + await asyncio.sleep(2**attempt) + continue + raise + + if token_data is None: + self._refresh_failures[path] = self._refresh_failures.get(path, 0) + 1 + backoff_seconds = min(300, 30 * (2 ** self._refresh_failures[path])) + self._next_refresh_after[path] = time.time() + backoff_seconds + raise last_error or Exception("OpenAI Codex token refresh failed") + + access_token = token_data.get("access_token") + if not access_token: + raise ValueError("Refresh response missing access_token") + + expires_in = token_data.get("expires_in") + if not isinstance(expires_in, (int, float)): + raise ValueError("Refresh response missing expires_in") + + # Build UPDATED credential object (do not mutate cached source in-place) + updated_creds = copy.deepcopy(source_creds) + updated_creds["access_token"] = access_token + updated_creds["refresh_token"] = token_data.get( + "refresh_token", updated_creds.get("refresh_token") + ) + + if token_data.get("id_token"): + updated_creds["id_token"] = token_data.get("id_token") + + updated_creds["expiry_date"] = int((time.time() + float(expires_in)) * 1000) + updated_creds["token_uri"] = TOKEN_ENDPOINT + + self._ensure_proxy_metadata(updated_creds) + + if not updated_creds.get("access_token") or not updated_creds.get( + "refresh_token" + ): + raise ValueError("Refreshed credentials missing required token fields") + + # Successful refresh clears backoff tracking + self._refresh_failures.pop(path, None) + self._next_refresh_after.pop(path, None) + + # Persist before mutating shared cache state + if not await self._save_credentials(path, updated_creds): + raise IOError( + f"Failed to persist refreshed OpenAI Codex credential '{Path(path).name}'" + ) + + return self._credentials_cache[path] + + # ========================================================================= + # Interactive OAuth flow + # ========================================================================= + + async def _perform_interactive_oauth( + self, + path: Optional[str], + creds: Dict[str, Any], + display_name: str, + ) -> Dict[str, Any]: + """Perform interactive OpenAI Codex OAuth authorization code flow with PKCE.""" + is_headless = is_headless_environment() + + # PKCE verifier/challenge (base64url, no padding) + code_verifier = ( + base64.urlsafe_b64encode(secrets.token_bytes(32)) + .decode("utf-8") + .rstrip("=") + ) + code_challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode("utf-8")).digest() + ) + .decode("utf-8") + .rstrip("=") + ) + state = secrets.token_hex(32) + + callback_port = get_callback_port() + redirect_uri = f"http://localhost:{callback_port}{CALLBACK_PATH}" + + auth_params = { + "response_type": "code", + "client_id": CLIENT_ID, + "redirect_uri": redirect_uri, + "scope": SCOPE, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": state, + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "originator": "pi", + } + auth_url = f"{AUTHORIZATION_ENDPOINT}?{urlencode(auth_params)}" + + callback_server = OAuthCallbackServer(port=callback_port) + + try: + await callback_server.start(expected_state=state) + + if is_headless: + help_text = Text.from_markup( + "Running in headless environment.\n" + "Open the URL below in a browser on another machine and complete login." + ) + else: + help_text = Text.from_markup( + "Open the URL below, complete sign-in, and return here." + ) + + console.print( + Panel( + help_text, + title=f"OpenAI Codex OAuth Setup for [bold yellow]{display_name}[/bold yellow]", + style="bold blue", + ) + ) + escaped_url = rich_escape(auth_url) + console.print(f"[bold]URL:[/bold] [link={auth_url}]{escaped_url}[/link]\n") + + if not is_headless: + try: + webbrowser.open(auth_url) + lib_logger.info("Browser opened for OpenAI Codex OAuth flow") + except Exception as e: + lib_logger.warning( + f"Failed to auto-open browser for OpenAI Codex OAuth: {e}" + ) + + code = await callback_server.wait_for_callback( + timeout=float(self._reauth_timeout_seconds) + ) + + token_data = await self._exchange_code_for_tokens( + code=code, + code_verifier=code_verifier, + redirect_uri=redirect_uri, + ) + + # Build updated credential object + updated_creds = copy.deepcopy(creds) + metadata = updated_creds.setdefault("_proxy_metadata", {}) + loaded_from_env = metadata.get("loaded_from_env", False) + env_index = metadata.get("env_credential_index") + + updated_creds.update( + { + "access_token": token_data.get("access_token"), + "refresh_token": token_data.get("refresh_token"), + "id_token": token_data.get("id_token"), + "token_uri": TOKEN_ENDPOINT, + "expiry_date": int( + (time.time() + float(token_data.get("expires_in", 3600))) * 1000 + ), + } + ) + + # Restore env metadata flags if this credential originated from env + updated_creds.setdefault("_proxy_metadata", {}) + updated_creds["_proxy_metadata"]["loaded_from_env"] = loaded_from_env + updated_creds["_proxy_metadata"]["env_credential_index"] = env_index + + self._ensure_proxy_metadata(updated_creds) + + if path: + if not await self._save_credentials(path, updated_creds): + raise IOError( + f"Failed to save OpenAI Codex OAuth credentials for '{display_name}'" + ) + else: + # in-memory setup flow + creds.clear() + creds.update(updated_creds) + + lib_logger.info( + f"OpenAI Codex OAuth initialized successfully for '{display_name}'" + ) + return updated_creds + + finally: + await callback_server.stop() + + async def initialize_token( + self, + creds_or_path: Union[Dict[str, Any], str], + force_interactive: bool = False, + ) -> Dict[str, Any]: + """ + Initialize OAuth token, refreshing or running interactive flow as needed. + + Interactive re-auth is globally coordinated via ReauthCoordinator so only + one flow runs at a time across all providers. + """ + path = creds_or_path if isinstance(creds_or_path, str) else None + + if isinstance(creds_or_path, dict): + display_name = creds_or_path.get("_proxy_metadata", {}).get( + "display_name", "in-memory OpenAI Codex credential" + ) + else: + display_name = Path(path).name if path else "in-memory OpenAI Codex credential" + + try: + creds = ( + await self._load_credentials(creds_or_path) if path else copy.deepcopy(creds_or_path) + ) + + reason = "" + if force_interactive: + reason = "interactive re-auth explicitly requested" + elif not creds.get("refresh_token"): + reason = "refresh token is missing" + elif self._is_token_expired(creds): + reason = "token is expired" + + if reason: + # Prefer non-interactive refresh when we have a refresh token and this is simple expiry + if reason == "token is expired" and creds.get("refresh_token") and path: + try: + return await self._refresh_token(path) + except CredentialNeedsReauthError: + # Explicitly fall through into interactive re-auth path + pass + except Exception as e: + lib_logger.warning( + f"Automatic OpenAI Codex token refresh failed for '{display_name}': {e}. Falling back to interactive login." + ) + + coordinator = get_reauth_coordinator() + + async def _do_interactive_oauth(): + return await self._perform_interactive_oauth(path, creds, display_name) + + result = await coordinator.execute_reauth( + credential_path=path or display_name, + provider_name="OPENAI_CODEX", + reauth_func=_do_interactive_oauth, + timeout=float(self._reauth_timeout_seconds), + ) + + # Persist cache when path-based + if path and isinstance(result, dict): + self._credentials_cache[path] = self._ensure_proxy_metadata(result) + + return result + + # Token is already valid + creds = self._ensure_proxy_metadata(creds) + if path: + self._credentials_cache[path] = creds + return creds + + except Exception as e: + raise ValueError( + f"Failed to initialize OpenAI Codex OAuth credential '{display_name}': {e}" + ) + + async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]: + creds = await self._load_credentials(credential_identifier) + if self._is_token_expired(creds): + creds = await self._refresh_token(credential_identifier) + return {"Authorization": f"Bearer {creds['access_token']}"} + + async def get_user_info( + self, creds_or_path: Union[Dict[str, Any], str] + ) -> Dict[str, Any]: + """Retrieve user info from _proxy_metadata.""" + try: + path = creds_or_path if isinstance(creds_or_path, str) else None + creds = await self._load_credentials(path) if path else copy.deepcopy(creds_or_path) + + if path: + await self.initialize_token(path) + creds = await self._load_credentials(path) + + metadata = creds.get("_proxy_metadata", {}) + email = metadata.get("email") + account_id = metadata.get("account_id") + + # Update timestamp in cache only (non-critical metadata) + if path and "_proxy_metadata" in creds: + creds["_proxy_metadata"]["last_check_timestamp"] = time.time() + self._credentials_cache[path] = creds + + return { + "email": email, + "account_id": account_id, + } + except Exception as e: + lib_logger.error(f"Failed to get OpenAI Codex user info: {e}") + return {"email": None, "account_id": None} + + async def proactively_refresh(self, credential_identifier: str): + """Queue proactive refresh for credentials near expiry.""" + try: + creds = await self._load_credentials(credential_identifier) + except IOError: + return + + if self._is_token_expired(creds): + await self._queue_refresh( + credential_identifier, + force=False, + needs_reauth=False, + ) + + # ========================================================================= + # Queue + availability plumbing + # ========================================================================= + + async def _get_lock(self, path: str) -> asyncio.Lock: + async with self._locks_lock: + if path not in self._refresh_locks: + self._refresh_locks[path] = asyncio.Lock() + return self._refresh_locks[path] + + def is_credential_available(self, path: str) -> bool: + """ + Check if credential is available for rotation. + + Unavailable when: + - In re-auth queue + - Truly expired (past actual expiry) + """ + if path in self._unavailable_credentials: + marked_time = self._unavailable_credentials.get(path) + if marked_time is not None: + now = time.time() + if now - marked_time > self._unavailable_ttl_seconds: + lib_logger.warning( + f"OpenAI Codex credential '{Path(path).name}' stuck in re-auth queue for {int(now - marked_time)}s. Auto-cleaning stale entry." + ) + self._unavailable_credentials.pop(path, None) + self._queued_credentials.discard(path) + else: + return False + + creds = self._credentials_cache.get(path) + if creds and self._is_token_truly_expired(creds): + if path not in self._queued_credentials: + try: + loop = asyncio.get_running_loop() + loop.create_task( + self._queue_refresh(path, force=True, needs_reauth=False) + ) + except RuntimeError: + # No running event loop (e.g., sync context); caller can still + # trigger refresh through normal async request flow. + pass + return False + + return True + + async def _ensure_queue_processor_running(self): + if self._queue_processor_task is None or self._queue_processor_task.done(): + self._queue_processor_task = asyncio.create_task(self._process_refresh_queue()) + + async def _ensure_reauth_processor_running(self): + if self._reauth_processor_task is None or self._reauth_processor_task.done(): + self._reauth_processor_task = asyncio.create_task(self._process_reauth_queue()) + + async def _queue_refresh( + self, + path: str, + force: bool = False, + needs_reauth: bool = False, + ): + """Queue credential for refresh or re-auth.""" + if not needs_reauth: + now = time.time() + backoff_until = self._next_refresh_after.get(path) + if backoff_until and now < backoff_until: + return + + async with self._queue_tracking_lock: + if path in self._queued_credentials: + return + + self._queued_credentials.add(path) + + if needs_reauth: + self._unavailable_credentials[path] = time.time() + await self._reauth_queue.put(path) + await self._ensure_reauth_processor_running() + else: + await self._refresh_queue.put((path, force)) + await self._ensure_queue_processor_running() + + async def _process_refresh_queue(self): + """Sequential background worker for normal refresh queue.""" + while True: + path = None + try: + try: + path, force = await asyncio.wait_for(self._refresh_queue.get(), timeout=60.0) + except asyncio.TimeoutError: + async with self._queue_tracking_lock: + self._queue_retry_count.clear() + self._queue_processor_task = None + return + + try: + creds = self._credentials_cache.get(path) + if creds and not self._is_token_expired(creds): + self._queue_retry_count.pop(path, None) + continue + + try: + async with asyncio.timeout(self._refresh_timeout_seconds): + await self._refresh_token(path, force=force) + self._queue_retry_count.pop(path, None) + + except asyncio.TimeoutError: + await self._handle_refresh_failure(path, force, "timeout") + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + needs_reauth = False + + if status_code == 400: + try: + payload = e.response.json() + error_type = payload.get("error", "") + error_desc = payload.get("error_description", "") + except Exception: + error_type = "" + error_desc = str(e) + + if ( + error_type == "invalid_grant" + or "invalid_grant" in error_desc.lower() + or "invalid" in error_desc.lower() + ): + needs_reauth = True + + elif status_code in (401, 403): + needs_reauth = True + + if needs_reauth: + self._queue_retry_count.pop(path, None) + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + await self._queue_refresh(path, force=True, needs_reauth=True) + else: + await self._handle_refresh_failure(path, force, f"HTTP {status_code}") + + except CredentialNeedsReauthError: + self._queue_retry_count.pop(path, None) + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + await self._queue_refresh(path, force=True, needs_reauth=True) + + except Exception as e: + await self._handle_refresh_failure(path, force, str(e)) + + finally: + async with self._queue_tracking_lock: + if ( + path in self._queued_credentials + and self._queue_retry_count.get(path, 0) == 0 + ): + self._queued_credentials.discard(path) + self._refresh_queue.task_done() + + await asyncio.sleep(self._refresh_interval_seconds) + + except asyncio.CancelledError: + break + except Exception as e: + lib_logger.error(f"Error in OpenAI Codex refresh queue processor: {e}") + if path: + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + + async def _handle_refresh_failure(self, path: str, force: bool, error: str): + retry_count = self._queue_retry_count.get(path, 0) + 1 + self._queue_retry_count[path] = retry_count + + if retry_count >= self._refresh_max_retries: + lib_logger.error( + f"OpenAI Codex refresh max retries reached for '{Path(path).name}' (last error: {error})." + ) + self._queue_retry_count.pop(path, None) + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + return + + lib_logger.warning( + f"OpenAI Codex refresh failed for '{Path(path).name}' ({error}). Retry {retry_count}/{self._refresh_max_retries}." + ) + await self._refresh_queue.put((path, force)) + + async def _process_reauth_queue(self): + """Sequential background worker for interactive re-auth queue.""" + while True: + path = None + try: + try: + path = await asyncio.wait_for(self._reauth_queue.get(), timeout=60.0) + except asyncio.TimeoutError: + self._reauth_processor_task = None + return + + try: + lib_logger.info( + f"Starting OpenAI Codex interactive re-auth for '{Path(path).name}'" + ) + await self.initialize_token(path, force_interactive=True) + lib_logger.info( + f"OpenAI Codex re-auth succeeded for '{Path(path).name}'" + ) + except Exception as e: + lib_logger.error( + f"OpenAI Codex re-auth failed for '{Path(path).name}': {e}" + ) + finally: + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + self._unavailable_credentials.pop(path, None) + self._reauth_queue.task_done() + + except asyncio.CancelledError: + if path: + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + self._unavailable_credentials.pop(path, None) + break + except Exception as e: + lib_logger.error(f"Error in OpenAI Codex re-auth queue processor: {e}") + if path: + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + self._unavailable_credentials.pop(path, None) + + # ========================================================================= + # Credential management methods for credential_tool + # ========================================================================= + + def _get_provider_file_prefix(self) -> str: + return "openai_codex" + + def _get_oauth_base_dir(self) -> Path: + return Path.cwd() / "oauth_creds" + + def _find_existing_credential_by_identity( + self, + email: Optional[str], + account_id: Optional[str], + base_dir: Optional[Path] = None, + ) -> Optional[Path]: + """ + Find an existing local credential to update. + + Matching policy (multi-account safe): + - If both email and account_id are available, require BOTH to match. + - If one identity field is missing on either side, use the other as a fallback. + + This avoids collisions when different users/accounts share a workspace + account_id while keeping backward compatibility for legacy files that may + miss one metadata field. + """ + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_oauth_*.json") + + email_fallback_match: Optional[Path] = None + account_fallback_match: Optional[Path] = None + + for cred_file in glob(pattern): + try: + with open(cred_file, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + existing_email = metadata.get("email") + existing_account_id = metadata.get("account_id") + + same_email = ( + bool(email) + and bool(existing_email) + and str(existing_email).strip() == str(email).strip() + ) + same_account = ( + bool(account_id) + and bool(existing_account_id) + and str(existing_account_id).strip() == str(account_id).strip() + ) + + # Strongest match: both identifiers present + matching + if same_email and same_account: + return Path(cred_file) + + # Fallbacks only when one identity dimension is missing + if same_email and (not account_id or not existing_account_id): + email_fallback_match = Path(cred_file) + + if same_account and (not email or not existing_email): + account_fallback_match = Path(cred_file) + + except Exception: + continue + + # Prefer email-based fallback over account fallback when both are possible + return email_fallback_match or account_fallback_match + + def _get_next_credential_number(self, base_dir: Optional[Path] = None) -> int: + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_oauth_*.json") + + existing_numbers = [] + for cred_file in glob(pattern): + match = re.search(r"_oauth_(\d+)\.json$", cred_file) + if match: + existing_numbers.append(int(match.group(1))) + + return (max(existing_numbers) + 1) if existing_numbers else 1 + + def _build_credential_path( + self, + base_dir: Optional[Path] = None, + number: Optional[int] = None, + ) -> Path: + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + if number is None: + number = self._get_next_credential_number(base_dir) + + filename = f"{self._get_provider_file_prefix()}_oauth_{number}.json" + return base_dir / filename + + async def setup_credential( + self, + base_dir: Optional[Path] = None, + ) -> OpenAICodexCredentialSetupResult: + """Complete OpenAI Codex credential setup flow.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + base_dir.mkdir(parents=True, exist_ok=True) + + try: + temp_creds = { + "_proxy_metadata": { + "display_name": "new OpenAI Codex credential", + "loaded_from_env": False, + "env_credential_index": None, + } + } + new_creds = await self.initialize_token(temp_creds) + + metadata = new_creds.get("_proxy_metadata", {}) + email = metadata.get("email") + account_id = metadata.get("account_id") + + existing_path = self._find_existing_credential_by_identity( + email=email, + account_id=account_id, + base_dir=base_dir, + ) + is_update = existing_path is not None + file_path = existing_path if is_update else self._build_credential_path(base_dir) + + if not await self._save_credentials(str(file_path), new_creds): + return OpenAICodexCredentialSetupResult( + success=False, + error=f"Failed to save OpenAI Codex credential to {file_path.name}", + ) + + return OpenAICodexCredentialSetupResult( + success=True, + file_path=str(file_path), + email=email, + is_update=is_update, + credentials=new_creds, + ) + + except Exception as e: + lib_logger.error(f"OpenAI Codex credential setup failed: {e}") + return OpenAICodexCredentialSetupResult(success=False, error=str(e)) + + def build_env_lines(self, creds: Dict[str, Any], cred_number: int) -> List[str]: + """Build OPENAI_CODEX_N_* env lines from credential JSON.""" + metadata = creds.get("_proxy_metadata", {}) + email = metadata.get("email", "unknown") + account_id = metadata.get("account_id", "") + + prefix = f"OPENAI_CODEX_{cred_number}" + + lines = [ + f"# OPENAI_CODEX Credential #{cred_number} for: {email}", + f"# Exported from: openai_codex_oauth_{cred_number}.json", + f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + "", + f"{prefix}_ACCESS_TOKEN={creds.get('access_token', '')}", + f"{prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}", + f"{prefix}_EXPIRY_DATE={int(float(creds.get('expiry_date', 0)))}", + f"{prefix}_ID_TOKEN={creds.get('id_token', '')}", + f"{prefix}_ACCOUNT_ID={account_id}", + f"{prefix}_EMAIL={email}", + ] + + return lines + + def export_credential_to_env( + self, + credential_path: str, + output_dir: Optional[Path] = None, + ) -> Optional[str]: + """Export a credential JSON file to .env format.""" + try: + cred_path = Path(credential_path) + with open(cred_path, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + email = metadata.get("email", "unknown") + + match = re.search(r"_oauth_(\d+)\.json$", cred_path.name) + cred_number = int(match.group(1)) if match else 1 + + if output_dir is None: + output_dir = cred_path.parent + + safe_email = str(email).replace("@", "_at_").replace(".", "_") + env_filename = f"openai_codex_{cred_number}_{safe_email}.env" + env_path = output_dir / env_filename + + env_lines = self.build_env_lines(creds, cred_number) + with open(env_path, "w") as f: + f.write("\n".join(env_lines)) + + lib_logger.info(f"Exported OpenAI Codex credential to {env_path}") + return str(env_path) + + except Exception as e: + lib_logger.error(f"Failed to export OpenAI Codex credential: {e}") + return None + + def list_credentials(self, base_dir: Optional[Path] = None) -> List[Dict[str, Any]]: + """List all local OpenAI Codex credential files.""" + if base_dir is None: + base_dir = self._get_oauth_base_dir() + + prefix = self._get_provider_file_prefix() + pattern = str(base_dir / f"{prefix}_oauth_*.json") + + credentials: List[Dict[str, Any]] = [] + for cred_file in sorted(glob(pattern)): + try: + with open(cred_file, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + match = re.search(r"_oauth_(\d+)\.json$", cred_file) + number = int(match.group(1)) if match else 0 + + credentials.append( + { + "file_path": cred_file, + "email": metadata.get("email", "unknown"), + "account_id": metadata.get("account_id"), + "number": number, + } + ) + except Exception: + continue + + return credentials + + def delete_credential(self, credential_path: str) -> bool: + """Delete an OpenAI Codex credential file.""" + try: + cred_path = Path(credential_path) + prefix = self._get_provider_file_prefix() + + if not cred_path.name.startswith(f"{prefix}_oauth_"): + lib_logger.error( + f"File {cred_path.name} does not appear to be an OpenAI Codex credential" + ) + return False + + if not cred_path.exists(): + lib_logger.warning( + f"OpenAI Codex credential file does not exist: {credential_path}" + ) + return False + + self._credentials_cache.pop(credential_path, None) + cred_path.unlink() + lib_logger.info(f"Deleted OpenAI Codex credential file: {credential_path}") + return True + + except Exception as e: + lib_logger.error(f"Failed to delete OpenAI Codex credential: {e}") + return False diff --git a/src/rotator_library/providers/openai_codex_provider.py b/src/rotator_library/providers/openai_codex_provider.py new file mode 100644 index 00000000..d8c12b82 --- /dev/null +++ b/src/rotator_library/providers/openai_codex_provider.py @@ -0,0 +1,1229 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +# src/rotator_library/providers/openai_codex_provider.py + +import copy +import json +import logging +import os +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union + +import httpx +import litellm + +from .openai_codex_auth_base import ( + AUTH_CLAIM, + DEFAULT_API_BASE, + RESPONSES_ENDPOINT_PATH, + OpenAICodexAuthBase, +) +from .provider_interface import ProviderInterface, UsageResetConfigDef, QuotaGroupMap +from ..model_definitions import ModelDefinitions +from ..timeout_config import TimeoutConfig +from ..transaction_logger import ProviderLogger + +lib_logger = logging.getLogger("rotator_library") + +# Conservative fallback model list (can be overridden via OPENAI_CODEX_MODELS) +HARDCODED_MODELS = [ + "gpt-5.1-codex", + "gpt-5-codex", + "gpt-4.1-codex", +] + + +class CodexStreamError(Exception): + """Terminal Codex stream error that should abort the stream.""" + + def __init__(self, message: str, status_code: int = 500, error_body: Optional[str] = None): + self.status_code = status_code + self.error_body = error_body or message + super().__init__(message) + + +class CodexSSETranslator: + """ + Translates OpenAI Codex SSE events into OpenAI chat.completion chunks. + + Supports both currently observed events and planned fallback aliases: + - response.output_text.delta (observed) + - response.content_part.delta (planned alias) + - response.function_call_arguments.delta / .done + """ + + def __init__(self, model_id: str): + self.model_id = model_id + self.response_id: Optional[str] = None + self.created: int = int(time.time()) + self._tool_index_by_call_id: Dict[str, int] = {} + self._tool_names_by_call_id: Dict[str, str] = {} + + def _build_chunk( + self, + *, + delta: Optional[Dict[str, Any]] = None, + finish_reason: Optional[str] = None, + usage: Optional[Dict[str, int]] = None, + ) -> Dict[str, Any]: + if not self.response_id: + self.response_id = f"chatcmpl-codex-{int(time.time() * 1000)}" + + choice = { + "index": 0, + "delta": delta or {}, + "finish_reason": finish_reason, + } + + chunk = { + "id": self.response_id, + "object": "chat.completion.chunk", + "created": self.created, + "model": self.model_id, + "choices": [choice], + } + + if usage is not None: + chunk["usage"] = usage + + return chunk + + def _extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]: + event_type = event.get("type") + + if event_type == "response.output_text.delta": + delta = event.get("delta") + if isinstance(delta, str): + return delta + + if event_type == "response.content_part.delta": + # Compatibility with planned taxonomy + if isinstance(event.get("delta"), str): + return event["delta"] + part = event.get("part") + if isinstance(part, dict): + if isinstance(part.get("delta"), str): + return part["delta"] + if isinstance(part.get("text"), str): + return part["text"] + + if event_type == "response.content_part.added": + part = event.get("part") + if isinstance(part, dict): + text = part.get("text") + if isinstance(text, str) and text: + return text + + return None + + def _map_incomplete_reason(self, reason: Optional[str]) -> str: + if not reason: + return "length" + + normalized = reason.strip().lower() + if normalized in {"stop", "completed"}: + return "stop" + if normalized in {"max_output_tokens", "max_tokens", "length"}: + return "length" + if normalized in {"tool_calls", "tool_call"}: + return "tool_calls" + if normalized in {"content_filter", "content_filtered"}: + return "content_filter" + return "length" + + def _extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]: + response = event.get("response") + if not isinstance(response, dict): + return None + + usage = response.get("usage") + if not isinstance(usage, dict): + return None + + prompt_tokens = int(usage.get("input_tokens", 0) or 0) + completion_tokens = int(usage.get("output_tokens", 0) or 0) + total_tokens = int(usage.get("total_tokens", 0) or 0) + + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + def _get_response_status(self, event: Dict[str, Any]) -> str: + response = event.get("response") + if isinstance(response, dict): + status = response.get("status") + if isinstance(status, str) and status: + return status + + event_type = event.get("type") + if event_type == "response.incomplete": + return "incomplete" + if event_type == "response.failed": + return "failed" + return "completed" + + def _get_or_create_tool_index(self, call_id: str) -> int: + if call_id not in self._tool_index_by_call_id: + self._tool_index_by_call_id[call_id] = len(self._tool_index_by_call_id) + return self._tool_index_by_call_id[call_id] + + def _extract_tool_call_id(self, event: Dict[str, Any]) -> Optional[str]: + for key in ("call_id", "item_id", "id"): + value = event.get(key) + if isinstance(value, str) and value: + return value + + item = event.get("item") + if isinstance(item, dict): + for key in ("call_id", "id"): + value = item.get(key) + if isinstance(value, str) and value: + return value + + return None + + def _extract_error_payload(self, event: Dict[str, Any]) -> Dict[str, Any]: + # Common formats: + # {type:"error", error:{...}} + # {type:"response.failed", response:{error:{...}}} + payload = event.get("error") + if isinstance(payload, dict): + return payload + + response = event.get("response") + if isinstance(response, dict): + nested = response.get("error") + if isinstance(nested, dict): + return nested + + return {} + + def _classify_error_status(self, error_payload: Dict[str, Any]) -> int: + code = str(error_payload.get("code", "") or "").lower() + err_type = str(error_payload.get("type", "") or "").lower() + message = str(error_payload.get("message", "") or "").lower() + text = " ".join([code, err_type, message]) + + if any(token in text for token in ["rate_limit", "usage_limit", "quota"]): + return 429 + if any(token in text for token in ["auth", "unauthorized", "invalid_api_key"]): + return 401 + if "forbidden" in text: + return 403 + if "context" in text or "max_output_tokens" in text: + return 400 + return 500 + + def process_event(self, event: Dict[str, Any]) -> List[Dict[str, Any]]: + """Process a single SSE event and return zero or more translated chunks.""" + chunks: List[Dict[str, Any]] = [] + + event_type = event.get("type") + if not isinstance(event_type, str): + return chunks + + # Capture response id/created as early as possible + response = event.get("response") + if isinstance(response, dict): + if isinstance(response.get("id"), str) and response.get("id"): + self.response_id = response["id"] + if isinstance(response.get("created_at"), (int, float)): + self.created = int(response["created_at"]) + + if event_type == "response.output_item.added": + item = event.get("item") + if isinstance(item, dict) and item.get("type") == "function_call": + call_id = self._extract_tool_call_id(item) + if call_id: + index = self._get_or_create_tool_index(call_id) + name = item.get("name") if isinstance(item.get("name"), str) else "" + if name: + self._tool_names_by_call_id[call_id] = name + + initial_args = item.get("arguments") + if not isinstance(initial_args, str): + initial_args = "" + + tool_delta = { + "tool_calls": [ + { + "index": index, + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": initial_args, + }, + } + ] + } + chunks.append(self._build_chunk(delta=tool_delta)) + return chunks + + if event_type == "response.function_call_arguments.delta": + call_id = self._extract_tool_call_id(event) + delta = event.get("delta") + if call_id and isinstance(delta, str): + index = self._get_or_create_tool_index(call_id) + name = self._tool_names_by_call_id.get(call_id, "") + tool_delta = { + "tool_calls": [ + { + "index": index, + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": delta, + }, + } + ] + } + chunks.append(self._build_chunk(delta=tool_delta)) + return chunks + + if event_type == "response.function_call_arguments.done": + call_id = self._extract_tool_call_id(event) + if call_id: + index = self._get_or_create_tool_index(call_id) + name = self._tool_names_by_call_id.get(call_id, "") + arguments = event.get("arguments") + if not isinstance(arguments, str): + arguments = "" + + tool_delta = { + "tool_calls": [ + { + "index": index, + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": arguments, + }, + } + ] + } + chunks.append(self._build_chunk(delta=tool_delta)) + return chunks + + text_delta = self._extract_text_delta(event) + if text_delta: + chunks.append(self._build_chunk(delta={"content": text_delta})) + return chunks + + if event_type in ("error", "response.failed"): + error_payload = self._extract_error_payload(event) + status_code = self._classify_error_status(error_payload) + message = ( + error_payload.get("message") + if isinstance(error_payload.get("message"), str) + else f"Codex stream failed ({event_type})" + ) + raise CodexStreamError( + message=message, + status_code=status_code, + error_body=json.dumps({"error": error_payload} if error_payload else event), + ) + + if event_type in ("response.completed", "response.incomplete"): + usage = self._extract_usage(event) + status = self._get_response_status(event) + finish_reason = "stop" + + if status == "incomplete": + incomplete_details = None + if isinstance(response, dict): + incomplete_details = response.get("incomplete_details") + reason = None + if isinstance(incomplete_details, dict): + reason = incomplete_details.get("reason") + if isinstance(reason, str): + finish_reason = self._map_incomplete_reason(reason) + else: + finish_reason = "length" + + chunks.append( + self._build_chunk(delta={}, finish_reason=finish_reason, usage=usage) + ) + return chunks + + # Ignore all other event families safely + return chunks + + +class OpenAICodexProvider(OpenAICodexAuthBase, ProviderInterface): + """OpenAI Codex provider via ChatGPT backend `/codex/responses`.""" + + skip_cost_calculation = True + default_rotation_mode: str = "sequential" + provider_env_name: str = "openai_codex" + + # Conservative placeholders (MVP-safe defaults) + tier_priorities = { + "unknown": 10, + } + + usage_reset_configs = { + "default": UsageResetConfigDef( + window_seconds=24 * 60 * 60, + mode="credential", + description="TODO: tune OpenAI Codex quota window from observed behavior", + field_name="daily", + ) + } + + model_quota_groups: QuotaGroupMap = { + # TODO: tune once quota sharing behavior is empirically validated + } + + def __init__(self): + super().__init__() + self.model_definitions = ModelDefinitions() + + def has_custom_logic(self) -> bool: + return True + + # ========================================================================= + # Model discovery + # ========================================================================= + + async def get_models(self, credential: str, client: httpx.AsyncClient) -> List[str]: + """ + Returns OpenAI Codex models from: + 1) OPENAI_CODEX_MODELS env definitions (priority) + 2) hardcoded fallback list + 3) optional dynamic /models discovery (best-effort) + """ + models: List[str] = [] + env_model_ids = set() + + static_models = self.model_definitions.get_all_provider_models("openai_codex") + if static_models: + for model in static_models: + model_name = model.split("/")[-1] if "/" in model else model + model_id = self.model_definitions.get_model_id("openai_codex", model_name) + models.append(model) + if model_id: + env_model_ids.add(model_id) + + lib_logger.info( + f"Loaded {len(static_models)} static models for openai_codex from OPENAI_CODEX_MODELS" + ) + + for model_id in HARDCODED_MODELS: + if model_id not in env_model_ids: + models.append(f"openai_codex/{model_id}") + env_model_ids.add(model_id) + + # Optional dynamic discovery (Codex backend may not support this endpoint) + try: + await self.initialize_token(credential) + creds = await self._load_credentials(credential) + access_token, account_id = self._extract_runtime_auth(creds) + + api_base = self._resolve_api_base() + models_url = f"{api_base.rstrip('/')}/models" + + headers = self._build_request_headers( + access_token=access_token, + account_id=account_id, + stream=False, + ) + + response = await client.get(models_url, headers=headers, timeout=20.0) + response.raise_for_status() + + payload = response.json() + data = payload.get("data") if isinstance(payload, dict) else payload + + discovered = 0 + if isinstance(data, list): + for item in data: + model_id = None + if isinstance(item, dict): + model_id = item.get("id") or item.get("name") + elif isinstance(item, str): + model_id = item + + if isinstance(model_id, str) and model_id and model_id not in env_model_ids: + models.append(f"openai_codex/{model_id}") + env_model_ids.add(model_id) + discovered += 1 + + if discovered > 0: + lib_logger.debug( + f"Discovered {discovered} additional models for openai_codex via dynamic /models" + ) + + except Exception as e: + lib_logger.debug(f"Dynamic model discovery failed for openai_codex: {e}") + + return models + + async def initialize_credentials(self, credential_paths: List[str]) -> None: + """Preload credentials and queue refresh/reauth where needed.""" + ready = 0 + refreshing = 0 + reauth_required = 0 + + for cred_path in credential_paths: + try: + creds = await self._load_credentials(cred_path) + self._ensure_proxy_metadata(creds) + + if not creds.get("refresh_token"): + await self._queue_refresh(cred_path, force=True, needs_reauth=True) + reauth_required += 1 + continue + + if self._is_token_expired(creds): + await self._queue_refresh(cred_path, force=False, needs_reauth=False) + refreshing += 1 + else: + ready += 1 + + # ensure metadata caches are populated + self._credentials_cache[cred_path] = creds + + except Exception as e: + lib_logger.warning( + f"Failed to initialize OpenAI Codex credential '{cred_path}': {e}" + ) + await self._queue_refresh(cred_path, force=True, needs_reauth=True) + reauth_required += 1 + + lib_logger.info( + "OpenAI Codex credential initialization: " + f"ready={ready}, refreshing={refreshing}, reauth_required={reauth_required}" + ) + + # ========================================================================= + # Request mapping helpers + # ========================================================================= + + def _resolve_api_base(self) -> str: + return os.getenv("OPENAI_CODEX_API_BASE", DEFAULT_API_BASE) + + def _extract_runtime_auth(self, creds: Dict[str, Any]) -> Tuple[str, str]: + access_token = creds.get("access_token") + if not isinstance(access_token, str) or not access_token: + raise ValueError("OpenAI Codex credential missing access_token") + + metadata = creds.get("_proxy_metadata", {}) + account_id = metadata.get("account_id") + + if not account_id: + # Fallback parse from access_token + payload = self._decode_jwt_unverified(access_token) + if payload: + direct = payload.get("https://api.openai.com/auth.chatgpt_account_id") + nested = None + claim = payload.get(AUTH_CLAIM) + if isinstance(claim, dict): + nested = claim.get("chatgpt_account_id") + + account_id = direct or nested + + if not isinstance(account_id, str) or not account_id: + raise ValueError( + "OpenAI Codex credential missing account_id. Re-authenticate to refresh token metadata." + ) + + return access_token, account_id + + def _build_request_headers( + self, + *, + access_token: str, + account_id: str, + stream: bool, + extra_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + headers = { + "Authorization": f"Bearer {access_token}", + "chatgpt-account-id": account_id, + "OpenAI-Beta": "responses=experimental", + "originator": "pi", + "Content-Type": "application/json", + "Accept": "text/event-stream" if stream else "application/json", + "User-Agent": "LLM-API-Key-Proxy/OpenAICodex", + } + + if extra_headers: + headers.update({k: str(v) for k, v in extra_headers.items()}) + + return headers + + def _extract_text(self, content: Any) -> str: + if content is None: + return "" + + if isinstance(content, str): + return content + + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, dict): + # OpenAI chat content blocks + if item.get("type") == "text" and isinstance(item.get("text"), str): + parts.append(item["text"]) + elif item.get("type") in {"input_text", "output_text"} and isinstance( + item.get("text"), str + ): + parts.append(item["text"]) + elif item.get("type") == "refusal" and isinstance(item.get("refusal"), str): + parts.append(item["refusal"]) + elif isinstance(item, str): + parts.append(item) + return "\n".join(parts) + + if isinstance(content, dict): + if isinstance(content.get("text"), str): + return content["text"] + return json.dumps(content) + + return str(content) + + def _convert_user_content_to_input_parts(self, content: Any) -> List[Dict[str, Any]]: + if isinstance(content, str): + return [{"type": "input_text", "text": content}] + + if isinstance(content, list): + parts: List[Dict[str, Any]] = [] + for item in content: + if not isinstance(item, dict): + continue + + item_type = item.get("type") + if item_type in ("text", "input_text") and isinstance(item.get("text"), str): + parts.append({"type": "input_text", "text": item["text"]}) + elif item_type == "image_url": + image_url = item.get("image_url") + if isinstance(image_url, dict): + image_url = image_url.get("url") + if isinstance(image_url, str) and image_url: + parts.append({"type": "input_image", "image_url": image_url, "detail": "auto"}) + elif item_type == "input_image": + image_url = item.get("image_url") + if isinstance(image_url, str) and image_url: + part = {"type": "input_image", "image_url": image_url} + if isinstance(item.get("detail"), str): + part["detail"] = item["detail"] + else: + part["detail"] = "auto" + parts.append(part) + + if parts: + return parts + + text = self._extract_text(content) + return [{"type": "input_text", "text": text}] + + def _convert_messages_to_codex_input( + self, + messages: List[Dict[str, Any]], + ) -> Tuple[str, List[Dict[str, Any]]]: + instructions: List[str] = [] + codex_input: List[Dict[str, Any]] = [] + + for message in messages: + role = message.get("role") + content = message.get("content") + + if role in ("system", "developer"): + text = self._extract_text(content) + if text.strip(): + instructions.append(text.strip()) + continue + + if role == "user": + codex_input.append( + { + "role": "user", + "content": self._convert_user_content_to_input_parts(content), + } + ) + continue + + if role == "assistant": + text = self._extract_text(content) + if text.strip(): + codex_input.append( + { + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + ) + + # Carry forward assistant tool calls where provided + tool_calls = message.get("tool_calls") + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + + call_id = tool_call.get("id") + function = tool_call.get("function", {}) + if not isinstance(function, dict): + continue + + name = function.get("name") + arguments = function.get("arguments") + if not isinstance(arguments, str): + arguments = json.dumps(arguments or {}) + + if isinstance(call_id, str) and isinstance(name, str): + codex_input.append( + { + "type": "function_call", + "call_id": call_id, + "name": name, + "arguments": arguments, + } + ) + continue + + if role == "tool": + call_id = message.get("tool_call_id") + if not isinstance(call_id, str) or not call_id: + continue + + output_text = self._extract_text(content) + codex_input.append( + { + "type": "function_call_output", + "call_id": call_id, + "output": output_text, + } + ) + + # Codex endpoint currently requires non-empty instructions + instructions_text = "\n\n".join(instructions).strip() + if not instructions_text: + instructions_text = "You are a helpful assistant." + + if not codex_input: + codex_input = [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "", + } + ], + } + ] + + return instructions_text, codex_input + + def _convert_tools(self, tools: Any) -> Optional[List[Dict[str, Any]]]: + if not isinstance(tools, list) or not tools: + return None + + converted: List[Dict[str, Any]] = [] + + for tool in tools: + if not isinstance(tool, dict): + continue + + # OpenAI chat format: {type:"function", function:{name,description,parameters}} + if tool.get("type") == "function" and isinstance(tool.get("function"), dict): + fn = tool["function"] + name = fn.get("name") + if not isinstance(name, str) or not name: + continue + + schema = fn.get("parameters") + if not isinstance(schema, dict): + schema = {"type": "object", "properties": {}} + + # Remove OpenAI-specific strict flag if present + schema = copy.deepcopy(schema) + schema.pop("additionalProperties", None) + + converted.append( + { + "type": "function", + "name": name, + "description": fn.get("description", ""), + "parameters": schema, + } + ) + continue + + # Already in responses format + if tool.get("type") == "function" and isinstance(tool.get("name"), str): + converted.append(copy.deepcopy(tool)) + + return converted or None + + def _normalize_tool_choice(self, tool_choice: Any, has_tools: bool) -> Any: + if not has_tools: + return None + + if isinstance(tool_choice, str): + # Codex endpoint handles "auto" reliably; map required -> auto + if tool_choice in {"auto", "none"}: + return tool_choice + if tool_choice == "required": + return "auto" + return "auto" + + if isinstance(tool_choice, dict): + if tool_choice.get("type") == "function": + fn = tool_choice.get("function") + if isinstance(fn, dict) and isinstance(fn.get("name"), str): + return {"type": "function", "name": fn["name"]} + if isinstance(tool_choice.get("name"), str): + return {"type": "function", "name": tool_choice["name"]} + if isinstance(tool_choice.get("name"), str): + return {"type": "function", "name": tool_choice["name"]} + + return "auto" + + def _build_codex_payload(self, model_name: str, **kwargs) -> Dict[str, Any]: + messages = kwargs.get("messages") or [] + instructions, codex_input = self._convert_messages_to_codex_input(messages) + + payload: Dict[str, Any] = { + "model": model_name, + "stream": True, # Endpoint currently requires stream=true + "store": False, + "instructions": instructions, + "input": codex_input, + "tool_choice": "auto", + "parallel_tool_calls": True, + } + + # Keep verbosity at medium by default (gpt-5.1-codex rejects low) + text_verbosity = os.getenv("OPENAI_CODEX_TEXT_VERBOSITY", "medium") + payload["text"] = {"verbosity": text_verbosity} + + # OpenAI chat params -> Codex responses equivalents + if kwargs.get("temperature") is not None: + payload["temperature"] = kwargs["temperature"] + if kwargs.get("top_p") is not None: + payload["top_p"] = kwargs["top_p"] + # Note: max_output_tokens is NOT supported by the Codex Responses API + # (gpt-5.3-codex returns 400 "Unsupported parameter: max_output_tokens"). + # Omit it and let the API use its default. + + converted_tools = self._convert_tools(kwargs.get("tools")) + if converted_tools: + payload["tools"] = converted_tools + payload["tool_choice"] = self._normalize_tool_choice( + kwargs.get("tool_choice"), + has_tools=True, + ) + payload["parallel_tool_calls"] = True + else: + payload.pop("tools", None) + payload.pop("tool_choice", None) + payload.pop("parallel_tool_calls", None) + + # Optional session pinning for cache affinity + session_id = kwargs.get("session_id") or kwargs.get("conversation_id") + if isinstance(session_id, str) and session_id: + payload["prompt_cache_key"] = session_id + payload["prompt_cache_retention"] = "in-memory" + + return payload + + # ========================================================================= + # SSE parsing + response conversion + # ========================================================================= + + async def _iter_sse_events( + self, response: httpx.Response + ) -> AsyncGenerator[Dict[str, Any], None]: + """Parse SSE stream into event dictionaries.""" + event_lines: List[str] = [] + + async for line in response.aiter_lines(): + if line is None: + continue + + if line == "": + if not event_lines: + continue + + data_lines = [] + for entry in event_lines: + if entry.startswith("data:"): + data_lines.append(entry[5:].lstrip()) + + event_lines = [] + if not data_lines: + continue + + payload = "\n".join(data_lines).strip() + if not payload or payload == "[DONE]": + if payload == "[DONE]": + return + continue + + try: + parsed = json.loads(payload) + if isinstance(parsed, dict): + yield parsed + except json.JSONDecodeError: + lib_logger.debug(f"OpenAI Codex SSE non-JSON payload ignored: {payload[:200]}") + continue + + event_lines.append(line) + + # Flush trailing event if stream closes without blank line + if event_lines: + data_lines = [entry[5:].lstrip() for entry in event_lines if entry.startswith("data:")] + payload = "\n".join(data_lines).strip() + if payload and payload != "[DONE]": + try: + parsed = json.loads(payload) + if isinstance(parsed, dict): + yield parsed + except json.JSONDecodeError: + pass + + def _stream_to_completion_response( + self, chunks: List[litellm.ModelResponse] + ) -> litellm.ModelResponse: + """Reassemble streamed chunks into a non-streaming ModelResponse.""" + if not chunks: + raise ValueError("No chunks provided for reassembly") + + final_message: Dict[str, Any] = {"role": "assistant"} + aggregated_tool_calls: Dict[int, Dict[str, Any]] = {} + usage_data = None + chunk_finish_reason = None + + first_chunk = chunks[0] + + for chunk in chunks: + if not hasattr(chunk, "choices") or not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.get("delta", {}) + + if "content" in delta and delta["content"] is not None: + final_message["content"] = final_message.get("content", "") + delta["content"] + + if "tool_calls" in delta and delta["tool_calls"]: + for tc_chunk in delta["tool_calls"]: + index = tc_chunk.get("index", 0) + if index not in aggregated_tool_calls: + aggregated_tool_calls[index] = { + "type": "function", + "function": {"name": "", "arguments": ""}, + } + + if tc_chunk.get("id"): + aggregated_tool_calls[index]["id"] = tc_chunk["id"] + + if tc_chunk.get("type"): + aggregated_tool_calls[index]["type"] = tc_chunk["type"] + + if isinstance(tc_chunk.get("function"), dict): + fn = tc_chunk["function"] + if fn.get("name") is not None: + aggregated_tool_calls[index]["function"]["name"] += str(fn["name"]) + if fn.get("arguments") is not None: + aggregated_tool_calls[index]["function"]["arguments"] += str( + fn["arguments"] + ) + + if choice.get("finish_reason"): + chunk_finish_reason = choice["finish_reason"] + + for chunk in reversed(chunks): + if hasattr(chunk, "usage") and chunk.usage: + usage_data = chunk.usage + break + + if aggregated_tool_calls: + final_message["tool_calls"] = list(aggregated_tool_calls.values()) + + for field in ["content", "tool_calls", "function_call"]: + if field not in final_message: + final_message[field] = None + + if aggregated_tool_calls: + finish_reason = "tool_calls" + elif chunk_finish_reason: + finish_reason = chunk_finish_reason + else: + finish_reason = "stop" + + final_choice = { + "index": 0, + "message": final_message, + "finish_reason": finish_reason, + } + + final_response_data = { + "id": first_chunk.id, + "object": "chat.completion", + "created": first_chunk.created, + "model": first_chunk.model, + "choices": [final_choice], + "usage": usage_data, + } + + return litellm.ModelResponse(**final_response_data) + + # ========================================================================= + # Main completion flow + # ========================================================================= + + async def acompletion( + self, client: httpx.AsyncClient, **kwargs + ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: + credential_identifier = kwargs.pop("credential_identifier") + transaction_context = kwargs.pop("transaction_context", None) + model = kwargs["model"] + + file_logger = ProviderLogger(transaction_context) + + async def make_request() -> Any: + # Ensure token initialized/refreshed before request + await self.initialize_token(credential_identifier) + creds = await self._load_credentials(credential_identifier) + if self._is_token_expired(creds): + creds = await self._refresh_token(credential_identifier) + + access_token, account_id = self._extract_runtime_auth(creds) + + model_name = model.split("/")[-1] + payload = self._build_codex_payload(model_name=model_name, **kwargs) + + headers = self._build_request_headers( + access_token=access_token, + account_id=account_id, + stream=True, + ) + + url = f"{self._resolve_api_base().rstrip('/')}{RESPONSES_ENDPOINT_PATH}" + file_logger.log_request(payload) + + return client.stream( + "POST", + url, + headers=headers, + json=payload, + timeout=TimeoutConfig.streaming(), + ) + + async def stream_handler( + response_stream: Any, + attempt: int = 1, + ): + try: + async with response_stream as response: + if response.status_code >= 400: + raw_error = await response.aread() + error_text = ( + raw_error.decode("utf-8", "replace") + if isinstance(raw_error, bytes) + else str(raw_error) + ) + + # Try a single forced token refresh on auth failures + if response.status_code in (401, 403) and attempt == 1: + lib_logger.warning( + "OpenAI Codex returned 401/403; forcing refresh and retrying once" + ) + await self._refresh_token(credential_identifier, force=True) + retry_stream = await make_request() + async for chunk in stream_handler(retry_stream, attempt=2): + yield chunk + return + + # Surface typed HTTPStatusError for classify_error() + raise httpx.HTTPStatusError( + f"OpenAI Codex HTTP {response.status_code}: {error_text}", + request=response.request, + response=response, + ) + + translator = CodexSSETranslator(model_id=model) + + async for event in self._iter_sse_events(response): + try: + file_logger.log_response_chunk(json.dumps(event)) + except Exception: + pass + + try: + translated_chunks = translator.process_event(event) + except CodexStreamError as stream_error: + synthetic_response = httpx.Response( + status_code=stream_error.status_code, + request=response.request, + text=stream_error.error_body, + ) + raise httpx.HTTPStatusError( + str(stream_error), + request=response.request, + response=synthetic_response, + ) + + for chunk_dict in translated_chunks: + yield litellm.ModelResponse(**chunk_dict) + + except httpx.HTTPStatusError: + raise + except Exception as e: + file_logger.log_error(f"Error during OpenAI Codex stream processing: {e}") + raise + + async def logging_stream_wrapper(): + chunks: List[litellm.ModelResponse] = [] + try: + async for chunk in stream_handler(await make_request()): + chunks.append(chunk) + yield chunk + finally: + if chunks: + try: + final_response = self._stream_to_completion_response(chunks) + if hasattr(final_response, "model_dump"): + file_logger.log_final_response(final_response.model_dump()) + else: + file_logger.log_final_response(final_response.dict()) + except Exception: + pass + + if kwargs.get("stream"): + return logging_stream_wrapper() + + async def non_stream_wrapper() -> litellm.ModelResponse: + chunks = [chunk async for chunk in logging_stream_wrapper()] + return self._stream_to_completion_response(chunks) + + return await non_stream_wrapper() + + # ========================================================================= + # Provider-specific quota parsing + # ========================================================================= + + @staticmethod + def parse_quota_error( + error: Exception, + error_body: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + """ + Parse OpenAI Codex quota/rate-limit errors. + + Supports: + - Retry-After header + - error.resets_at (unix seconds) + - error.retry_after / retry_after_seconds fields + - usage_limit / quota / rate_limit style error codes + """ + now_ts = time.time() + + response = None + if isinstance(error, httpx.HTTPStatusError): + response = error.response + + headers = response.headers if response is not None else {} + + retry_after: Optional[int] = None + retry_header = headers.get("Retry-After") or headers.get("retry-after") + if retry_header: + try: + retry_after = max(1, int(float(retry_header))) + except ValueError: + retry_after = None + + body_text = error_body + if body_text is None and response is not None: + try: + body_text = response.text + except Exception: + body_text = None + + if not body_text: + if retry_after is not None: + return { + "retry_after": retry_after, + "reason": "RATE_LIMIT", + "reset_timestamp": None, + "quota_reset_timestamp": None, + } + return None + + parsed = None + try: + parsed = json.loads(body_text) + except Exception: + parsed = None + + if not isinstance(parsed, dict): + if retry_after is not None: + return { + "retry_after": retry_after, + "reason": "RATE_LIMIT", + "reset_timestamp": None, + "quota_reset_timestamp": None, + } + return None + + err = parsed.get("error") if isinstance(parsed.get("error"), dict) else {} + + code = str(err.get("code", "") or "").lower() + err_type = str(err.get("type", "") or "").lower() + message = str(err.get("message", "") or "").lower() + combined = " ".join([code, err_type, message]) + + # Look for codex-specific reset timestamp + reset_ts = err.get("resets_at") + quota_reset_timestamp: Optional[float] = None + reset_timestamp_iso: Optional[str] = None + if isinstance(reset_ts, (int, float)): + quota_reset_timestamp = float(reset_ts) + retry_after_from_reset = int(max(1, quota_reset_timestamp - now_ts)) + retry_after = retry_after or retry_after_from_reset + reset_timestamp_iso = datetime.fromtimestamp( + quota_reset_timestamp, tz=timezone.utc + ).isoformat() + + if retry_after is None: + for key in ("retry_after", "retry_after_seconds", "retryAfter"): + value = err.get(key) + if isinstance(value, (int, float)): + retry_after = max(1, int(value)) + break + if isinstance(value, str): + try: + retry_after = max(1, int(float(value))) + break + except ValueError: + continue + + if retry_after is None and any( + token in combined for token in ["usage_limit", "rate_limit", "quota"] + ): + retry_after = 60 + + if retry_after is None: + return None + + reason = ( + str(err.get("code") or err.get("type") or "RATE_LIMIT").upper() + ) + + return { + "retry_after": retry_after, + "reason": reason, + "reset_timestamp": reset_timestamp_iso, + "quota_reset_timestamp": quota_reset_timestamp, + } diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index f53f91e9..22056b49 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from dataclasses import dataclass from typing import ( List, @@ -19,7 +19,33 @@ import litellm if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager + + +# ============================================================================= +# SINGLETON METACLASS FOR PROVIDERS +# ============================================================================= + + +class SingletonABCMeta(ABCMeta): + """ + Metaclass that combines ABC functionality with singleton pattern. + + All classes using this metaclass (including subclasses of ProviderInterface) + will be singletons - only one instance per class exists. + + This prevents the bug where multiple provider instances are created + by different components (RotatingClient, UsageManager, Hooks, etc.), + each with their own caches and state. + """ + + _instances: Dict[type, Any] = {} + + def __call__(cls, *args, **kwargs): + if cls not in SingletonABCMeta._instances: + SingletonABCMeta._instances[cls] = super().__call__(*args, **kwargs) + return SingletonABCMeta._instances[cls] + from ..config import ( DEFAULT_ROTATION_MODE, @@ -68,7 +94,7 @@ class UsageResetConfigDef: QuotaGroupMap = Dict[str, List[str]] # group_name -> [models] -class ProviderInterface(ABC): +class ProviderInterface(ABC, metaclass=SingletonABCMeta): """ An interface for API provider-specific functionality, including model discovery and custom API call handling for non-standard providers. diff --git a/src/rotator_library/providers/qwen_auth_base.py b/src/rotator_library/providers/qwen_auth_base.py index acb42606..f31ead3c 100644 --- a/src/rotator_library/providers/qwen_auth_base.py +++ b/src/rotator_library/providers/qwen_auth_base.py @@ -71,26 +71,20 @@ def __init__(self): str, float ] = {} # Track backoff timers (Unix timestamp) - # [QUEUE SYSTEM] Sequential refresh processing with two separate queues + # [QUEUE SYSTEM] Sequential refresh processing # Normal refresh queue: for proactive token refresh (old token still valid) self._refresh_queue: asyncio.Queue = asyncio.Queue() self._queue_processor_task: Optional[asyncio.Task] = None - # Re-auth queue: for invalid refresh tokens (requires user interaction) - self._reauth_queue: asyncio.Queue = asyncio.Queue() - self._reauth_processor_task: Optional[asyncio.Task] = None - # Tracking sets/dicts - self._queued_credentials: set = set() # Track credentials in either queue - # Only credentials in re-auth queue are marked unavailable (not normal refresh) - # TTL cleanup is defense-in-depth for edge cases where re-auth processor crashes - self._unavailable_credentials: Dict[ - str, float - ] = {} # Maps credential path -> timestamp when marked unavailable - # TTL should exceed reauth timeout (300s) to avoid premature cleanup - self._unavailable_ttl_seconds: int = 360 # 6 minutes TTL for stale entries + self._queued_credentials: set = set() # Track credentials in refresh queue self._queue_tracking_lock = asyncio.Lock() # Protects queue sets + # [PERMANENTLY EXPIRED] Track credentials that have been permanently removed from rotation + # These credentials have invalid/revoked refresh tokens and require manual re-authentication + # via credential_tool.py. They will NOT be selected for rotation until proxy restart. + self._permanently_expired_credentials: set = set() + # Retry tracking for normal refresh queue self._queue_retry_count: Dict[ str, int @@ -100,7 +94,6 @@ def __init__(self): self._refresh_timeout_seconds: int = 15 # Max time for single refresh self._refresh_interval_seconds: int = 30 # Delay between queue items self._refresh_max_retries: int = 3 # Attempts before kicked out - self._reauth_timeout_seconds: int = 300 # Time for user to complete OAuth def _parse_env_credential_path(self, path: str) -> Optional[str]: """ @@ -280,6 +273,14 @@ def _is_token_expired(self, creds: Dict[str, Any]) -> bool: expiry_timestamp = creds.get("expiry_date", 0) / 1000 return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS + async def _get_lock(self, path: str) -> asyncio.Lock: + # [FIX RACE CONDITION] Protect lock creation with a master lock + # This prevents TOCTOU bug where multiple coroutines check and create simultaneously + async with self._locks_lock: + if path not in self._refresh_locks: + self._refresh_locks[path] = asyncio.Lock() + return self._refresh_locks[path] + def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool: """Check if token is TRULY expired (past actual expiry, not just threshold). @@ -289,6 +290,52 @@ def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool: expiry_timestamp = creds.get("expiry_date", 0) / 1000 return expiry_timestamp < time.time() + def _mark_credential_expired(self, path: str, reason: str) -> None: + """ + Permanently mark a credential as expired and remove it from rotation. + + This is called when a credential's refresh token is invalid or revoked, + meaning normal token refresh cannot work. The credential is removed from + rotation entirely and requires manual re-authentication via credential_tool.py. + + The proxy must be restarted after fixing the credential. + + Args: + path: Credential file path or env:// path + reason: Human-readable reason for expiration (e.g., "invalid_grant", "HTTP 401") + """ + # Add to permanently expired set + self._permanently_expired_credentials.add(path) + + # Clean up other tracking structures + self._queued_credentials.discard(path) + + # Get display name + if path.startswith("env://"): + display_name = path + else: + display_name = Path(path).name + + # Rich-formatted output for high visibility + console.print( + Panel( + f"[bold red]Credential:[/bold red] {display_name}\n" + f"[bold red]Reason:[/bold red] {reason}\n\n" + f"[yellow]This credential has been removed from rotation.[/yellow]\n" + f"[yellow]To fix: Run 'python credential_tool.py' to re-authenticate,[/yellow]\n" + f"[yellow]then restart the proxy.[/yellow]", + title="[bold red]⚠ CREDENTIAL EXPIRED - REMOVED FROM ROTATION[/bold red]", + border_style="red", + ) + ) + + # Also log at ERROR level for log files + lib_logger.error( + f"CREDENTIAL EXPIRED - REMOVED FROM ROTATION | " + f"Credential: {display_name} | Reason: {reason} | " + f"Action: Run 'credential_tool.py' to re-authenticate, then restart proxy" + ) + async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]: async with await self._get_lock(path): cached_creds = self._credentials_cache.get(path) @@ -344,9 +391,9 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] f"HTTP {status_code} for '{Path(path).name}': {error_body}" ) - # [INVALID GRANT HANDLING] Handle 400/401/403 by raising - # The caller (_process_refresh_queue or initialize_token) will handle re-auth - # We must NOT call initialize_token from here as we hold a lock (would deadlock) + # [INVALID GRANT HANDLING] Handle 400/401/403 by marking as expired + # These errors indicate the refresh token is invalid/revoked + # Mark as permanently expired - no interactive re-auth during proxy operation if status_code == 400: # Check if this is an invalid refresh token error try: @@ -361,39 +408,25 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] "invalid" in error_desc.lower() or error_type == "invalid_request" ): - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP 400: {error_desc}). " - f"Queued for re-authentication, rotating to next credential." + self._mark_credential_expired( + path, + f"Refresh token invalid (HTTP 400: {error_desc})", ) - # Queue for re-auth in background (non-blocking, fire-and-forget) - # This ensures credential gets fixed even if caller doesn't handle it - asyncio.create_task( - self._queue_refresh( - path, force=True, needs_reauth=True - ) - ) - # Raise rotatable error instead of raw HTTPStatusError raise CredentialNeedsReauthError( credential_path=path, - message=f"Refresh token invalid for '{Path(path).name}'. Re-auth queued.", + message=f"Refresh token invalid for '{Path(path).name}'. Credential removed from rotation.", ) else: # Other 400 error - raise it raise elif status_code in (401, 403): - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP {status_code}). " - f"Queued for re-authentication, rotating to next credential." - ) - # Queue for re-auth in background (non-blocking, fire-and-forget) - asyncio.create_task( - self._queue_refresh(path, force=True, needs_reauth=True) + self._mark_credential_expired( + path, f"Credential unauthorized (HTTP {status_code})" ) - # Raise rotatable error instead of raw HTTPStatusError raise CredentialNeedsReauthError( credential_path=path, - message=f"Token invalid for '{Path(path).name}' (HTTP {status_code}). Re-auth queued.", + message=f"Token invalid for '{Path(path).name}' (HTTP {status_code}). Credential removed from rotation.", ) elif status_code == 429: @@ -546,67 +579,28 @@ async def proactively_refresh(self, credential_identifier: str): # f"Queueing refresh for '{Path(credential_identifier).name}'" # ) # lib_logger.info(f"Proactive refresh triggered for '{Path(credential_identifier).name}'") - await self._queue_refresh( - credential_identifier, force=False, needs_reauth=False - ) - - async def _get_lock(self, path: str) -> asyncio.Lock: - # [FIX RACE CONDITION] Protect lock creation with a master lock - async with self._locks_lock: - if path not in self._refresh_locks: - self._refresh_locks[path] = asyncio.Lock() - return self._refresh_locks[path] - - def is_credential_available(self, path: str) -> bool: - """Check if a credential is available for rotation. + await self._queue_refresh(credential_identifier, force=False) - Credentials are unavailable if: - 1. In re-auth queue (token is truly broken, requires user interaction) - 2. Token is TRULY expired (past actual expiry, not just threshold) + async def _queue_refresh(self, path: str, force: bool = False): + """Add a credential to the refresh queue if not already queued. - Note: Credentials in normal refresh queue are still available because - the old token is valid until actual expiry. - - TTL cleanup (defense-in-depth): If a credential has been in the re-auth - queue longer than _unavailable_ttl_seconds without being processed, it's - cleaned up. This should only happen if the re-auth processor crashes or - is cancelled without proper cleanup. + Args: + path: Credential file path + force: Force refresh even if not expired """ - # Check if in re-auth queue (truly unavailable) - if path in self._unavailable_credentials: - marked_time = self._unavailable_credentials.get(path) - if marked_time is not None: - now = time.time() - if now - marked_time > self._unavailable_ttl_seconds: - # Entry is stale - clean it up and return available - # This is a defense-in-depth for edge cases where re-auth - # processor crashed or was cancelled without cleanup - lib_logger.warning( - f"Credential '{Path(path).name}' stuck in re-auth queue for " - f"{int(now - marked_time)}s (TTL: {self._unavailable_ttl_seconds}s). " - f"Re-auth processor may have crashed. Auto-cleaning stale entry." - ) - # Clean up both tracking structures for consistency - self._unavailable_credentials.pop(path, None) - self._queued_credentials.discard(path) - else: - return False # Still in re-auth, not available + # Check backoff for automated refreshes + now = time.time() + if path in self._next_refresh_after: + backoff_until = self._next_refresh_after[path] + if now < backoff_until: + # Credential is in backoff, do not queue + return - # Check if token is TRULY expired (not just threshold-expired) - creds = self._credentials_cache.get(path) - if creds and self._is_token_truly_expired(creds): - # Token is actually expired - should not be used - # Queue for refresh if not already queued + async with self._queue_tracking_lock: if path not in self._queued_credentials: - # lib_logger.debug( - # f"Credential '{Path(path).name}' is truly expired, queueing for refresh" - # ) - asyncio.create_task( - self._queue_refresh(path, force=True, needs_reauth=False) - ) - return False - - return True + self._queued_credentials.add(path) + await self._refresh_queue.put((path, force)) + await self._ensure_queue_processor_running() async def _ensure_queue_processor_running(self): """Lazily starts the queue processor if not already running.""" @@ -615,64 +609,6 @@ async def _ensure_queue_processor_running(self): self._process_refresh_queue() ) - async def _ensure_reauth_processor_running(self): - """Lazily starts the re-auth queue processor if not already running.""" - if self._reauth_processor_task is None or self._reauth_processor_task.done(): - self._reauth_processor_task = asyncio.create_task( - self._process_reauth_queue() - ) - - async def _queue_refresh( - self, path: str, force: bool = False, needs_reauth: bool = False - ): - """Add a credential to the appropriate refresh queue if not already queued. - - Args: - path: Credential file path - force: Force refresh even if not expired - needs_reauth: True if full re-authentication needed (routes to re-auth queue) - - Queue routing: - - needs_reauth=True: Goes to re-auth queue, marks as unavailable - - needs_reauth=False: Goes to normal refresh queue, does NOT mark unavailable - (old token is still valid until actual expiry) - """ - # IMPORTANT: Only check backoff for simple automated refreshes - # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input - if not needs_reauth: - now = time.time() - if path in self._next_refresh_after: - backoff_until = self._next_refresh_after[path] - if now < backoff_until: - # Credential is in backoff for automated refresh, do not queue - # remaining = int(backoff_until - now) - # lib_logger.debug( - # f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)" - # ) - return - - async with self._queue_tracking_lock: - if path not in self._queued_credentials: - self._queued_credentials.add(path) - - if needs_reauth: - # Re-auth queue: mark as unavailable (token is truly broken) - self._unavailable_credentials[path] = time.time() - # lib_logger.debug( - # f"Queued '{Path(path).name}' for RE-AUTH (marked unavailable). " - # f"Total unavailable: {len(self._unavailable_credentials)}" - # ) - await self._reauth_queue.put(path) - await self._ensure_reauth_processor_running() - else: - # Normal refresh queue: do NOT mark unavailable (old token still valid) - # lib_logger.debug( - # f"Queued '{Path(path).name}' for refresh (still available). " - # f"Queue size: {self._refresh_queue.qsize() + 1}" - # ) - await self._refresh_queue.put((path, force)) - await self._ensure_queue_processor_running() - async def _process_refresh_queue(self): """Background worker that processes normal refresh requests sequentially. @@ -731,8 +667,7 @@ async def _process_refresh_queue(self): except httpx.HTTPStatusError as e: status_code = e.response.status_code # Check for invalid refresh token errors (400/401/403) - # These need to be routed to re-auth queue for interactive OAuth - needs_reauth = False + # These indicate the refresh token is invalid/revoked - mark as expired if status_code == 400: # Check if this is an invalid refresh token error @@ -748,26 +683,23 @@ async def _process_refresh_queue(self): "invalid" in error_desc.lower() or error_type == "invalid_request" ): - needs_reauth = True - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP 400: {error_desc}). " - f"Routing to re-auth queue." + self._queue_retry_count.pop(path, None) + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + self._mark_credential_expired( + path, + f"Refresh token invalid (HTTP 400: {error_desc})", + ) + else: + await self._handle_refresh_failure( + path, force, f"HTTP {status_code}" ) elif status_code in (401, 403): - needs_reauth = True - lib_logger.info( - f"Credential '{Path(path).name}' needs re-auth (HTTP {status_code}). " - f"Routing to re-auth queue." - ) - - if needs_reauth: - self._queue_retry_count.pop(path, None) # Clear retry count + self._queue_retry_count.pop(path, None) async with self._queue_tracking_lock: - self._queued_credentials.discard( - path - ) # Remove from queued - await self._queue_refresh( - path, force=True, needs_reauth=True + self._queued_credentials.discard(path) + self._mark_credential_expired( + path, f"Refresh token invalid (HTTP {status_code})" ) else: await self._handle_refresh_failure( @@ -829,65 +761,6 @@ async def _handle_refresh_failure(self, path: str, force: bool, error: str): # Keep in queued_credentials set, add back to queue await self._refresh_queue.put((path, force)) - async def _process_reauth_queue(self): - """Background worker that processes re-auth requests. - - Key behaviors: - - Credentials ARE marked unavailable (token is truly broken) - - Uses ReauthCoordinator for interactive OAuth - - No automatic retry (requires user action) - - Cleans up unavailable status when done - """ - # lib_logger.info("Re-auth queue processor started") - while True: - path = None - try: - # Wait for an item with timeout to allow graceful shutdown - try: - path = await asyncio.wait_for( - self._reauth_queue.get(), timeout=60.0 - ) - except asyncio.TimeoutError: - # Queue is empty and idle for 60s - exit - self._reauth_processor_task = None - # lib_logger.debug("Re-auth queue processor idle, shutting down") - return - - try: - lib_logger.info(f"Starting re-auth for '{Path(path).name}'...") - await self.initialize_token(path, force_interactive=True) - lib_logger.info(f"Re-auth SUCCESS for '{Path(path).name}'") - - except Exception as e: - lib_logger.error(f"Re-auth FAILED for '{Path(path).name}': {e}") - # No automatic retry for re-auth (requires user action) - - finally: - # Always clean up - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - # lib_logger.debug( - # f"Re-auth cleanup for '{Path(path).name}'. " - # f"Remaining unavailable: {len(self._unavailable_credentials)}" - # ) - self._reauth_queue.task_done() - - except asyncio.CancelledError: - # Clean up current credential before breaking - if path: - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - # lib_logger.debug("Re-auth queue processor cancelled") - break - except Exception as e: - lib_logger.error(f"Error in re-auth queue processor: {e}") - if path: - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._unavailable_credentials.pop(path, None) - async def _perform_interactive_oauth( self, path: str, creds: Dict[str, Any], display_name: str ) -> Dict[str, Any]: @@ -1085,15 +958,18 @@ async def initialize_token( """ Initialize OAuth token, triggering interactive device flow if needed. - If interactive OAuth is required (expired refresh token, missing credentials, etc.), - the flow is coordinated globally via ReauthCoordinator to ensure only one - interactive OAuth flow runs at a time across all providers. + For new credential setup (CLI tool), interactive OAuth is used when: + - Token is expired and refresh fails + - Refresh token is missing + + For proxy operation with force_interactive=True (deprecated): + - The credential is marked as permanently expired instead of interactive OAuth + - This prevents breaking the proxy flow with browser prompts Args: creds_or_path: Either a credentials dict or path to credentials file. - force_interactive: If True, skip expiry checks and force interactive OAuth. - Use this when the refresh token is known to be invalid - (e.g., after HTTP 400 from token endpoint). + force_interactive: If True, mark credential as expired (for proxy context). + For CLI tool, use the normal path (force_interactive=False). """ path = creds_or_path if isinstance(creds_or_path, str) else None @@ -1111,12 +987,20 @@ async def initialize_token( await self._load_credentials(creds_or_path) if path else creds_or_path ) - reason = "" + # If force_interactive is True, this was called from proxy context + # where re-auth was requested. Instead of interactive OAuth, mark as expired. if force_interactive: - reason = ( - "re-authentication was explicitly requested (refresh token invalid)" + if path: + self._mark_credential_expired( + path, "Refresh token invalid - re-authentication required" + ) + raise ValueError( + f"Credential '{display_name}' requires re-authentication. " + f"Run 'credential_tool.py' to manually re-authenticate, then restart proxy." ) - elif not creds.get("refresh_token"): + + reason = "" + if not creds.get("refresh_token"): reason = "refresh token is missing" elif self._is_token_expired(creds): reason = "token is expired" @@ -1127,30 +1011,29 @@ async def initialize_token( return await self._refresh_token(path) except Exception as e: lib_logger.warning( - f"Automatic token refresh for '{display_name}' failed: {e}. Proceeding to interactive login." + f"Automatic token refresh for '{display_name}' failed: {e}." ) + # Fall through to handle expired credential - lib_logger.warning( - f"Qwen OAuth token for '{display_name}' needs setup: {reason}." - ) - - # [GLOBAL REAUTH COORDINATION] Use the global coordinator to ensure - # only one interactive OAuth flow runs at a time across all providers - coordinator = get_reauth_coordinator() - - # Define the interactive OAuth function to be executed by coordinator - async def _do_interactive_oauth(): - return await self._perform_interactive_oauth( - path, creds, display_name + # Distinguish between proxy context (has path) and credential tool context (no path) + # - Proxy context: mark as expired and fail (no interactive OAuth during proxy operation) + # - Credential tool context: do interactive OAuth for new credential setup + if path: + # [NO AUTO-REAUTH] Proxy context - mark as permanently expired + self._mark_credential_expired( + path, + f"{reason}. Manual re-authentication required via credential_tool.py", + ) + raise ValueError( + f"Credential '{display_name}' is expired and requires manual re-authentication. " + f"Run 'python credential_tool.py' to fix, then restart the proxy." ) - # Execute via global coordinator (ensures only one at a time) - return await coordinator.execute_reauth( - credential_path=path or display_name, - provider_name="QWEN_CODE", - reauth_func=_do_interactive_oauth, - timeout=300.0, # 5 minute timeout for user to complete OAuth + # Credential tool context - do interactive OAuth for new credential setup + lib_logger.warning( + f"Qwen OAuth token for '{display_name}' needs setup: {reason}." ) + return await self._perform_interactive_oauth(path, creds, display_name) lib_logger.info(f"Qwen OAuth token at '{display_name}' is valid.") return creds diff --git a/src/rotator_library/providers/qwen_code_provider.py b/src/rotator_library/providers/qwen_code_provider.py index 83630f9a..f96aa88b 100644 --- a/src/rotator_library/providers/qwen_code_provider.py +++ b/src/rotator_library/providers/qwen_code_provider.py @@ -9,7 +9,7 @@ import os import httpx import logging -from typing import Union, AsyncGenerator, List, Dict, Any +from typing import Union, AsyncGenerator, List, Dict, Any, Optional from .provider_interface import ProviderInterface from .qwen_auth_base import QwenAuthBase from ..model_definitions import ModelDefinitions @@ -233,16 +233,30 @@ def _build_request_payload(self, **kwargs) -> Dict[str, Any]: return payload - def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): + def _convert_chunk_to_openai( + self, + chunk: Dict[str, Any], + model_id: str, + stream_state: Optional[Dict[str, Any]] = None, + ): """ Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk. CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk) without early return to ensure finish_reason is properly processed. + + Args: + chunk: Raw chunk from Qwen API + model_id: Model identifier for response + stream_state: Mutable dict to track state across chunks (e.g., tool_calls seen) """ if not isinstance(chunk, dict): return + # Initialize stream_state if not provided + if stream_state is None: + stream_state = {} + # Get choices and usage data choices = chunk.get("choices", []) usage_data = chunk.get("usage") @@ -250,13 +264,24 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): chunk_created = chunk.get("created", int(time.time())) # Handle chunks with BOTH choices and usage (typical for final chunk) - # CRITICAL: Process choices FIRST to capture finish_reason, then yield usage + # CRITICAL: Keep as single chunk - don't split! Client needs usage to detect final chunk. if choices and usage_data: choice = choices[0] delta = choice.get("delta", {}) finish_reason = choice.get("finish_reason") - # Yield the choice chunk first (contains finish_reason) + # Track tool_calls presence for finish_reason normalization + if delta.get("tool_calls"): + stream_state["has_tool_calls"] = True + + # Ensure finish_reason is set for final chunks (with usage data) + # Priority: tool_calls > original finish_reason > default "stop" + if stream_state.get("has_tool_calls"): + finish_reason = "tool_calls" + elif not finish_reason: + finish_reason = "stop" + + # Yield single chunk with BOTH choices and usage yield { "choices": [ {"index": 0, "delta": delta, "finish_reason": finish_reason} @@ -265,14 +290,6 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): "object": "chat.completion.chunk", "id": chunk_id, "created": chunk_created, - } - # Then yield the usage chunk - yield { - "choices": [], - "model": model_id, - "object": "chat.completion.chunk", - "id": chunk_id, - "created": chunk_created, "usage": { "prompt_tokens": usage_data.get("prompt_tokens", 0), "completion_tokens": usage_data.get("completion_tokens", 0), @@ -281,20 +298,51 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): } return - # Handle usage-only chunks - if usage_data: - yield { - "choices": [], - "model": model_id, - "object": "chat.completion.chunk", - "id": chunk_id, - "created": chunk_created, - "usage": { - "prompt_tokens": usage_data.get("prompt_tokens", 0), - "completion_tokens": usage_data.get("completion_tokens", 0), - "total_tokens": usage_data.get("total_tokens", 0), - }, - } + # Handle usage-only chunks (Qwen API sends finish_reason and usage separately) + # Check if we have a buffered finish_reason chunk to combine with + if usage_data and not choices: + pending = stream_state.pop("pending_final_chunk", None) + if pending: + # Combine buffered finish_reason chunk with this usage chunk + finish_reason = pending["finish_reason"] + # Apply tool_calls priority + if stream_state.get("has_tool_calls"): + finish_reason = "tool_calls" + elif not finish_reason: + finish_reason = "stop" + + yield { + "choices": [ + { + "index": 0, + "delta": pending["delta"], + "finish_reason": finish_reason, + } + ], + "model": model_id, + "object": "chat.completion.chunk", + "id": pending.get("id", chunk_id), + "created": pending.get("created", chunk_created), + "usage": { + "prompt_tokens": usage_data.get("prompt_tokens", 0), + "completion_tokens": usage_data.get("completion_tokens", 0), + "total_tokens": usage_data.get("total_tokens", 0), + }, + } + else: + # No pending chunk - yield usage-only as fallback + yield { + "choices": [], + "model": model_id, + "object": "chat.completion.chunk", + "id": chunk_id, + "created": chunk_created, + "usage": { + "prompt_tokens": usage_data.get("prompt_tokens", 0), + "completion_tokens": usage_data.get("completion_tokens", 0), + "total_tokens": usage_data.get("total_tokens", 0), + }, + } return # Handle content-only chunks @@ -305,6 +353,30 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): delta = choice.get("delta", {}) finish_reason = choice.get("finish_reason") + # Track tool_calls presence for finish_reason normalization + if delta.get("tool_calls"): + stream_state["has_tool_calls"] = True + + # Normalize finish_reason: if tool_calls were seen, ensure finish_reason reflects this + if ( + finish_reason + and stream_state.get("has_tool_calls") + and finish_reason != "tool_calls" + ): + finish_reason = "tool_calls" + + # If this chunk has finish_reason but no usage, buffer it for combining with usage chunk + # Qwen API sends finish_reason and usage in separate chunks + if finish_reason: + stream_state["pending_final_chunk"] = { + "delta": delta, + "finish_reason": finish_reason, + "id": chunk_id, + "created": chunk_created, + } + # Don't yield yet - wait for usage chunk to combine + return + # Handle tags for reasoning content content = delta.get("content") if content and ("" in content or "" in content): @@ -337,11 +409,9 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): "created": chunk_created, } else: - # Standard content chunk + # Standard content chunk (no finish_reason) yield { - "choices": [ - {"index": 0, "delta": delta, "finish_reason": finish_reason} - ], + "choices": [{"index": 0, "delta": delta, "finish_reason": None}], "model": model_id, "object": "chat.completion.chunk", "id": chunk_id, @@ -379,7 +449,25 @@ def _stream_to_completion_response( continue choice = chunk.choices[0] - delta = choice.get("delta", {}) + # Handle both dict and object access patterns for choice.delta + if hasattr(choice, "get"): + delta = choice.get("delta", {}) + choice_finish = choice.get("finish_reason") + elif hasattr(choice, "delta"): + delta = choice.delta if choice.delta else {} + # Convert delta to dict if it's an object + if hasattr(delta, "__dict__") and not isinstance(delta, dict): + delta = { + k: v + for k, v in delta.__dict__.items() + if not k.startswith("_") and v is not None + } + elif hasattr(delta, "model_dump"): + delta = delta.model_dump(exclude_none=True) + choice_finish = getattr(choice, "finish_reason", None) + else: + delta = {} + choice_finish = None # Aggregate content if "content" in delta and delta["content"] is not None: @@ -443,8 +531,8 @@ def _stream_to_completion_response( ]["arguments"] # Track finish_reason from chunks (for reference only) - if choice.get("finish_reason"): - chunk_finish_reason = choice["finish_reason"] + if choice_finish: + chunk_finish_reason = choice_finish # Handle usage data from the last chunk that has it for chunk in reversed(chunks): @@ -535,6 +623,8 @@ async def make_request(): async def stream_handler(response_stream, attempt=1): """Handles the streaming response and converts chunks.""" + # Track state across chunks for finish_reason normalization + stream_state: Dict[str, Any] = {} try: async with response_stream as response: # Check for HTTP errors before processing stream @@ -589,7 +679,7 @@ async def stream_handler(response_stream, attempt=1): try: chunk = json.loads(data_str) for openai_chunk in self._convert_chunk_to_openai( - chunk, model + chunk, model, stream_state ): yield litellm.ModelResponse(**openai_chunk) except json.JSONDecodeError: diff --git a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py index f399222d..e9711bce 100644 --- a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py +++ b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py @@ -34,7 +34,7 @@ from .base_quota_tracker import BaseQuotaTracker, QUOTA_DISCOVERY_DELAY_SECONDS if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") @@ -51,14 +51,18 @@ # Learned values (from file) override these defaults if available. DEFAULT_MAX_REQUESTS: Dict[str, Dict[str, int]] = { - "standard-tier": { + # Canonical tier names + "PRO": { # Claude/GPT-OSS group (verified: 0.6667% per request = 150 requests) "claude-sonnet-4-5": 150, "claude-sonnet-4-5-thinking": 150, "claude-opus-4-5": 150, "claude-opus-4-5-thinking": 150, + "claude-opus-4-6": 150, + "claude-opus-4-6-thinking": 150, "claude-sonnet-4.5": 150, "claude-opus-4.5": 150, + "claude-opus-4.6": 150, "gpt-oss-120b-medium": 150, # Gemini 3 Pro group (verified: 0.3125% per request = 320 requests) "gemini-3-pro-high": 320, @@ -74,14 +78,17 @@ # Gemini 2.5 Pro - UNVERIFIED/UNUSED (assumed 0.1% = 1000 requests) "gemini-2.5-pro": 1, }, - "free-tier": { + "FREE": { # Claude/GPT-OSS group (verified: 2.0% per request = 50 requests) "claude-sonnet-4-5": 50, "claude-sonnet-4-5-thinking": 50, "claude-opus-4-5": 50, "claude-opus-4-5-thinking": 50, + "claude-opus-4-6": 50, + "claude-opus-4-6-thinking": 50, "claude-sonnet-4.5": 50, "claude-opus-4.5": 50, + "claude-opus-4.6": 50, "gpt-oss-120b-medium": 50, # Gemini 3 Pro group (verified: 0.6667% per request = 150 requests) "gemini-3-pro-high": 150, @@ -99,6 +106,10 @@ }, } +# Legacy tier name aliases (backwards compatibility) +DEFAULT_MAX_REQUESTS["standard-tier"] = DEFAULT_MAX_REQUESTS["PRO"] +DEFAULT_MAX_REQUESTS["free-tier"] = DEFAULT_MAX_REQUESTS["FREE"] + # Default max requests for unknown models (1% = 100 requests) DEFAULT_MAX_REQUESTS_UNKNOWN = 100 @@ -112,6 +123,8 @@ _USER_TO_API_MODEL_MAP: Dict[str, str] = { "claude-opus-4-5": "claude-opus-4-5-thinking", # Opus only exists as -thinking in API (legacy) "claude-opus-4.5": "claude-opus-4-5-thinking", # Opus only exists as -thinking in API (new format) + "claude-opus-4-6": "claude-opus-4-6-thinking", # Opus only exists as -thinking in API (legacy) + "claude-opus-4.6": "claude-opus-4-6-thinking", # Opus only exists as -thinking in API (new format) "gemini-3-pro-preview": "gemini-3-pro-high", # Preview maps to high by default } @@ -119,6 +132,8 @@ _API_TO_USER_MODEL_MAP: Dict[str, str] = { "claude-opus-4-5-thinking": "claude-opus-4.5", # Normalize to new user-facing name "claude-opus-4-5": "claude-opus-4.5", # Normalize old format to new + "claude-opus-4-6-thinking": "claude-opus-4.6", # Normalize to new user-facing name + "claude-opus-4-6": "claude-opus-4.6", # Normalize old format to new "claude-sonnet-4-5-thinking": "claude-sonnet-4.5", # Normalize to new user-facing name "claude-sonnet-4-5": "claude-sonnet-4.5", # Normalize old format to new "gemini-3-pro-high": "gemini-3-pro-preview", # Could map to preview (but high is valid too) @@ -393,7 +408,7 @@ async def _make_test_request( headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", - **self._get_antigravity_headers(), + **self._get_antigravity_headers(credential_path), } payload = { @@ -479,7 +494,7 @@ async def fetch_quota_from_api( headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", - **self._get_antigravity_headers(), + **self._get_antigravity_headers(credential_path), } payload = {"project": project_id} if project_id else {} @@ -990,13 +1005,22 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + force: bool = False, + is_initial_fetch: bool = False, ) -> int: """ Store fetched quota baselines into UsageManager. + Antigravity-specific override: API quota only updates in ~20% increments, + so local tracking is more accurate. Exhaustion check is only applied on + initial fetch (restart), not on subsequent background refreshes. + Args: quota_results: Dict from fetch_quota_from_api or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + force: If True, always use API values (for manual refresh) + is_initial_fetch: If True, this is the first fetch on startup. + Exhaustion check is only applied when this is True. Returns: Number of baselines successfully stored @@ -1009,6 +1033,8 @@ async def _store_baselines_to_usage_manager( # Aggregate cooldown info for consolidated logging # Structure: {short_cred_name: {group_or_model: hours_until_reset}} cooldowns_by_cred: Dict[str, Dict[str, float]] = {} + # Track cleared cooldowns for consolidated logging + cleared_cooldowns_by_cred: Dict[str, List[str]] = {} for cred_path, quota_data in quota_results.items(): if quota_data.get("status") != "success": @@ -1052,16 +1078,53 @@ async def _store_baselines_to_usage_manager( # Extract reset_timestamp (already parsed to float in fetch_quota_from_api) reset_timestamp = model_info.get("reset_timestamp") + # Only use reset_timestamp when quota is actually used + # (remaining == 1.0 means 100% left, timer is bogus) + valid_reset_ts = reset_timestamp if remaining < 1.0 else None + # Store with provider prefix for consistency with usage tracking prefixed_model = f"antigravity/{user_model}" + quota_used = None + if max_requests is not None: + quota_used = int((1.0 - remaining) * max_requests) + quota_group = self.get_model_quota_group(user_model) + + # ANTIGRAVITY-SPECIFIC: Only apply exhaustion on initial fetch + # (API only updates in ~20% increments, so we rely on local tracking + # for subsequent refreshes) + apply_exhaustion = is_initial_fetch and (remaining == 0.0) + + # Clear cooldowns if API shows quota is available + # This handles both force refresh and baseline refresh (proxy restart) + if remaining > 0.0: + cooldown_target = quota_group or prefixed_model + cleared = await usage_manager.clear_cooldown_if_exists( + cred_path, + model_or_group=cooldown_target, + ) + if cleared: + if short_cred not in cleared_cooldowns_by_cred: + cleared_cooldowns_by_cred[short_cred] = [] + if cooldown_target not in cleared_cooldowns_by_cred[short_cred]: + cleared_cooldowns_by_cred[short_cred].append( + cooldown_target + ) + cooldown_info = await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests, reset_timestamp + cred_path, + prefixed_model, + quota_max_requests=max_requests, + quota_reset_ts=valid_reset_ts, + quota_used=quota_used, + quota_group=quota_group, + force=force, + apply_exhaustion=apply_exhaustion, ) # Aggregate cooldown info if returned if cooldown_info: - group_or_model = cooldown_info["group_or_model"] - hours = cooldown_info["hours_until_reset"] + group_or_model = cooldown_info["model"] + hours = cooldown_info["cooldown_hours"] if short_cred not in cooldowns_by_cred: cooldowns_by_cred[short_cred] = {} # Only keep first occurrence per group/model (avoids duplicates) @@ -1073,15 +1136,19 @@ async def _store_baselines_to_usage_manager( # Log consolidated message for all cooldowns if cooldowns_by_cred: - # Build message: "oauth_1[claude 3.4h, gemini-3-pro 2.1h], oauth_2[claude 5.2h]" - parts = [] - for cred_name, groups in sorted(cooldowns_by_cred.items()): - group_strs = [f"{g} {h:.1f}h" for g, h in sorted(groups.items())] - parts.append(f"{cred_name}[{', '.join(group_strs)}]") - lib_logger.info(f"Antigravity quota exhausted: {', '.join(parts)}") + lib_logger.debug("Antigravity quota baseline refresh: cooldowns recorded") else: lib_logger.debug("Antigravity quota baseline refresh: no cooldowns needed") + # Log consolidated message for cleared cooldowns + if cleared_cooldowns_by_cred: + total_cleared = sum(len(v) for v in cleared_cooldowns_by_cred.values()) + lib_logger.info( + f"Antigravity baseline refresh: cleared cooldowns for " + f"{total_cleared} model/group(s) across " + f"{len(cleared_cooldowns_by_cred)} credential(s)" + ) + return stored_count async def discover_quota_costs( diff --git a/src/rotator_library/providers/utilities/base_quota_tracker.py b/src/rotator_library/providers/utilities/base_quota_tracker.py index a155890f..c089e74e 100644 --- a/src/rotator_library/providers/utilities/base_quota_tracker.py +++ b/src/rotator_library/providers/utilities/base_quota_tracker.py @@ -40,7 +40,7 @@ from ...utils.paths import get_cache_dir if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") @@ -268,7 +268,7 @@ def get_quota_cost(self, model: str, tier: str) -> float: # Fall back to defaults tier_costs = self.default_quota_costs.get( - tier, self.default_quota_costs.get("standard-tier", {}) + tier, self.default_quota_costs.get("PRO", {}) ) return tier_costs.get(clean_model, self.default_quota_cost_unknown) @@ -510,6 +510,8 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + force: bool = False, + is_initial_fetch: bool = False, ) -> int: """ Store fetched quota baselines into UsageManager. @@ -517,6 +519,10 @@ async def _store_baselines_to_usage_manager( Args: quota_results: Dict from _fetch_quota_for_credential or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + force: If True, always use API values (for manual refresh) + is_initial_fetch: If True, this is the first fetch on startup. + For default providers (API is authoritative), exhaustion + is applied whenever remaining == 0.0 regardless of this flag. Returns: Number of baselines successfully stored @@ -529,7 +535,7 @@ async def _store_baselines_to_usage_manager( continue # Get tier for this credential - tier = self.project_tier_cache.get(cred_path, "standard-tier") + tier = self.project_tier_cache.get(cred_path, "PRO") # Extract model quota data using subclass implementation model_quotas = self._extract_model_quota_from_response(quota_data, tier) @@ -543,13 +549,47 @@ async def _store_baselines_to_usage_manager( max_requests = self.get_max_requests_for_model(user_model, tier) # Store baseline + quota_used = None + if max_requests is not None: + quota_used = int((1.0 - remaining) * max_requests) + quota_group = self.get_model_quota_group(user_model) + + # Only use reset_timestamp when quota is actually used + # (remaining == 1.0 means 100% left, timer is bogus) + bucket = self._find_bucket_for_model(quota_data, user_model) + reset_timestamp = bucket.get("reset_timestamp") if bucket else None + valid_reset_ts = reset_timestamp if remaining < 1.0 else None + + # DEFAULT: Always apply exhaustion if remaining == 0.0 exactly + # (API is authoritative for most providers) + apply_exhaustion = remaining == 0.0 + await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests=max_requests + cred_path, + prefixed_model, + quota_max_requests=max_requests, + quota_reset_ts=valid_reset_ts, + quota_used=quota_used, + quota_group=quota_group, + force=force, + apply_exhaustion=apply_exhaustion, ) stored_count += 1 return stored_count + def _find_bucket_for_model( + self, quota_data: Dict[str, Any], user_model: str + ) -> Optional[Dict[str, Any]]: + """Find the bucket data for a specific model in quota response.""" + for bucket in quota_data.get("buckets", []): + model_id = bucket.get("model_id") + if model_id: + bucket_user_model = self._api_to_user_model(model_id) + if bucket_user_model == user_model: + return bucket + return None + # ========================================================================= # QUOTA COST DISCOVERY # ========================================================================= @@ -787,7 +827,7 @@ def _get_model_remaining_from_quota( Returns: Remaining fraction (0.0 to 1.0) or None if not found """ - tier = quota_data.get("tier", "standard-tier") + tier = quota_data.get("tier", "PRO") model_quotas = self._extract_model_quota_from_response(quota_data, tier) clean_model = model.split("/")[-1] if "/" in model else model diff --git a/src/rotator_library/providers/utilities/device_profile.py b/src/rotator_library/providers/utilities/device_profile.py new file mode 100644 index 00000000..c864a8d4 --- /dev/null +++ b/src/rotator_library/providers/utilities/device_profile.py @@ -0,0 +1,764 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +# src/rotator_library/providers/utilities/device_profile.py +""" +Device fingerprint generation, binding, and storage for Antigravity provider. + +This module provides complete device fingerprinting for rate-limit mitigation. + +Each credential gets a unique, persistent fingerprint that includes: +- User-Agent: antigravity/{FIXED_VERSION} {platform}/{arch} +- X-Goog-Api-Client: randomized SDK client string +- X-Goog-QuotaUser: device-{random_hex} +- X-Client-Device-Id: UUID v4 +- Client-Metadata: JSON with IDE/platform/OS info + legacy hardware IDs + +Fingerprints are stored per-credential in cache/device_profiles/{email_hash}.json +with version history for audit purposes. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import random +import secrets +import time +import uuid +from dataclasses import dataclass, asdict, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from ...utils.paths import get_cache_dir + +lib_logger = logging.getLogger("rotator_library") + +# Cache subdirectory for device profiles +DEVICE_PROFILES_SUBDIR = "device_profiles" + +# ============================================================================= +# FINGERPRINT CONSTANTS +# ============================================================================= + +# Fixed version - does NOT randomize per user request +ANTIGRAVITY_VERSION = "1.15.8" + +# Platform configurations with OS versions +PLATFORMS = { + "win32": { + "name": "WINDOWS", + "os_versions": [ + "10.0.19041", + "10.0.19042", + "10.0.19043", + "10.0.22000", + "10.0.22621", + "10.0.22631", + ], + }, + "darwin": { + "name": "MACOS", + "os_versions": ["10.15.7", "11.6.8", "12.6.3", "13.5.2", "14.2.1", "14.5"], + }, + "linux": { + "name": "LINUX", + "os_versions": ["5.15.0", "5.19.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0"], + }, +} + +# Architecture options +ARCHITECTURES = ["x64", "arm64"] + +# SDK client strings (randomized per credential) +SDK_CLIENTS = [ + "google-cloud-sdk vscode_cloudshelleditor/0.1", + "google-cloud-sdk vscode/1.96.0", + "google-cloud-sdk vscode/1.95.0", + #"google-cloud-sdk jetbrains/2024.3", + #"google-cloud-sdk intellij/2024.1", + #"google-cloud-sdk android-studio/2024.1", +] + +# IDE types for Client-Metadata +IDE_TYPES = [ + "VSCODE", + #"INTELLIJ", + #"ANDROID_STUDIO", + #"CLOUD_SHELL_EDITOR", +] + + +# ============================================================================= +# LEGACY DEVICE PROFILE (kept for backward compatibility) +# ============================================================================= + + +@dataclass +class DeviceProfile: + """ + Legacy device profile containing 4 hardware identifiers. + + Kept for backward compatibility with existing stored profiles. + New code should use DeviceFingerprint instead. + """ + + machine_id: str + mac_machine_id: str + dev_device_id: str + sqm_id: str + + def to_dict(self) -> Dict[str, str]: + """Convert to dictionary for JSON serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, str]) -> "DeviceProfile": + """Create from dictionary.""" + return cls( + machine_id=data["machine_id"], + mac_machine_id=data["mac_machine_id"], + dev_device_id=data["dev_device_id"], + sqm_id=data["sqm_id"], + ) + + +# ============================================================================= +# DEVICE FINGERPRINT (new complete implementation) +# ============================================================================= + + +@dataclass +class DeviceFingerprint: + """ + Complete device fingerprint for a credential. + + Contains all necessary hardware identifiers and metadata for API authentication. + """ + + # === HTTP Header Fields === + user_agent: str # "antigravity/1.15.8 win32/x64" + api_client: str # "google-cloud-sdk vscode/1.96.0" + quota_user: str # "device-a1b2c3d4e5f6" + device_id: str # UUID v4 for X-Client-Device-Id header + + # === Client-Metadata Fields === + ide_type: str # "VSCODE", "INTELLIJ", etc. + platform: str # "WINDOWS", "MACOS", "LINUX" + platform_raw: str # "win32", "darwin", "linux" (for UA) + arch: str # "x64", "arm64" + os_version: str # "10.0.22631", "14.5", etc. + sqm_id: str # "{UUID}" uppercase in braces + plugin_type: str # "GEMINI" (always) + + # === Legacy Hardware IDs (kept for compatibility) === + machine_id: str # "auth0|user_{hex}" + mac_machine_id: str # Custom UUID v4 + dev_device_id: str # Standard UUID v4 + + # === Metadata === + created_at: int # Unix timestamp + session_token: str # 16-byte hex + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DeviceFingerprint": + """Create from dictionary.""" + return cls( + user_agent=data["user_agent"], + api_client=data["api_client"], + quota_user=data["quota_user"], + device_id=data["device_id"], + ide_type=data["ide_type"], + platform=data["platform"], + platform_raw=data["platform_raw"], + arch=data["arch"], + os_version=data["os_version"], + sqm_id=data["sqm_id"], + plugin_type=data["plugin_type"], + machine_id=data["machine_id"], + mac_machine_id=data["mac_machine_id"], + dev_device_id=data["dev_device_id"], + created_at=data["created_at"], + session_token=data["session_token"], + ) + + def to_legacy_profile(self) -> DeviceProfile: + """Convert to legacy DeviceProfile for backward compatibility.""" + return DeviceProfile( + machine_id=self.machine_id, + mac_machine_id=self.mac_machine_id, + dev_device_id=self.dev_device_id, + sqm_id=self.sqm_id, + ) + + +@dataclass +class DeviceFingerprintVersion: + """ + Versioned device fingerprint with metadata for history tracking. + """ + + id: str # Random UUID v4 for this version + created_at: int # Unix timestamp + label: str # e.g., "auto_generated", "upgraded", "regenerated" + fingerprint: DeviceFingerprint + is_current: bool = True + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "id": self.id, + "created_at": self.created_at, + "label": self.label, + "fingerprint": self.fingerprint.to_dict(), + "is_current": self.is_current, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DeviceFingerprintVersion": + """Create from dictionary.""" + return cls( + id=data["id"], + created_at=data["created_at"], + label=data["label"], + fingerprint=DeviceFingerprint.from_dict(data["fingerprint"]), + is_current=data.get("is_current", False), + ) + + +@dataclass +class CredentialDeviceData: + """ + Complete device data for a credential, including current fingerprint and history. + """ + + email: str + current_fingerprint: Optional[DeviceFingerprint] = None + fingerprint_history: List[DeviceFingerprintVersion] = field(default_factory=list) + + # Legacy fields for migration + current_profile: Optional[DeviceProfile] = None + device_history: List[Dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "email": self.email, + "current_fingerprint": ( + self.current_fingerprint.to_dict() if self.current_fingerprint else None + ), + "fingerprint_history": [v.to_dict() for v in self.fingerprint_history], + # Legacy fields (kept for backward compat, but prefer new fields) + "current_profile": ( + self.current_profile.to_dict() if self.current_profile else None + ), + "device_history": self.device_history, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CredentialDeviceData": + """Create from dictionary with silent upgrade support.""" + current_fp = data.get("current_fingerprint") + fp_history = data.get("fingerprint_history", []) + + # Legacy profile data + current_profile_data = data.get("current_profile") + device_history = data.get("device_history", []) + + return cls( + email=data["email"], + current_fingerprint=( + DeviceFingerprint.from_dict(current_fp) if current_fp else None + ), + fingerprint_history=[ + DeviceFingerprintVersion.from_dict(v) for v in fp_history + ], + current_profile=( + DeviceProfile.from_dict(current_profile_data) + if current_profile_data + else None + ), + device_history=device_history, + ) + + +# ============================================================================= +# ID GENERATION FUNCTIONS +# ============================================================================= + + +def random_hex(length: int) -> str: + """ + Generate a random lowercase alphanumeric string. + + Args: + length: Number of characters to generate + + Returns: + Random alphanumeric string (lowercase) + """ + import string + + chars = string.ascii_lowercase + string.digits + return "".join(random.choice(chars) for _ in range(length)) + + +def new_standard_machine_id() -> str: + """ + Generate a UUID v4 format string with custom builder. + + Format: xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx + where x is random hex [0-f] and y is random hex [8-b] + + Returns: + UUID v4 format string + """ + + def rand_hex(n: int) -> str: + return "".join(random.choice("0123456789abcdef") for _ in range(n)) + + # y must be in range 8-b (UUID v4 variant bits) + y = random.choice("89ab") + + return f"{rand_hex(8)}-{rand_hex(4)}-4{rand_hex(3)}-{y}{rand_hex(3)}-{rand_hex(12)}" + + +def generate_device_fingerprint() -> DeviceFingerprint: + """ + Generate a complete device fingerprint. + + Returns: + New DeviceFingerprint with all fields populated + """ + # Pick platform (raw + name + os_version together for consistency) + platform_raw = random.choice(list(PLATFORMS.keys())) # "win32" + platform_info = PLATFORMS[platform_raw] + platform_name = platform_info["name"] # "WINDOWS" + os_version = random.choice(platform_info["os_versions"]) + + # Pick arch + arch = random.choice(ARCHITECTURES) # "x64" + + # Build user agent with FIXED version + user_agent = f"antigravity/{ANTIGRAVITY_VERSION} {platform_raw}/{arch}" + + # Other randomized fields (picked once, persisted) + api_client = random.choice(SDK_CLIENTS) + quota_user = f"device-{secrets.token_hex(8)}" + device_id = str(uuid.uuid4()) + ide_type = random.choice(IDE_TYPES) + sqm_id = f"{{{str(uuid.uuid4()).upper()}}}" + session_token = secrets.token_hex(16) + + # Legacy hardware IDs (for compatibility) + machine_id = f"auth0|user_{random_hex(32)}" + mac_machine_id = new_standard_machine_id() + dev_device_id = str(uuid.uuid4()) + + return DeviceFingerprint( + user_agent=user_agent, + api_client=api_client, + quota_user=quota_user, + device_id=device_id, + ide_type=ide_type, + platform=platform_name, + platform_raw=platform_raw, + arch=arch, + os_version=os_version, + sqm_id=sqm_id, + plugin_type="GEMINI", + machine_id=machine_id, + mac_machine_id=mac_machine_id, + dev_device_id=dev_device_id, + created_at=int(time.time()), + session_token=session_token, + ) + + +def upgrade_legacy_profile(profile: DeviceProfile) -> DeviceFingerprint: + """ + Upgrade a legacy DeviceProfile to a full DeviceFingerprint. + + Preserves the legacy hardware IDs and generates the missing fields. + + Args: + profile: Legacy DeviceProfile to upgrade + + Returns: + New DeviceFingerprint with legacy IDs preserved + """ + # Pick platform (raw + name + os_version together for consistency) + platform_raw = random.choice(list(PLATFORMS.keys())) + platform_info = PLATFORMS[platform_raw] + platform_name = platform_info["name"] + os_version = random.choice(platform_info["os_versions"]) + + # Pick arch + arch = random.choice(ARCHITECTURES) + + # Build user agent with FIXED version + user_agent = f"antigravity/{ANTIGRAVITY_VERSION} {platform_raw}/{arch}" + + # Generate missing fields + api_client = random.choice(SDK_CLIENTS) + quota_user = f"device-{secrets.token_hex(8)}" + device_id = str(uuid.uuid4()) + ide_type = random.choice(IDE_TYPES) + session_token = secrets.token_hex(16) + + return DeviceFingerprint( + user_agent=user_agent, + api_client=api_client, + quota_user=quota_user, + device_id=device_id, + ide_type=ide_type, + platform=platform_name, + platform_raw=platform_raw, + arch=arch, + os_version=os_version, + sqm_id=profile.sqm_id, # Preserve + plugin_type="GEMINI", + machine_id=profile.machine_id, # Preserve + mac_machine_id=profile.mac_machine_id, # Preserve + dev_device_id=profile.dev_device_id, # Preserve + created_at=int(time.time()), + session_token=session_token, + ) + + +# ============================================================================= +# HEADER BUILDER +# ============================================================================= + + +def build_fingerprint_headers(fp: DeviceFingerprint) -> Dict[str, str]: + """ + Build all 5 HTTP headers from a fingerprint. + + Args: + fp: DeviceFingerprint to build headers from + + Returns: + Dict with User-Agent, X-Goog-Api-Client, Client-Metadata, + X-Goog-QuotaUser, X-Client-Device-Id + """ + client_metadata = { + "ideType": fp.ide_type, + "platform": fp.platform, + "pluginType": fp.plugin_type, + "osVersion": fp.os_version, + "arch": fp.arch, + "sqmId": fp.sqm_id, + } + + return { + "User-Agent": fp.user_agent, + "X-Goog-Api-Client": fp.api_client, + "Client-Metadata": json.dumps(client_metadata), + "X-Goog-QuotaUser": fp.quota_user, + "X-Client-Device-Id": fp.device_id, + } + + +# ============================================================================= +# STORAGE AND RETRIEVAL +# ============================================================================= + + +def _get_email_hash(email: str) -> str: + """Get a safe filename hash for an email address.""" + return hashlib.sha256(email.lower().encode()).hexdigest()[:16] + + +def _get_profile_path(email: str) -> Path: + """Get the path to the device profile file for an email.""" + cache_dir = get_cache_dir(subdir=DEVICE_PROFILES_SUBDIR) + return cache_dir / f"{_get_email_hash(email)}.json" + + +def load_credential_device_data(email: str) -> Optional[CredentialDeviceData]: + """ + Load device data for a credential from disk. + + Args: + email: Email address of the credential + + Returns: + CredentialDeviceData if found, None otherwise + """ + profile_path = _get_profile_path(email) + if not profile_path.exists(): + return None + + try: + with open(profile_path, "r") as f: + data = json.load(f) + return CredentialDeviceData.from_dict(data) + except (json.JSONDecodeError, KeyError, FileNotFoundError) as e: + lib_logger.warning(f"Failed to load device profile for {email}: {e}") + return None + + +def save_credential_device_data(data: CredentialDeviceData) -> bool: + """ + Save device data for a credential to disk. + + Args: + data: CredentialDeviceData to save + + Returns: + True if saved successfully, False otherwise + """ + profile_path = _get_profile_path(data.email) + + try: + # Ensure directory exists + profile_path.parent.mkdir(parents=True, exist_ok=True) + + # Write atomically + temp_path = profile_path.with_suffix(".tmp") + with open(temp_path, "w") as f: + json.dump(data.to_dict(), f, indent=2) + + # Atomic rename + temp_path.replace(profile_path) + + lib_logger.debug(f"Saved device fingerprint for {data.email}") + return True + except Exception as e: + lib_logger.error(f"Failed to save device fingerprint for {data.email}: {e}") + return False + + +# ============================================================================= +# HIGH-LEVEL API +# ============================================================================= + + +def get_or_create_fingerprint( + email: str, auto_generate: bool = True +) -> Optional[DeviceFingerprint]: + """ + Get the current device fingerprint for a credential, optionally creating one. + + Handles silent upgrade from legacy DeviceProfile to DeviceFingerprint. + + Args: + email: Email address of the credential + auto_generate: If True and no fingerprint exists, generate one + + Returns: + DeviceFingerprint if available, None otherwise + """ + data = load_credential_device_data(email) + + # Check for existing fingerprint + if data and data.current_fingerprint: + return data.current_fingerprint + + # Check for legacy profile and upgrade (silent upgrade) + if data and data.current_profile: + lib_logger.info(f"Upgrading legacy device profile to fingerprint for {email}") + fingerprint = upgrade_legacy_profile(data.current_profile) + _save_fingerprint(data, fingerprint, label="upgraded") + return fingerprint + + if not auto_generate: + return None + + # Generate new fingerprint + return bind_new_fingerprint(email, label="auto_generated") + + +def bind_new_fingerprint( + email: str, + label: str = "auto_generated", + fingerprint: Optional[DeviceFingerprint] = None, +) -> DeviceFingerprint: + """ + Bind a new device fingerprint to a credential. + + Creates a new fingerprint (or uses provided one), marks it as current, + and adds it to the version history. + + Args: + email: Email address of the credential + label: Label for this version (e.g., "auto_generated", "regenerated") + fingerprint: Optional fingerprint to bind. If None, generates a new one. + + Returns: + The bound DeviceFingerprint + """ + # Load existing data or create new + data = load_credential_device_data(email) + if not data: + data = CredentialDeviceData(email=email) + + # Generate fingerprint if not provided + if fingerprint is None: + fingerprint = generate_device_fingerprint() + + _save_fingerprint(data, fingerprint, label) + + lib_logger.info( + f"Bound new device fingerprint for {email} (label={label}, " + f"ua={fingerprint.user_agent})" + ) + + return fingerprint + + +def _save_fingerprint( + data: CredentialDeviceData, fingerprint: DeviceFingerprint, label: str +) -> None: + """ + Internal helper to save a fingerprint to credential data. + """ + # Mark all existing versions as not current + for version in data.fingerprint_history: + version.is_current = False + + # Create new version + version = DeviceFingerprintVersion( + id=str(uuid.uuid4()), + created_at=int(time.time()), + label=label, + fingerprint=fingerprint, + is_current=True, + ) + + # Update data + data.current_fingerprint = fingerprint + data.fingerprint_history.append(version) + + # Also update legacy profile for backward compatibility + data.current_profile = fingerprint.to_legacy_profile() + + # Save + save_credential_device_data(data) + + +def get_fingerprint_history(email: str) -> List[DeviceFingerprintVersion]: + """ + Get the device fingerprint version history for a credential. + + Args: + email: Email address of the credential + + Returns: + List of DeviceFingerprintVersion entries + """ + data = load_credential_device_data(email) + return data.fingerprint_history if data else [] + + +def regenerate_fingerprint(email: str) -> DeviceFingerprint: + """ + Regenerate the device fingerprint for a credential. + + Call this to get a fresh identity (e.g., after rate limiting). + + Args: + email: Email address of the credential + + Returns: + New DeviceFingerprint + """ + return bind_new_fingerprint(email, label="regenerated") + + +# ============================================================================= +# LEGACY API (kept for backward compatibility) +# ============================================================================= + + +def get_or_create_device_profile( + email: str, auto_generate: bool = True +) -> Optional[DeviceProfile]: + """ + Get the current device profile for a credential. + + DEPRECATED: Use get_or_create_fingerprint() instead. + + Args: + email: Email address of the credential + auto_generate: If True and no profile exists, generate one + + Returns: + DeviceProfile if available, None otherwise + """ + fingerprint = get_or_create_fingerprint(email, auto_generate) + if fingerprint: + return fingerprint.to_legacy_profile() + return None + + +def generate_profile() -> DeviceProfile: + """ + Generate a new random device profile. + + DEPRECATED: Use generate_device_fingerprint() instead. + + Returns: + New DeviceProfile with random identifiers + """ + fingerprint = generate_device_fingerprint() + return fingerprint.to_legacy_profile() + + +def build_client_metadata( + profile: Optional[DeviceProfile] = None, + ide_type: str = "ANTIGRAVITY", + platform: str = "WINDOWS_AMD64", + plugin_type: str = "GEMINI", +) -> Dict[str, Any]: + """ + Build Client-Metadata dict with device profile information. + + DEPRECATED: Use build_fingerprint_headers() instead. + + Args: + profile: Optional DeviceProfile to include. + ide_type: IDE type identifier + platform: Platform identifier + plugin_type: Plugin type identifier + + Returns: + Client metadata dictionary + """ + metadata = { + "ideType": ide_type if profile else "IDE_UNSPECIFIED", + "platform": platform if profile else "PLATFORM_UNSPECIFIED", + "pluginType": plugin_type, + } + + if profile: + metadata["machineId"] = profile.machine_id + metadata["macMachineId"] = profile.mac_machine_id + metadata["devDeviceId"] = profile.dev_device_id + metadata["sqmId"] = profile.sqm_id + + return metadata + + +def build_client_metadata_header( + profile: Optional[DeviceProfile] = None, **kwargs +) -> str: + """ + Build Client-Metadata header value as JSON string. + + DEPRECATED: Use build_fingerprint_headers() instead. + + Args: + profile: Optional DeviceProfile to include + **kwargs: Additional arguments passed to build_client_metadata + + Returns: + JSON string for Client-Metadata header + """ + return json.dumps(build_client_metadata(profile, **kwargs)) diff --git a/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py b/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py index 3f86014f..80b54ff6 100644 --- a/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py +++ b/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py @@ -39,7 +39,7 @@ from .gemini_shared_utils import CODE_ASSIST_ENDPOINT if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") @@ -55,7 +55,8 @@ # Learned values (from file) override these defaults if available. DEFAULT_MAX_REQUESTS: Dict[str, Dict[str, int]] = { - "standard-tier": { + # Canonical tier names + "PRO": { # Pro group (verified: 0.4% per request = 250 requests) "gemini-2.5-pro": 250, "gemini-3-pro-preview": 250, @@ -67,7 +68,7 @@ # 3-Flash group (verified: ~0.0667% per request = 1500 requests) "gemini-3-flash-preview": 1500, }, - "free-tier": { + "FREE": { # Pro group (verified: 1.0% per request = 100 requests) "gemini-2.5-pro": 100, "gemini-3-pro-preview": 100, @@ -80,6 +81,10 @@ }, } +# Legacy tier name aliases (backwards compatibility) +DEFAULT_MAX_REQUESTS["standard-tier"] = DEFAULT_MAX_REQUESTS["PRO"] +DEFAULT_MAX_REQUESTS["free-tier"] = DEFAULT_MAX_REQUESTS["FREE"] + # Default max requests for unknown models (1% = 100 requests) DEFAULT_MAX_REQUESTS_UNKNOWN = 1000 diff --git a/src/rotator_library/providers/utilities/gemini_credential_manager.py b/src/rotator_library/providers/utilities/gemini_credential_manager.py index f8f4dfcc..8002b677 100644 --- a/src/rotator_library/providers/utilities/gemini_credential_manager.py +++ b/src/rotator_library/providers/utilities/gemini_credential_manager.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager lib_logger = logging.getLogger("rotator_library") @@ -60,12 +60,18 @@ def _load_tier_from_file(self, credential_path: str) -> Optional[str]: This is used as a fallback when the tier isn't in the memory cache, typically on first access before initialize_credentials() has run. + Also performs tier name migration: old tier names (e.g., "g1-pro-tier") + are normalized to canonical names (e.g., "PRO") and the file is updated. + Args: credential_path: Path to the credential file Returns: Tier string if found, None otherwise """ + # Import here to avoid circular imports + from .gemini_shared_utils import normalize_tier_name + # Skip env:// paths (environment-based credentials) if self._parse_env_credential_path(credential_path) is not None: return None @@ -79,6 +85,23 @@ def _load_tier_from_file(self, credential_path: str) -> Optional[str]: project_id = metadata.get("project_id") if tier: + # Migrate old tier names to canonical format + canonical_tier = normalize_tier_name(tier) + if canonical_tier and canonical_tier != tier: + # Tier name changed - update file and log migration + lib_logger.info( + f"Migrating tier '{tier}' -> '{canonical_tier}' for credential: {Path(credential_path).name}" + ) + creds["_proxy_metadata"]["tier"] = canonical_tier + try: + with open(credential_path, "w") as f: + json.dump(creds, f, indent=2) + except Exception as write_err: + lib_logger.warning( + f"Could not persist tier migration to {credential_path}: {write_err}" + ) + tier = canonical_tier + self.project_tier_cache[credential_path] = tier lib_logger.debug( f"Lazy-loaded tier '{tier}' for credential: {Path(credential_path).name}" @@ -148,10 +171,27 @@ async def initialize_credentials(self, credential_paths: List[str]) -> None: await self._discover_project_id( credential_path, access_token, litellm_params={} ) - discovered_tier = self.project_tier_cache.get( + # Use full tier name for discovery log (one-time display) + tier_full_cache = getattr(self, "tier_full_cache", {}) + tier_full = tier_full_cache.get(credential_path) + discovered_tier = tier_full or self.project_tier_cache.get( credential_path, "unknown" ) - lib_logger.debug( + lib_logger.info( + f"Discovered tier '{discovered_tier}' for {Path(credential_path).name}" + ) + except Exception as e: + lib_logger.warning( + f"Failed to discover tier for {Path(credential_path).name}: {e}. " + f"Credential will use default priority." + ) + # Use full tier name for discovery log (one-time display) + tier_full_cache = getattr(self, "tier_full_cache", {}) + tier_full = tier_full_cache.get(credential_path) + discovered_tier = tier_full or self.project_tier_cache.get( + credential_path, "unknown" + ) + lib_logger.info( f"Discovered tier '{discovered_tier}' for {Path(credential_path).name}" ) except Exception as e: @@ -193,6 +233,13 @@ async def _load_persisted_tiers( if tier: self.project_tier_cache[path] = tier loaded[path] = tier + + # Also load tier_full if available + tier_full = metadata.get("tier_full") + tier_full_cache = getattr(self, "tier_full_cache", None) + if tier_full and tier_full_cache is not None: + tier_full_cache[path] = tier_full + lib_logger.debug( f"Loaded persisted tier '{tier}' for credential: {Path(path).name}" ) @@ -267,20 +314,27 @@ async def run_background_job( f"{provider_name}: Fetching initial quota baselines for {len(credentials)} credentials..." ) quota_results = await self.fetch_initial_baselines(credentials) + is_initial_fetch = True self._initial_quota_fetch_done = True else: # Subsequent runs: only recently used credentials (incremental updates) - usage_data = await usage_manager._get_usage_data_snapshot() + usage_data = await usage_manager.get_usage_snapshot() quota_results = await self.refresh_active_quota_baselines( credentials, usage_data ) + is_initial_fetch = False if not quota_results: return # Store new baselines in UsageManager + # On initial fetch: force=True overwrites with API data, is_initial_fetch enables exhaustion check + # On subsequent: force=False uses max logic, no exhaustion check stored = await self._store_baselines_to_usage_manager( - quota_results, usage_manager + quota_results, + usage_manager, + force=is_initial_fetch, # Force on initial fetch + is_initial_fetch=is_initial_fetch, ) if stored > 0: lib_logger.debug( @@ -320,7 +374,11 @@ async def refresh_active_quota_baselines( ) async def _store_baselines_to_usage_manager( - self, quota_results: Dict[str, Any], usage_manager: "UsageManager" + self, + quota_results: Dict[str, Any], + usage_manager: "UsageManager", + force: bool = False, + is_initial_fetch: bool = False, ) -> int: """Store quota baselines to usage manager. Must be implemented by quota tracker.""" raise NotImplementedError( diff --git a/src/rotator_library/providers/utilities/gemini_shared_utils.py b/src/rotator_library/providers/utilities/gemini_shared_utils.py index 6f0fe112..e1ca00f6 100644 --- a/src/rotator_library/providers/utilities/gemini_shared_utils.py +++ b/src/rotator_library/providers/utilities/gemini_shared_utils.py @@ -376,3 +376,382 @@ def recursively_parse_json_strings( except (json.JSONDecodeError, ValueError): pass return obj + + +# ============================================================================= +# TIER NAMING AND PRIORITY CONSTANTS +# ============================================================================= +# Shared tier handling for Google/Gemini-based providers (Gemini CLI, Antigravity) +# +# Canonical tier names are uppercase: ULTRA, PRO, FREE +# API returns various formats: g1-pro-tier, standard-tier, free-tier, etc. +# This module normalizes all tier names and provides priority ordering. + +# Canonical tier names (short) +TIER_ULTRA = "ULTRA" +TIER_PRO = "PRO" +TIER_FREE = "FREE" + +# Full tier names for display (based on tier source/subscription type) +# Used for one-time displays like credential discovery logging +TIER_ID_TO_FULL_NAME: Dict[str, str] = { + # Google One AI subscription tiers (from paidTier API response) + "g1-pro-tier": "Google One AI PRO", + "g1-ultra-tier": "Google One AI ULTRA", + "g1-free-tier": TIER_FREE, # Free tiers are just "FREE" + # Gemini Code Assist subscription tiers + "gemini-code-assist-pro": "Code Assist PRO", + "gemini-code-assist-ultra": "Code Assist ULTRA", + "gemini-code-assist-free": TIER_FREE, + # Legacy/standard tier names (no special prefix) + "standard-tier": TIER_PRO, + "pro-tier": TIER_PRO, + "ultra-tier": TIER_ULTRA, + "enterprise-tier": TIER_ULTRA, + "free-tier": TIER_FREE, + "legacy-tier": TIER_FREE, + # Already canonical - return as-is + TIER_FREE: TIER_FREE, + TIER_PRO: TIER_PRO, + TIER_ULTRA: TIER_ULTRA, +} + +# Mapping from API/legacy tier names to canonical names +# Handles all known tier name formats from various API responses +TIER_NAME_TO_CANONICAL: Dict[str, str] = { + # Legacy Python names + "free-tier": TIER_FREE, + "legacy-tier": TIER_FREE, # Legacy is treated as free + "standard-tier": TIER_PRO, + "pro-tier": TIER_PRO, + "ultra-tier": TIER_ULTRA, + "enterprise-tier": TIER_ULTRA, + # Google One AI tier names (from paidTier API response) + "g1-pro-tier": TIER_PRO, + "g1-ultra-tier": TIER_ULTRA, + "g1-free-tier": TIER_FREE, + # Gemini Code Assist tier names + "gemini-code-assist-pro": TIER_PRO, + "gemini-code-assist-ultra": TIER_ULTRA, + "gemini-code-assist-free": TIER_FREE, + # Already canonical (uppercase) + TIER_FREE: TIER_FREE, + TIER_PRO: TIER_PRO, + TIER_ULTRA: TIER_ULTRA, +} + +# Reverse mapping for backwards compatibility (canonical -> legacy) +CANONICAL_TO_LEGACY: Dict[str, str] = { + TIER_FREE: "free-tier", + TIER_PRO: "standard-tier", + TIER_ULTRA: "enterprise-tier", +} + +# Free tier identifiers (all naming conventions) +FREE_TIER_IDS: set = {TIER_FREE, "free-tier", "legacy-tier", "g1-free-tier"} + +# Tier priorities for credential selection (lower number = higher priority) +# ULTRA (Google One AI Premium) > PRO (Google One AI / Paid) > FREE +TIER_PRIORITIES: Dict[str, int] = { + # Canonical names + TIER_ULTRA: 1, # Highest priority - Google One AI Premium + TIER_PRO: 2, # Standard paid tier - Google One AI + TIER_FREE: 3, # Free tier + # API/legacy names mapped to same priorities for backwards compatibility + "g1-ultra-tier": 1, + "g1-pro-tier": 2, + "standard-tier": 2, + "free-tier": 3, + "legacy-tier": 10, # Legacy/unknown treated as lowest + "unknown": 10, +} + +# Default priority for tiers not in the mapping +DEFAULT_TIER_PRIORITY: int = 10 + + +# ============================================================================= +# TIER HELPER FUNCTIONS +# ============================================================================= + + +def normalize_tier_name(tier_id: Optional[str]) -> Optional[str]: + """ + Normalize tier name to canonical format (ULTRA, PRO, FREE). + + Supports all tier name formats: + - Legacy Python names: free-tier, standard-tier, legacy-tier + - Google One AI names: g1-pro-tier, g1-ultra-tier + - Gemini Code Assist names: gemini-code-assist-pro + - Already canonical: FREE, PRO, ULTRA + + Args: + tier_id: Tier identifier from API response or config + + Returns: + Canonical tier name (ULTRA, PRO, FREE) or original if unknown + """ + if not tier_id: + return None + return TIER_NAME_TO_CANONICAL.get(tier_id, tier_id) + + +def is_free_tier(tier_id: Optional[str]) -> bool: + """ + Check if tier is a free tier (any naming convention). + + Args: + tier_id: Tier identifier to check + + Returns: + True if tier is free, False otherwise + """ + if not tier_id: + return False + return tier_id in FREE_TIER_IDS or normalize_tier_name(tier_id) == TIER_FREE + + +def is_paid_tier(tier_id: Optional[str]) -> bool: + """ + Check if tier is a paid tier (PRO or ULTRA). + + Args: + tier_id: Tier identifier to check + + Returns: + True if tier is paid (PRO or ULTRA), False otherwise + """ + if not tier_id or tier_id == "unknown": + return False + canonical = normalize_tier_name(tier_id) + return canonical in (TIER_PRO, TIER_ULTRA) + + +def get_tier_priority(tier_id: Optional[str]) -> int: + """ + Get priority for a tier (lower number = higher priority). + + Priority order: ULTRA (1) > PRO (2) > FREE (3) > unknown (10) + + Args: + tier_id: Tier identifier + + Returns: + Priority number (1-10), lower is better + """ + if not tier_id: + return DEFAULT_TIER_PRIORITY + # Try direct lookup first (handles both canonical and API names) + if tier_id in TIER_PRIORITIES: + return TIER_PRIORITIES[tier_id] + # Normalize and try again + canonical = normalize_tier_name(tier_id) + return TIER_PRIORITIES.get(canonical, DEFAULT_TIER_PRIORITY) + + +def format_tier_for_display(tier_id: Optional[str]) -> str: + """ + Format tier name for display (lowercase canonical). + + Args: + tier_id: Tier identifier + + Returns: + Display-friendly tier name: "ultra", "pro", "free", or "unknown" + """ + if not tier_id: + return "unknown" + canonical = normalize_tier_name(tier_id) + if canonical in (TIER_ULTRA, TIER_PRO, TIER_FREE): + return canonical.lower() + return "unknown" + + +def get_tier_full_name(tier_id: Optional[str]) -> str: + """ + Get the full/descriptive tier name for display. + + Used for one-time displays like credential discovery logging where + we want to show the subscription source (e.g., "Google One AI PRO"). + + Args: + tier_id: Original tier identifier from API response (e.g., "g1-pro-tier") + + Returns: + Full tier name (e.g., "Google One AI PRO") or canonical short name as fallback + """ + if not tier_id: + return "unknown" + # Try direct lookup for full name + if tier_id in TIER_ID_TO_FULL_NAME: + return TIER_ID_TO_FULL_NAME[tier_id] + # Fallback to canonical short name + canonical = normalize_tier_name(tier_id) + return canonical if canonical else tier_id + + +# ============================================================================= +# PROJECT ID EXTRACTION +# ============================================================================= + + +def extract_project_id_from_response( + data: Dict[str, Any], key: str = "cloudaicompanionProject" +) -> Optional[str]: + """ + Extract project ID from API response, handling both string and object formats. + + The API may return cloudaicompanionProject as either: + - A string: "project-id-123" + - An object: {"id": "project-id-123", ...} + + Args: + data: API response data + key: Key to extract from (default: "cloudaicompanionProject") + + Returns: + Project ID string or None if not found + """ + value = data.get(key) + if isinstance(value, str) and value: + return value + if isinstance(value, dict): + return value.get("id") + return None + + +# ============================================================================= +# CREDENTIAL LOADING HELPERS +# ============================================================================= + + +def load_persisted_project_metadata( + credential_path: str, + credential_index: Optional[int], + credentials_cache: Dict[str, Any], + project_id_cache: Dict[str, str], + project_tier_cache: Dict[str, str], + tier_full_cache: Optional[Dict[str, str]] = None, +) -> Optional[str]: + """ + Load persisted project_id and tier from credential file or env cache. + + This helper handles the common pattern of checking for already-persisted + project metadata in both file-based and env-based credentials. + + Args: + credential_path: Path to the credential (file path or env:// path) + credential_index: Result of _parse_env_credential_path() - None for file, int for env + credentials_cache: Dict of loaded credentials (for env-based) + project_id_cache: Dict to populate with project_id + project_tier_cache: Dict to populate with tier + tier_full_cache: Optional dict to populate with tier_full (full display name) + + Returns: + Project ID if found and cached, None otherwise (caller should do discovery) + """ + if credential_index is None: + # File-based credentials: load from file + try: + with open(credential_path, "r") as f: + creds = json.load(f) + + metadata = creds.get("_proxy_metadata", {}) + persisted_project_id = metadata.get("project_id") + persisted_tier = metadata.get("tier") + persisted_tier_full = metadata.get("tier_full") + + if persisted_project_id: + lib_logger.debug( + f"Loaded persisted project ID from credential file: {persisted_project_id}" + ) + project_id_cache[credential_path] = persisted_project_id + + # Also load tier if available + if persisted_tier: + project_tier_cache[credential_path] = persisted_tier + lib_logger.debug(f"Loaded persisted tier: {persisted_tier}") + + # Load tier_full if available and cache provided + if persisted_tier_full and tier_full_cache is not None: + tier_full_cache[credential_path] = persisted_tier_full + lib_logger.debug( + f"Loaded persisted tier_full: {persisted_tier_full}" + ) + + return persisted_project_id + except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: + lib_logger.debug(f"Could not load persisted project ID from file: {e}") + else: + # Env-based credentials: load from credentials cache + # The credentials were already loaded by _load_from_env() which reads + # {PREFIX}_{N}_PROJECT_ID and {PREFIX}_{N}_TIER into _proxy_metadata + if credential_path in credentials_cache: + creds = credentials_cache[credential_path] + metadata = creds.get("_proxy_metadata", {}) + env_project_id = metadata.get("project_id") + env_tier = metadata.get("tier") + env_tier_full = metadata.get("tier_full") + + if env_project_id: + lib_logger.debug( + f"Loaded project ID from env credential metadata: {env_project_id}" + ) + project_id_cache[credential_path] = env_project_id + + if env_tier: + project_tier_cache[credential_path] = env_tier + lib_logger.debug( + f"Loaded tier from env credential metadata: {env_tier}" + ) + + # Load tier_full if available and cache provided + if env_tier_full and tier_full_cache is not None: + tier_full_cache[credential_path] = env_tier_full + lib_logger.debug( + f"Loaded tier_full from env credential metadata: {env_tier_full}" + ) + + return env_project_id + + return None + + +# ============================================================================= +# ENV FILE HELPERS +# ============================================================================= + + +def build_project_tier_env_lines( + creds: Dict[str, Any], env_prefix: str, cred_number: int +) -> List[str]: + """ + Build env lines for project_id and tier from credential metadata. + + Used by Google OAuth providers (Gemini CLI, Antigravity) to generate + environment variable lines for project and tier information. + + Args: + creds: Credential dict containing _proxy_metadata + env_prefix: Environment variable prefix (e.g., "GEMINI_CLI", "ANTIGRAVITY") + cred_number: Credential number for env var naming + + Returns: + List of env lines like ["PREFIX_N_PROJECT_ID=...", "PREFIX_N_TIER=..."] + """ + lines = [] + metadata = creds.get("_proxy_metadata", {}) + prefix = f"{env_prefix}_{cred_number}" + + project_id = metadata.get("project_id", "") + tier = metadata.get("tier", "") + tier_full = metadata.get("tier_full", "") + + if project_id: + lines.append(f"{prefix}_PROJECT_ID={project_id}") + if tier: + lines.append(f"{prefix}_TIER={tier}") + if tier_full: + lines.append(f"{prefix}_TIER_FULL={tier_full}") + + return lines diff --git a/src/rotator_library/pyproject.toml b/src/rotator_library/pyproject.toml index 4c051069..85f9bffb 100644 --- a/src/rotator_library/pyproject.toml +++ b/src/rotator_library/pyproject.toml @@ -4,14 +4,14 @@ build-backend = "setuptools.build_meta" [project] name = "rotator_library" -version = "1.15" +version = "1.7" authors = [ { name="Mirrowel", email="mirrowel-github.appraiser015@aleeas.com" }, ] description = "A robust Python client for intelligent API key rotation and retry logic, leveraging LiteLLM. It manages usage, handles various API errors (rate limits, server errors, authentication), and supports dynamic model discovery across multiple LLM providers." license = "LGPL-3.0-only" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", diff --git a/src/rotator_library/usage/__init__.py b/src/rotator_library/usage/__init__.py new file mode 100644 index 00000000..f8ae947f --- /dev/null +++ b/src/rotator_library/usage/__init__.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage tracking and credential selection package. + +This package provides the UsageManager facade and associated components +for tracking API usage, enforcing limits, and selecting credentials. + +Public API: + UsageManager: Main facade for usage tracking and credential selection + CredentialContext: Context manager for credential lifecycle + +Components (for advanced usage): + CredentialRegistry: Stable credential identity management + TrackingEngine: Usage recording and window management + LimitEngine: Limit checking and enforcement + SelectionEngine: Credential selection with strategies + UsageStorage: JSON file persistence +""" + +# Types first (no dependencies on other modules) +from .types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + CooldownInfo, + FairCycleState, + UsageUpdate, + SelectionContext, + LimitCheckResult, + RotationMode, + ResetMode, + LimitResult, +) + +# Config +from .config import ( + ProviderUsageConfig, + FairCycleConfig, + CustomCapConfig, + WindowDefinition, + load_provider_usage_config, +) + +# Components +from .identity.registry import CredentialRegistry +from .tracking.windows import WindowManager +from .tracking.engine import TrackingEngine +from .limits.engine import LimitEngine +from .selection.engine import SelectionEngine +from .persistence.storage import UsageStorage +from .integration.api import UsageAPI + +# Main facade (imports components above) +from .manager import UsageManager, CredentialContext + +__all__ = [ + # Main public API + "UsageManager", + "CredentialContext", + # Types + "WindowStats", + "TotalStats", + "ModelStats", + "GroupStats", + "CredentialState", + "UsageUpdate", + "CooldownInfo", + "FairCycleState", + "SelectionContext", + "LimitCheckResult", + "RotationMode", + "ResetMode", + "LimitResult", + # Config + "ProviderUsageConfig", + "FairCycleConfig", + "CustomCapConfig", + "WindowDefinition", + "load_provider_usage_config", + # Engines + "CredentialRegistry", + "WindowManager", + "TrackingEngine", + "LimitEngine", + "SelectionEngine", + "UsageStorage", + "UsageAPI", +] diff --git a/src/rotator_library/usage/config.py b/src/rotator_library/usage/config.py new file mode 100644 index 00000000..de6cfab1 --- /dev/null +++ b/src/rotator_library/usage/config.py @@ -0,0 +1,853 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Default configurations for the usage tracking package. + +This module contains default values and configuration loading +for usage tracking, limits, and credential selection. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from ..core.constants import ( + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD, + DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, +) +from .types import ResetMode, RotationMode, TrackingMode, CooldownMode, CapMode + + +# ============================================================================= +# WINDOW CONFIGURATION +# ============================================================================= + + +@dataclass +class WindowDefinition: + """ + Definition of a usage tracking window. + + Used to configure how usage is tracked and when it resets. + """ + + name: str # e.g., "5h", "daily", "weekly" + duration_seconds: Optional[int] # None for infinite/total + reset_mode: ResetMode + is_primary: bool = False # Primary window used for rotation decisions + applies_to: str = "model" # "credential", "model", "group" + + @classmethod + def rolling( + cls, + name: str, + duration_seconds: int, + is_primary: bool = False, + applies_to: str = "model", + ) -> "WindowDefinition": + """Create a rolling window definition.""" + return cls( + name=name, + duration_seconds=duration_seconds, + reset_mode=ResetMode.ROLLING, + is_primary=is_primary, + applies_to=applies_to, + ) + + @classmethod + def daily( + cls, + name: str = "daily", + applies_to: str = "model", + ) -> "WindowDefinition": + """Create a daily fixed window definition.""" + return cls( + name=name, + duration_seconds=86400, + reset_mode=ResetMode.FIXED_DAILY, + applies_to=applies_to, + ) + + +# ============================================================================= +# FAIR CYCLE CONFIGURATION +# ============================================================================= + + +@dataclass +class FairCycleConfig: + """ + Fair cycle rotation configuration. + + Controls how credentials are cycled to ensure fair usage distribution. + """ + + enabled: Optional[bool] = ( + None # None = derive from rotation mode (on for sequential) + ) + tracking_mode: TrackingMode = TrackingMode.MODEL_GROUP + cross_tier: bool = False # Track across all tiers + duration: int = DEFAULT_FAIR_CYCLE_DURATION # Cycle duration in seconds + quota_threshold: float = ( + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD # Multiplier of window limit for exhaustion + ) + reset_cooldown_threshold: int = ( + DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD # Min cooldown to count for reset + ) + + +# ============================================================================= +# CUSTOM CAP CONFIGURATION +# ============================================================================= + + +def _parse_duration_string(duration_str: str) -> Optional[int]: + """ + Parse duration strings in various formats to total seconds. + + Handles: + - Plain seconds (no unit): '300', '562476' + - Simple durations: '3600s', '60m', '2h', '1d' + - Compound durations: '2h30m', '1h30m45s', '2d1h30m' + + Args: + duration_str: Duration string to parse + + Returns: + Total seconds as integer, or None if parsing fails. + """ + import re + + if not duration_str: + return None + + remaining = duration_str.strip().lower() + + # Try parsing as plain number first (no units) + try: + return int(float(remaining)) + except ValueError: + pass + + total_seconds = 0.0 + + # Parse days component + day_match = re.match(r"(\d+)d", remaining) + if day_match: + total_seconds += int(day_match.group(1)) * 86400 + remaining = remaining[day_match.end() :] + + # Parse hours component + hour_match = re.match(r"(\d+)h", remaining) + if hour_match: + total_seconds += int(hour_match.group(1)) * 3600 + remaining = remaining[hour_match.end() :] + + # Parse minutes component - use negative lookahead to avoid matching 'ms' + min_match = re.match(r"(\d+)m(?!s)", remaining) + if min_match: + total_seconds += int(min_match.group(1)) * 60 + remaining = remaining[min_match.end() :] + + # Parse seconds component (including decimals) + sec_match = re.match(r"([\d.]+)s", remaining) + if sec_match: + total_seconds += float(sec_match.group(1)) + + if total_seconds > 0: + return int(total_seconds) + return None + + +def _parse_cooldown_config( + mode: Optional[str], + value: Any, +) -> Tuple[CooldownMode, int]: + """ + Parse cooldown configuration from config dict values. + + Supports comprehensive cooldown_value parsing: + - Flat duration: 300, "300", "1h", "30m", "1h30m", "2d1h30m" → fixed seconds + - Offset with sign: "+300", "+1h30m", "-300", "-5m" → offset from natural reset + - Percentage: "+50%", "-20%" → percentage of window duration as offset + (stored as negative value with special encoding: -1000 - percentage) + - String "quota_reset" → use natural reset time + + The cooldown_mode is auto-detected from the value format if not explicitly set: + - Starts with '+' or '-' → CooldownMode.OFFSET + - Just a duration → CooldownMode.FIXED + - "quota_reset" string → CooldownMode.QUOTA_RESET + + Args: + mode: Explicit cooldown mode string, or None to auto-detect + value: Cooldown value (int, str, or various formats) + + Returns: + Tuple of (CooldownMode, cooldown_value in seconds) + """ + # Handle explicit mode with simple value + if mode is not None: + try: + cooldown_mode = CooldownMode(mode) + except ValueError: + cooldown_mode = CooldownMode.QUOTA_RESET + + # Parse value + if isinstance(value, int): + return cooldown_mode, value + elif isinstance(value, str): + parsed = _parse_duration_string(value.lstrip("+-")) + return cooldown_mode, parsed or 0 + else: + return cooldown_mode, 0 + + # Auto-detect mode from value format + if isinstance(value, int): + if value == 0: + return CooldownMode.QUOTA_RESET, 0 + return CooldownMode.FIXED, value + + if isinstance(value, str): + value = value.strip() + + # Check for "quota_reset" string + if value.lower() in ("quota_reset", "quota-reset", "quotareset"): + return CooldownMode.QUOTA_RESET, 0 + + # Check for percentage format: "+50%", "-20%" + if value.endswith("%"): + sign = 1 + val_str = value.rstrip("%") + if val_str.startswith("+"): + val_str = val_str[1:] + elif val_str.startswith("-"): + sign = -1 + val_str = val_str[1:] + try: + percentage = int(val_str) + # Encode percentage as special value: -1000 - (sign * percentage) + # This allows the custom_caps checker to detect and handle percentages + # Range: -1001 to -1100 for +1% to +100%, -999 to -900 for -1% to -100% + encoded = -1000 - (sign * percentage) + return CooldownMode.OFFSET, encoded + except ValueError: + pass + + # Check for offset format: "+300", "+1h30m", "-5m" + if value.startswith("+") or value.startswith("-"): + sign = 1 if value.startswith("+") else -1 + duration_str = value[1:] + parsed = _parse_duration_string(duration_str) + if parsed is not None: + return CooldownMode.OFFSET, sign * parsed + return CooldownMode.OFFSET, 0 + + # Plain duration: "300", "1h30m" + parsed = _parse_duration_string(value) + if parsed is not None: + return CooldownMode.FIXED, parsed + + return CooldownMode.QUOTA_RESET, 0 + + +import logging + +_config_logger = logging.getLogger("rotator_library") + + +def _parse_max_requests( + raw_value: Any, tier_key: str, model_or_group: str +) -> Optional[Tuple[int, "CapMode"]]: + """ + Parse max_requests value and determine its mode. + + Formats supported: + - 130 (int) → ABSOLUTE, 130 + - "130" → ABSOLUTE, 130 + - "-130" → OFFSET, -130 (means max - 130) + - "+130" → OFFSET, +130 (means max + 130) + - "80%" → PERCENTAGE, 80 + + Returns: + Tuple of (value, mode) or None if invalid (logs error). + """ + # Handle None or missing + if raw_value is None: + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + "max_requests is None, skipping cap" + ) + return None + + # Already an int + if isinstance(raw_value, int): + return (raw_value, CapMode.ABSOLUTE) + + # Float - convert to int + if isinstance(raw_value, float): + return (int(raw_value), CapMode.ABSOLUTE) + + # Must be a string from here + if not isinstance(raw_value, str): + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + f"max_requests has invalid type {type(raw_value).__name__}, skipping cap" + ) + return None + + # Strip whitespace + value_str = raw_value.strip() + + # Empty string is invalid + if not value_str: + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + "max_requests is empty string, skipping cap" + ) + return None + + # Percentage format: "80%" + if value_str.endswith("%"): + try: + percentage = int(value_str.rstrip("%")) + if percentage < 0 or percentage > 100: + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + f"percentage {percentage}% out of range (0-100), skipping cap" + ) + return None + return (percentage, CapMode.PERCENTAGE) + except ValueError: + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + f"invalid percentage '{value_str}', skipping cap" + ) + return None + + # Offset format: "+130" or "-130" + if value_str.startswith("+") or value_str.startswith("-"): + try: + offset = int(value_str) + return (offset, CapMode.OFFSET) + except ValueError: + # Try float conversion + try: + offset = int(float(value_str)) + return (offset, CapMode.OFFSET) + except ValueError: + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + f"invalid offset '{value_str}', skipping cap" + ) + return None + + # Absolute format: plain number string "130" + try: + value = int(value_str) + return (value, CapMode.ABSOLUTE) + except ValueError: + # Try float conversion + try: + value = int(float(value_str)) + return (value, CapMode.ABSOLUTE) + except ValueError: + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + f"invalid value '{value_str}', skipping cap" + ) + return None + + +@dataclass +class CustomCapConfig: + """ + Custom cap configuration for a tier/model combination. + + Allows setting usage limits that can be absolute, offset from API limits, + or percentage of API limits. + """ + + tier_key: str # Priority as string or "default" + model_or_group: str # Model name or quota group name + max_requests: int # The numeric value + max_requests_mode: CapMode = CapMode.ABSOLUTE # How to interpret max_requests + cooldown_mode: CooldownMode = CooldownMode.QUOTA_RESET + cooldown_value: int = 0 # Seconds for offset/fixed modes + + @classmethod + def from_dict( + cls, tier_key: str, model_or_group: str, config: Dict[str, Any] + ) -> Optional["CustomCapConfig"]: + """ + Create from dictionary config. + + max_requests formats: + - 130 or "130" → ABSOLUTE mode, exactly 130 requests + - "-130" → OFFSET mode, max - 130 requests + - "+130" → OFFSET mode, max + 130 requests + - "80%" → PERCENTAGE mode, 80% of max requests + + cooldown_value formats: + - Flat duration: 300, "300", "1h", "30m", "1h30m", "2d1h30m" → fixed seconds + - Offset with sign: "+300", "+1h30m", "-300", "-5m" → offset from natural reset + - String "quota_reset" → use natural reset time + + Returns: + CustomCapConfig instance, or None if max_requests is invalid. + """ + raw_max_requests = config.get("max_requests") + + # Check if mode is already explicitly provided (for round-trip serialization) + explicit_mode = config.get("max_requests_mode") + if explicit_mode is not None: + # Mode was explicitly provided - use it directly + try: + if isinstance(explicit_mode, CapMode): + max_requests_mode = explicit_mode + else: + max_requests_mode = CapMode(explicit_mode) + # Still need to validate max_requests is a valid number + if isinstance(raw_max_requests, int): + max_requests_value = raw_max_requests + elif isinstance(raw_max_requests, float): + max_requests_value = int(raw_max_requests) + elif isinstance(raw_max_requests, str): + try: + max_requests_value = int( + float(raw_max_requests.lstrip("+-").rstrip("%")) + ) + except ValueError: + _config_logger.error( + f"Custom cap for tier={tier_key} model={model_or_group}: " + f"invalid max_requests value '{raw_max_requests}', skipping cap" + ) + return None + else: + max_requests_value = 0 + except ValueError: + # Invalid mode string, fall through to parsing + explicit_mode = None + + if explicit_mode is None: + # Parse max_requests with mode detection + parsed = _parse_max_requests(raw_max_requests, tier_key, model_or_group) + if parsed is None: + return None + max_requests_value, max_requests_mode = parsed + + # Parse cooldown configuration + cooldown_mode, cooldown_value = _parse_cooldown_config( + config.get("cooldown_mode"), + config.get("cooldown_value", 0), + ) + + return cls( + tier_key=tier_key, + model_or_group=model_or_group, + max_requests=max_requests_value, + max_requests_mode=max_requests_mode, + cooldown_mode=cooldown_mode, + cooldown_value=cooldown_value, + ) + + +# ============================================================================= +# PROVIDER USAGE CONFIG +# ============================================================================= + + +@dataclass +class ProviderUsageConfig: + """ + Complete usage configuration for a provider. + + Combines all settings needed for usage tracking and credential selection. + """ + + # Rotation settings + rotation_mode: RotationMode = RotationMode.BALANCED + rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE + sequential_fallback_multiplier: int = DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER + + # Priority multipliers (priority -> max concurrent) + priority_multipliers: Dict[int, int] = field(default_factory=dict) + priority_multipliers_by_mode: Dict[str, Dict[int, int]] = field( + default_factory=dict + ) + + # Fair cycle + fair_cycle: FairCycleConfig = field(default_factory=FairCycleConfig) + + # Custom caps + custom_caps: List[CustomCapConfig] = field(default_factory=list) + + # Exhaustion threshold (cooldown must exceed this to count as "exhausted") + exhaustion_cooldown_threshold: int = DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD + + # Window limits blocking (if True, block credentials when window quota exhausted locally) + # Default False: only API errors (cooldowns) should block, not local tracking + window_limits_enabled: bool = False + + # Window definitions + windows: List[WindowDefinition] = field(default_factory=list) + + def get_effective_multiplier(self, priority: int) -> int: + """ + Get the effective multiplier for a priority level. + + Checks mode-specific overrides first, then universal multipliers, + then falls back to sequential_fallback_multiplier. + """ + mode_key = self.rotation_mode.value + mode_multipliers = self.priority_multipliers_by_mode.get(mode_key, {}) + + # Check mode-specific first + if priority in mode_multipliers: + return mode_multipliers[priority] + + # Check universal + if priority in self.priority_multipliers: + return self.priority_multipliers[priority] + + # Fall back + return self.sequential_fallback_multiplier + + +# ============================================================================= +# DEFAULT WINDOWS +# ============================================================================= + + +def get_default_windows() -> List[WindowDefinition]: + """ + Get default window definitions. + + Only used when provider doesn't define custom windows via + usage_reset_configs or get_usage_reset_config(). + """ + return [ + WindowDefinition.rolling("daily", 86400, is_primary=True, applies_to="model"), + ] + + +# ============================================================================= +# CONFIG LOADER INTEGRATION +# ============================================================================= + + +def load_provider_usage_config( + provider: str, + provider_plugins: Dict[str, Any], +) -> ProviderUsageConfig: + """ + Load usage configuration for a provider. + + Merges: + 1. System defaults + 2. Provider class attributes + 3. Environment variables (always win) + + Args: + provider: Provider name (e.g., "gemini", "openai") + provider_plugins: Dict of provider plugin classes + + Returns: + Complete configuration for the provider + """ + import os + + config = ProviderUsageConfig() + + # Get plugin class + plugin_class = provider_plugins.get(provider) + + # Apply provider defaults + if plugin_class: + # Rotation mode + if hasattr(plugin_class, "default_rotation_mode"): + config.rotation_mode = RotationMode(plugin_class.default_rotation_mode) + + # Priority multipliers + if hasattr(plugin_class, "default_priority_multipliers"): + config.priority_multipliers = dict( + plugin_class.default_priority_multipliers + ) + + if hasattr(plugin_class, "default_priority_multipliers_by_mode"): + config.priority_multipliers_by_mode = { + k: dict(v) + for k, v in plugin_class.default_priority_multipliers_by_mode.items() + } + + # Sequential fallback multiplier + if hasattr(plugin_class, "default_sequential_fallback_multiplier"): + fallback = plugin_class.default_sequential_fallback_multiplier + if fallback is not None: + config.sequential_fallback_multiplier = fallback + + # Fair cycle + if hasattr(plugin_class, "default_fair_cycle_config"): + fc_config = plugin_class.default_fair_cycle_config + config.fair_cycle = FairCycleConfig( + enabled=fc_config.get("enabled"), + tracking_mode=TrackingMode( + fc_config.get("tracking_mode", "model_group") + ), + cross_tier=fc_config.get("cross_tier", False), + duration=fc_config.get("duration", DEFAULT_FAIR_CYCLE_DURATION), + quota_threshold=fc_config.get( + "quota_threshold", DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD + ), + reset_cooldown_threshold=fc_config.get( + "reset_cooldown_threshold", + DEFAULT_FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD, + ), + ) + else: + if hasattr(plugin_class, "default_fair_cycle_enabled"): + config.fair_cycle.enabled = plugin_class.default_fair_cycle_enabled + if hasattr(plugin_class, "default_fair_cycle_tracking_mode"): + config.fair_cycle.tracking_mode = TrackingMode( + plugin_class.default_fair_cycle_tracking_mode + ) + if hasattr(plugin_class, "default_fair_cycle_cross_tier"): + config.fair_cycle.cross_tier = ( + plugin_class.default_fair_cycle_cross_tier + ) + if hasattr(plugin_class, "default_fair_cycle_duration"): + config.fair_cycle.duration = plugin_class.default_fair_cycle_duration + if hasattr(plugin_class, "default_fair_cycle_quota_threshold"): + config.fair_cycle.quota_threshold = ( + plugin_class.default_fair_cycle_quota_threshold + ) + + # Custom caps + if hasattr(plugin_class, "default_custom_caps"): + for tier_key, models in plugin_class.default_custom_caps.items(): + tier_keys: Tuple[Union[int, str], ...] + if isinstance(tier_key, tuple): + tier_keys = tuple(tier_key) + else: + tier_keys = (tier_key,) + for model_or_group, cap_config in models.items(): + for resolved_tier in tier_keys: + cap = CustomCapConfig.from_dict( + str(resolved_tier), model_or_group, cap_config + ) + if cap is not None: + config.custom_caps.append(cap) + if cap is not None: + config.custom_caps.append(cap) + + # Windows + if hasattr(plugin_class, "usage_window_definitions"): + config.windows = [] + for wdef in plugin_class.usage_window_definitions: + config.windows.append( + WindowDefinition( + name=wdef.get("name", "default"), + duration_seconds=wdef.get("duration_seconds"), + reset_mode=ResetMode(wdef.get("reset_mode", "rolling")), + is_primary=wdef.get("is_primary", False), + applies_to=wdef.get("applies_to", "model"), + ) + ) + + # Use default windows if none defined + if not config.windows: + config.windows = get_default_windows() + + # Apply environment variable overrides + provider_upper = provider.upper() + + # Rotation mode from env + env_mode = os.getenv(f"ROTATION_MODE_{provider_upper}") + if env_mode: + config.rotation_mode = RotationMode(env_mode.lower()) + + # Sequential fallback multiplier + env_fallback = os.getenv(f"SEQUENTIAL_FALLBACK_MULTIPLIER_{provider_upper}") + if env_fallback: + try: + config.sequential_fallback_multiplier = int(env_fallback) + except ValueError: + pass + + # Fair cycle enabled from env + env_fc = os.getenv(f"FAIR_CYCLE_{provider_upper}") + if env_fc is None: + env_fc = os.getenv(f"FAIR_CYCLE_ENABLED_{provider_upper}") + if env_fc: + config.fair_cycle.enabled = env_fc.lower() in ("true", "1", "yes") + + # Fair cycle tracking mode + env_fc_mode = os.getenv(f"FAIR_CYCLE_TRACKING_MODE_{provider_upper}") + if env_fc_mode: + try: + config.fair_cycle.tracking_mode = TrackingMode(env_fc_mode.lower()) + except ValueError: + pass + + # Fair cycle cross-tier + env_fc_cross = os.getenv(f"FAIR_CYCLE_CROSS_TIER_{provider_upper}") + if env_fc_cross: + config.fair_cycle.cross_tier = env_fc_cross.lower() in ("true", "1", "yes") + + # Fair cycle duration from env + env_fc_duration = os.getenv(f"FAIR_CYCLE_DURATION_{provider_upper}") + if env_fc_duration: + try: + config.fair_cycle.duration = int(env_fc_duration) + except ValueError: + pass + + # Fair cycle quota threshold from env + env_fc_quota = os.getenv(f"FAIR_CYCLE_QUOTA_THRESHOLD_{provider_upper}") + if env_fc_quota: + try: + config.fair_cycle.quota_threshold = float(env_fc_quota) + except ValueError: + pass + + # Fair cycle reset cooldown threshold from env + env_fc_reset_cd = os.getenv(f"FAIR_CYCLE_RESET_COOLDOWN_THRESHOLD_{provider_upper}") + if env_fc_reset_cd: + try: + config.fair_cycle.reset_cooldown_threshold = int(env_fc_reset_cd) + except ValueError: + pass + + # Exhaustion threshold from env + env_threshold = os.getenv(f"EXHAUSTION_COOLDOWN_THRESHOLD_{provider_upper}") + if env_threshold: + try: + config.exhaustion_cooldown_threshold = int(env_threshold) + except ValueError: + pass + + # Priority multipliers from env + # Format: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}=value + # Format: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE}=value + for key, value in os.environ.items(): + prefix = f"CONCURRENCY_MULTIPLIER_{provider_upper}_PRIORITY_" + if key.startswith(prefix): + try: + remainder = key[len(prefix) :] + multiplier = int(value) + if multiplier < 1: + continue + if "_" in remainder: + priority_str, mode = remainder.rsplit("_", 1) + priority = int(priority_str) + mode = mode.lower() + if mode in ("sequential", "balanced"): + config.priority_multipliers_by_mode.setdefault(mode, {})[ + priority + ] = multiplier + else: + config.priority_multipliers[priority] = multiplier + else: + priority = int(remainder) + config.priority_multipliers[priority] = multiplier + except ValueError: + pass + + # Custom caps from env + if os.environ: + cap_map: Dict[str, Dict[str, Dict[str, Any]]] = {} + for cap in config.custom_caps: + cap_entry = cap_map.setdefault(str(cap.tier_key), {}) + cap_entry[cap.model_or_group] = { + "max_requests": cap.max_requests, + "max_requests_mode": cap.max_requests_mode.value, + "cooldown_mode": cap.cooldown_mode.value, + "cooldown_value": cap.cooldown_value, + } + + cap_prefix = f"CUSTOM_CAP_{provider_upper}_T" + cooldown_prefix = f"CUSTOM_CAP_COOLDOWN_{provider_upper}_T" + for env_key, env_value in os.environ.items(): + if env_key.startswith(cap_prefix) and not env_key.startswith( + cooldown_prefix + ): + remainder = env_key[len(cap_prefix) :] + tier_key, model_key = _parse_custom_cap_env_key(remainder) + if tier_key is None or not model_key: + continue + cap_entry = cap_map.setdefault(str(tier_key), {}) + cap_entry.setdefault(model_key, {})["max_requests"] = env_value + elif env_key.startswith(cooldown_prefix): + remainder = env_key[len(cooldown_prefix) :] + tier_key, model_key = _parse_custom_cap_env_key(remainder) + if tier_key is None or not model_key: + continue + if ":" in env_value: + mode, value_str = env_value.split(":", 1) + try: + value = int(value_str) + except ValueError: + continue + else: + mode = env_value + value = 0 + cap_entry = cap_map.setdefault(str(tier_key), {}) + cap_entry.setdefault(model_key, {})["cooldown_mode"] = mode + cap_entry.setdefault(model_key, {})["cooldown_value"] = value + + config.custom_caps = [] + for tier_key, models in cap_map.items(): + for model_or_group, cap_config in models.items(): + cap = CustomCapConfig.from_dict(tier_key, model_or_group, cap_config) + if cap is not None: + config.custom_caps.append(cap) + + # Derive fair cycle enabled from rotation mode if not explicitly set + if config.fair_cycle.enabled is None: + config.fair_cycle.enabled = config.rotation_mode == RotationMode.SEQUENTIAL + + return config + + +def _parse_custom_cap_env_key( + remainder: str, +) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: + """Parse the tier and model/group from a custom cap env var remainder.""" + if not remainder: + return None, None + + remaining_parts = remainder.split("_") + if len(remaining_parts) < 2: + return None, None + + tier_key: Union[int, Tuple[int, ...], str, None] = None + model_key: Optional[str] = None + tier_parts: List[int] = [] + + for i, part in enumerate(remaining_parts): + if part == "DEFAULT": + tier_key = "default" + model_key = "_".join(remaining_parts[i + 1 :]) + break + if part.isdigit(): + tier_parts.append(int(part)) + continue + + if not tier_parts: + return None, None + if len(tier_parts) == 1: + tier_key = tier_parts[0] + else: + tier_key = tuple(tier_parts) + model_key = "_".join(remaining_parts[i:]) + break + else: + return None, None + + if model_key: + model_key = model_key.lower().replace("_", "-") + + return tier_key, model_key diff --git a/src/rotator_library/usage/identity/__init__.py b/src/rotator_library/usage/identity/__init__.py new file mode 100644 index 00000000..b42af7ff --- /dev/null +++ b/src/rotator_library/usage/identity/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Credential identity management.""" + +from .registry import CredentialRegistry + +__all__ = ["CredentialRegistry"] diff --git a/src/rotator_library/usage/identity/registry.py b/src/rotator_library/usage/identity/registry.py new file mode 100644 index 00000000..ddd867d4 --- /dev/null +++ b/src/rotator_library/usage/identity/registry.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Credential identity registry. + +Provides stable identifiers for credentials that persist across +file path changes (for OAuth) and hide sensitive data (for API keys). +""" + +import hashlib +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Set + +from ...core.types import CredentialInfo + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialRegistry: + """ + Manages stable identifiers for credentials. + + Stable IDs are: + - For OAuth credentials: The email address from _proxy_metadata.email + - For API keys: SHA-256 hash of the key (truncated for readability) + + This ensures usage data persists even when: + - OAuth credential files are moved/renamed + - API keys are passed in different orders + """ + + def __init__(self): + # Cache: accessor -> CredentialInfo + self._cache: Dict[str, CredentialInfo] = {} + # Reverse index: stable_id -> accessor + self._id_to_accessor: Dict[str, str] = {} + + def get_stable_id(self, accessor: str, provider: str) -> str: + """ + Get or create a stable ID for a credential accessor. + + Args: + accessor: The credential accessor (file path or API key) + provider: Provider name + + Returns: + Stable identifier string + """ + # Check cache first + if accessor in self._cache: + return self._cache[accessor].stable_id + + # Determine if OAuth or API key + if self._is_oauth_path(accessor): + stable_id = self._get_oauth_stable_id(accessor) + else: + stable_id = self._get_api_key_stable_id(accessor) + + # Cache the result + info = CredentialInfo( + accessor=accessor, + stable_id=stable_id, + provider=provider, + ) + self._cache[accessor] = info + self._id_to_accessor[stable_id] = accessor + + return stable_id + + def get_info(self, accessor: str, provider: str) -> CredentialInfo: + """ + Get complete credential info for an accessor. + + Args: + accessor: The credential accessor + provider: Provider name + + Returns: + CredentialInfo with stable_id and metadata + """ + # Ensure stable ID is computed + self.get_stable_id(accessor, provider) + return self._cache[accessor] + + def get_accessor(self, stable_id: str) -> Optional[str]: + """ + Get the current accessor for a stable ID. + + Args: + stable_id: The stable identifier + + Returns: + Current accessor string, or None if not found + """ + return self._id_to_accessor.get(stable_id) + + def update_accessor(self, stable_id: str, new_accessor: str) -> None: + """ + Update the accessor for a stable ID. + + Used when an OAuth credential file is moved/renamed. + + Args: + stable_id: The stable identifier + new_accessor: New accessor path + """ + old_accessor = self._id_to_accessor.get(stable_id) + if old_accessor and old_accessor in self._cache: + info = self._cache.pop(old_accessor) + info.accessor = new_accessor + self._cache[new_accessor] = info + self._id_to_accessor[stable_id] = new_accessor + + def update_metadata( + self, + accessor: str, + provider: str, + tier: Optional[str] = None, + priority: Optional[int] = None, + display_name: Optional[str] = None, + ) -> None: + """ + Update metadata for a credential. + + Args: + accessor: The credential accessor + provider: Provider name + tier: Tier name (e.g., "standard-tier") + priority: Priority level (lower = higher priority) + display_name: Human-readable name + """ + info = self.get_info(accessor, provider) + if tier is not None: + info.tier = tier + if priority is not None: + info.priority = priority + if display_name is not None: + info.display_name = display_name + + def get_all_accessors(self) -> Set[str]: + """Get all registered accessors.""" + return set(self._cache.keys()) + + def get_all_stable_ids(self) -> Set[str]: + """Get all registered stable IDs.""" + return set(self._id_to_accessor.keys()) + + def clear_cache(self) -> None: + """Clear the internal cache.""" + self._cache.clear() + self._id_to_accessor.clear() + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _is_oauth_path(self, accessor: str) -> bool: + """ + Check if accessor is an OAuth credential file path. + + OAuth paths typically end with .json and exist on disk. + API keys are typically raw strings. + """ + # Simple heuristic: if it looks like a file path with .json, it's OAuth + if accessor.endswith(".json"): + return True + # If it contains path separators, it's likely a file path + if "/" in accessor or "\\" in accessor: + return True + return False + + def _get_oauth_stable_id(self, accessor: str) -> str: + """ + Get stable ID for an OAuth credential. + + Reads the email from _proxy_metadata.email in the credential file. + Falls back to file hash if email not found. + """ + try: + path = Path(accessor) + if path.exists(): + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Try to get email from _proxy_metadata + metadata = data.get("_proxy_metadata", {}) + email = metadata.get("email") + if email: + return email + + # Fallback: try common OAuth fields + for field in ["email", "client_email", "account"]: + if field in data: + return data[field] + + # Last resort: hash the file content + lib_logger.debug( + f"No email found in OAuth credential {accessor}, using content hash" + ) + return self._hash_content(json.dumps(data, sort_keys=True)) + + except Exception as e: + lib_logger.warning(f"Failed to read OAuth credential {accessor}: {e}") + + # Fallback: hash the path + return self._hash_content(accessor) + + def _get_api_key_stable_id(self, accessor: str) -> str: + """ + Get stable ID for an API key. + + Uses truncated SHA-256 hash to hide the actual key. + """ + return self._hash_content(accessor) + + def _hash_content(self, content: str) -> str: + """ + Create a stable hash of content. + + Uses first 12 characters of SHA-256 for readability. + """ + return hashlib.sha256(content.encode()).hexdigest()[:12] + + # ========================================================================= + # SERIALIZATION + # ========================================================================= + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize registry state for persistence. + + Returns: + Dictionary suitable for JSON serialization + """ + return { + "accessor_index": dict(self._id_to_accessor), + "credentials": { + accessor: { + "stable_id": info.stable_id, + "provider": info.provider, + "tier": info.tier, + "priority": info.priority, + "display_name": info.display_name, + } + for accessor, info in self._cache.items() + }, + } + + def from_dict(self, data: Dict[str, Any]) -> None: + """ + Restore registry state from persistence. + + Args: + data: Dictionary from to_dict() + """ + self._id_to_accessor = dict(data.get("accessor_index", {})) + + for accessor, cred_data in data.get("credentials", {}).items(): + info = CredentialInfo( + accessor=accessor, + stable_id=cred_data["stable_id"], + provider=cred_data["provider"], + tier=cred_data.get("tier"), + priority=cred_data.get("priority", 999), + display_name=cred_data.get("display_name"), + ) + self._cache[accessor] = info diff --git a/src/rotator_library/usage/integration/__init__.py b/src/rotator_library/usage/integration/__init__.py new file mode 100644 index 00000000..4e58bdd8 --- /dev/null +++ b/src/rotator_library/usage/integration/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Integration helpers for usage manager.""" + +from .hooks import HookDispatcher +from .api import UsageAPI + +__all__ = ["HookDispatcher", "UsageAPI"] diff --git a/src/rotator_library/usage/integration/api.py b/src/rotator_library/usage/integration/api.py new file mode 100644 index 00000000..30f988c6 --- /dev/null +++ b/src/rotator_library/usage/integration/api.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage API Facade for Reading and Updating Usage Data. + +This module provides a clean, public API for programmatically interacting with +usage data. It's accessible via `usage_manager.api` and is intended for: + + - Admin endpoints (viewing/modifying credential state) + - Background jobs (quota refresh, cleanup tasks) + - Monitoring and alerting (checking remaining quota) + - External tooling and integrations + - Provider-specific logic that needs to inspect/modify state + +============================================================================= +ACCESSING THE API +============================================================================= + +The UsageAPI is available as a property on UsageManager: + + # From RotatingClient + usage_manager = client.get_usage_manager("my_provider") + api = usage_manager.api + + # Or if you have the manager directly + api = usage_manager.api + +============================================================================= +AVAILABLE METHODS +============================================================================= + +Reading State +------------- + + # Get state for a specific credential + state = api.get_state("path/to/credential.json") + if state: + print(f"Total requests: {state.totals.request_count}") + print(f"Total successes: {state.totals.success_count}") + print(f"Total failures: {state.totals.failure_count}") + + # Get all credential states + all_states = api.get_all_states() + for stable_id, state in all_states.items(): + print(f"{stable_id}: {state.totals.request_count} requests") + + # Check remaining quota in a window + remaining = api.get_window_remaining( + accessor="path/to/credential.json", + window_name="5h", + model="gpt-4o", # Optional: specific model + quota_group="gpt4", # Optional: quota group + ) + print(f"Remaining in 5h window: {remaining}") + +Modifying State +--------------- + + # Apply a manual cooldown + await api.apply_cooldown( + accessor="path/to/credential.json", + duration=1800.0, # 30 minutes + reason="manual_override", + model_or_group="gpt4", # Optional: scope to model/group + ) + + # Clear a cooldown + await api.clear_cooldown( + accessor="path/to/credential.json", + model_or_group="gpt4", # Optional: scope + ) + + # Mark credential as exhausted for fair cycle + await api.mark_exhausted( + accessor="path/to/credential.json", + model_or_group="gpt4", + reason="quota_exceeded", + ) + +============================================================================= +CREDENTIAL STATE STRUCTURE +============================================================================= + +CredentialState contains: + + state.accessor # File path or API key + state.display_name # Human-readable name (e.g., email) + state.tier # Tier name (e.g., "standard-tier") + state.priority # Priority level (1 = highest) + state.active_requests # Currently in-flight requests + + state.totals # TotalStats - credential-level totals + state.model_usage # Dict[model, ModelStats] + state.group_usage # Dict[group, GroupStats] + + state.cooldowns # Dict[key, CooldownState] + state.fair_cycle # Dict[key, FairCycleState] + +ModelStats / GroupStats contain: + + stats.windows # Dict[name, WindowStats] - time-based windows + stats.totals # TotalStats - all-time totals for this scope + +TotalStats contains: + + totals.request_count + totals.success_count + totals.failure_count + totals.prompt_tokens + totals.completion_tokens + totals.thinking_tokens + totals.output_tokens + totals.prompt_tokens_cache_read + totals.prompt_tokens_cache_write + totals.total_tokens + totals.approx_cost + totals.first_used_at + totals.last_used_at + +WindowStats contains: + + window.request_count + window.success_count + window.failure_count + window.prompt_tokens + window.completion_tokens + window.thinking_tokens + window.output_tokens + window.prompt_tokens_cache_read + window.prompt_tokens_cache_write + window.total_tokens + window.approx_cost + window.started_at + window.reset_at + window.limit + window.remaining # Computed: limit - request_count (if limit set) + +============================================================================= +EXAMPLE: BUILDING AN ADMIN ENDPOINT +============================================================================= + + from fastapi import APIRouter + from rotator_library import RotatingClient + + router = APIRouter() + + @router.get("/admin/credentials/{provider}") + async def list_credentials(provider: str): + usage_manager = client.get_usage_manager(provider) + if not usage_manager: + return {"error": "Provider not found"} + + api = usage_manager.api + result = [] + + for stable_id, state in api.get_all_states().items(): + result.append({ + "id": stable_id, + "accessor": state.accessor, + "tier": state.tier, + "priority": state.priority, + "requests": state.totals.request_count, + "successes": state.totals.success_count, + "failures": state.totals.failure_count, + "cooldowns": [ + {"key": k, "remaining": v.remaining_seconds} + for k, v in state.cooldowns.items() + if v.is_active + ], + }) + + return {"credentials": result} + + @router.post("/admin/credentials/{provider}/{accessor}/cooldown") + async def apply_cooldown(provider: str, accessor: str, duration: float): + usage_manager = client.get_usage_manager(provider) + api = usage_manager.api + await api.apply_cooldown(accessor, duration, reason="admin") + return {"status": "cooldown applied"} + +============================================================================= +""" + +from typing import Any, Dict, Optional, TYPE_CHECKING + +from ..types import CredentialState + +if TYPE_CHECKING: + from ..manager import UsageManager + + +class UsageAPI: + """ + Public API facade for reading and updating usage data. + + Provides a clean interface for external code to interact with usage + tracking without needing to understand the internal component structure. + + Access via: usage_manager.api + + Example: + api = usage_manager.api + state = api.get_state("path/to/credential.json") + remaining = api.get_window_remaining("path/to/cred.json", "5h", "gpt-4o") + await api.apply_cooldown("path/to/cred.json", 1800.0, "manual") + """ + + def __init__(self, manager: "UsageManager"): + """ + Initialize the API facade. + + Args: + manager: The UsageManager instance to wrap. + """ + self._manager = manager + + def get_state(self, accessor: str) -> Optional[CredentialState]: + """ + Get the credential state for a given accessor. + + Args: + accessor: Credential file path or API key. + + Returns: + CredentialState if found, None otherwise. + + Example: + state = api.get_state("oauth_creds/my_cred.json") + if state: + print(f"Requests: {state.totals.request_count}") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + return self._manager.states.get(stable_id) + + def get_all_states(self) -> Dict[str, CredentialState]: + """ + Get all credential states. + + Returns: + Dict mapping stable_id to CredentialState. + + Example: + for stable_id, state in api.get_all_states().items(): + print(f"{stable_id}: {state.totals.request_count} requests") + """ + return dict(self._manager.states) + + def get_window_remaining( + self, + accessor: str, + window_name: str, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Optional[int]: + """ + Get remaining requests in a usage window. + + Args: + accessor: Credential file path or API key. + window_name: Window name (e.g., "5h", "daily"). + model: Optional model to check (uses model-specific window). + quota_group: Optional quota group to check. + + Returns: + Remaining requests (limit - used), or None if: + - Credential not found + - Window has no limit set + + Example: + remaining = api.get_window_remaining("cred.json", "5h", model="gpt-4o") + if remaining is not None and remaining < 10: + print("Warning: low quota remaining") + """ + state = self.get_state(accessor) + if not state: + return None + return self._manager.limits.window_checker.get_remaining( + state, window_name, model=model, quota_group=quota_group + ) + + async def apply_cooldown( + self, + accessor: str, + duration: float, + reason: str = "manual", + model_or_group: Optional[str] = None, + ) -> None: + """ + Apply a cooldown to a credential. + + The credential will not be selected for requests until the cooldown + expires or is cleared. + + Args: + accessor: Credential file path or API key. + duration: Cooldown duration in seconds. + reason: Reason for cooldown (for logging/debugging). + model_or_group: Optional scope (model name or quota group). + If None, applies to credential globally. + + Example: + # Global cooldown + await api.apply_cooldown("cred.json", 1800.0, "maintenance") + + # Model-specific cooldown + await api.apply_cooldown("cred.json", 3600.0, "quota", "gpt-4o") + """ + await self._manager.apply_cooldown( + accessor=accessor, + duration=duration, + reason=reason, + model_or_group=model_or_group, + ) + + async def clear_cooldown( + self, + accessor: str, + model_or_group: Optional[str] = None, + ) -> None: + """ + Clear a cooldown from a credential. + + Args: + accessor: Credential file path or API key. + model_or_group: Optional scope to clear. If None, clears all. + + Example: + # Clear specific cooldown + await api.clear_cooldown("cred.json", "gpt-4o") + + # Clear all cooldowns + await api.clear_cooldown("cred.json") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + state = self._manager.states.get(stable_id) + if state: + await self._manager.tracking.clear_cooldown( + state=state, + model_or_group=model_or_group, + ) + + async def mark_exhausted( + self, + accessor: str, + model_or_group: str, + reason: str, + ) -> None: + """ + Mark a credential as exhausted for fair cycle. + + The credential will be skipped during selection until all other + credentials in the same tier are also exhausted, at which point + the fair cycle resets. + + Args: + accessor: Credential file path or API key. + model_or_group: Model name or quota group to mark exhausted. + reason: Reason for exhaustion (for logging/debugging). + + Example: + await api.mark_exhausted("cred.json", "gpt-4o", "quota_exceeded") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + state = self._manager.states.get(stable_id) + if state: + await self._manager.tracking.mark_exhausted( + state=state, + model_or_group=model_or_group, + reason=reason, + ) diff --git a/src/rotator_library/usage/integration/hooks.py b/src/rotator_library/usage/integration/hooks.py new file mode 100644 index 00000000..1f59446d --- /dev/null +++ b/src/rotator_library/usage/integration/hooks.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Provider Hook Dispatcher for Usage Manager. + +This module bridges provider plugins to the usage manager, allowing providers +to customize how requests are counted, cooled down, and tracked. + +============================================================================= +OVERVIEW +============================================================================= + +The HookDispatcher calls provider hooks at key points in the request lifecycle. +Currently, the main hook is `on_request_complete`, which is called after every +request (success or failure) and allows the provider to override: + + - Request count (how many requests to record) + - Cooldown duration (custom cooldown to apply) + - Exhaustion state (mark credential for fair cycle) + +============================================================================= +IMPLEMENTING on_request_complete IN YOUR PROVIDER +============================================================================= + +Add this method to your provider class: + + from rotator_library.core.types import RequestCompleteResult + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + ''' + Called after each request completes. + + Args: + credential: Credential accessor (file path or API key) + model: Model that was called + success: Whether the request succeeded + response: Response object (if success=True) + error: ClassifiedError object (if success=False) + + Returns: + RequestCompleteResult to override behavior, or None for defaults. + ''' + # Your logic here + return None # Use default behavior + +============================================================================= +RequestCompleteResult FIELDS +============================================================================= + + count_override: Optional[int] + How many requests to count for usage tracking. + - 0 = Don't count this request (e.g., server errors) + - N = Count as N requests (e.g., internal retries) + - None = Use default (1) + + cooldown_override: Optional[float] + Seconds to cool down this credential. + - Applied in addition to any error-based cooldown. + - Use for custom rate limiting logic. + + force_exhausted: bool + Mark credential as exhausted for fair cycle. + - True = Skip this credential until fair cycle resets. + - Useful for quota errors without long cooldowns. + +============================================================================= +USE CASE: COUNTING INTERNAL RETRIES +============================================================================= + +If your provider performs internal retries (e.g., for transient errors, empty +responses, or malformed responses), each retry is an API call that should be +counted. Use the ContextVar pattern for thread-safe counting: + + from contextvars import ContextVar + from rotator_library.core.types import RequestCompleteResult + + # Module-level: each async task gets its own isolated value + _internal_attempt_count: ContextVar[int] = ContextVar( + 'my_provider_attempt_count', default=1 + ) + + class MyProvider: + + async def _make_request_with_retry(self, ...): + # Reset at start of request + _internal_attempt_count.set(1) + + for attempt in range(max_attempts): + try: + result = await self._call_api(...) + return result # Success + except RetryableError: + # Increment before retry + _internal_attempt_count.set(_internal_attempt_count.get() + 1) + continue + + def on_request_complete(self, credential, model, success, response, error): + # Report actual API call count + count = _internal_attempt_count.get() + _internal_attempt_count.set(1) # Reset for safety + + if count > 1: + logging.debug(f"Request used {count} API calls (internal retries)") + + return RequestCompleteResult(count_override=count) + +Why ContextVar? + - Instance variables (self.count) are shared across concurrent requests + - ContextVar gives each async task its own isolated value + - Thread-safe without explicit locking + +============================================================================= +USE CASE: CUSTOM ERROR HANDLING +============================================================================= + +Override counting or cooldown based on error type: + + def on_request_complete(self, credential, model, success, response, error): + if not success and error: + # Don't count server errors against quota + if error.error_type == "server_error": + return RequestCompleteResult(count_override=0) + + # Force exhaustion on quota errors + if error.error_type == "quota_exceeded": + return RequestCompleteResult( + force_exhausted=True, + cooldown_override=3600.0, # 1 hour + ) + + # Custom cooldown for rate limits + if error.error_type == "rate_limit": + retry_after = getattr(error, "retry_after", 60) + return RequestCompleteResult(cooldown_override=retry_after) + + return None # Default behavior + +============================================================================= +""" + +import asyncio +from typing import Any, Dict, Optional + +from ...core.types import RequestCompleteResult + + +class HookDispatcher: + """ + Dispatch optional provider hooks during request lifecycle. + + The HookDispatcher is instantiated by UsageManager with the provider plugins + dict. It lazily instantiates provider instances and calls their hooks. + + Currently supported hooks: + - on_request_complete: Called after each request completes + + Usage: + dispatcher = HookDispatcher(provider_plugins) + result = await dispatcher.dispatch_request_complete( + provider="my_provider", + credential="path/to/cred.json", + model="my-model", + success=True, + response=response_obj, + error=None, + ) + if result and result.count_override is not None: + request_count = result.count_override + """ + + def __init__(self, provider_plugins: Optional[Dict[str, Any]] = None): + """ + Initialize the hook dispatcher. + + Args: + provider_plugins: Dict mapping provider names to plugin classes. + Classes are lazily instantiated on first hook call. + """ + self._plugins = provider_plugins or {} + + def _get_instance(self, provider: str) -> Optional[Any]: + """Get provider plugin instance (singleton via metaclass).""" + plugin_class = self._plugins.get(provider) + if not plugin_class: + return None + if isinstance(plugin_class, type): + return plugin_class() # Singleton - always returns same instance + return plugin_class + + async def dispatch_request_complete( + self, + provider: str, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + """ + Dispatch the on_request_complete hook to a provider. + + Called by UsageManager after each request completes (success or failure). + The provider can return a RequestCompleteResult to override default + behavior for request counting, cooldowns, or exhaustion marking. + + Args: + provider: Provider name (e.g., "antigravity", "openai") + credential: Credential accessor (file path or API key) + model: Model that was called (with provider prefix) + success: Whether the request succeeded + response: Response object if success=True, else None + error: ClassifiedError if success=False, else None + + Returns: + RequestCompleteResult from provider, or None if: + - Provider not found in plugins + - Provider doesn't implement on_request_complete + - Provider returns None (use default behavior) + """ + plugin = self._get_instance(provider) + if not plugin or not hasattr(plugin, "on_request_complete"): + return None + + result = plugin.on_request_complete(credential, model, success, response, error) + if asyncio.iscoroutine(result): + result = await result + + return result diff --git a/src/rotator_library/usage/limits/__init__.py b/src/rotator_library/usage/limits/__init__.py new file mode 100644 index 00000000..e759c312 --- /dev/null +++ b/src/rotator_library/usage/limits/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Limit checking and enforcement.""" + +from .engine import LimitEngine +from .base import LimitChecker +from .window_limits import WindowLimitChecker +from .cooldowns import CooldownChecker +from .fair_cycle import FairCycleChecker +from .custom_caps import CustomCapChecker + +__all__ = [ + "LimitEngine", + "LimitChecker", + "WindowLimitChecker", + "CooldownChecker", + "FairCycleChecker", + "CustomCapChecker", +] diff --git a/src/rotator_library/usage/limits/base.py b/src/rotator_library/usage/limits/base.py new file mode 100644 index 00000000..c6c0a4d8 --- /dev/null +++ b/src/rotator_library/usage/limits/base.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Base interface for limit checkers. + +All limit types implement this interface for consistent behavior. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult + + +class LimitChecker(ABC): + """ + Abstract base class for limit checkers. + + Each limit type (window, cooldown, fair cycle, custom cap) + implements this interface. + """ + + @property + @abstractmethod + def name(self) -> str: + """Name of this limit checker.""" + ... + + @abstractmethod + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if a credential passes this limit. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail and reason + """ + ... + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset this limit for a credential. + + Default implementation does nothing - override if needed. + + Args: + state: Credential state to reset + model: Optional model scope + quota_group: Optional quota group scope + """ + pass diff --git a/src/rotator_library/usage/limits/concurrent.py b/src/rotator_library/usage/limits/concurrent.py new file mode 100644 index 00000000..83a510cb --- /dev/null +++ b/src/rotator_library/usage/limits/concurrent.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Concurrent request limit checker. + +Blocks credentials that have reached their max_concurrent limit. +""" + +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from .base import LimitChecker + + +class ConcurrentLimitChecker(LimitChecker): + """ + Checks concurrent request limits. + + Blocks credentials that have active_requests >= max_concurrent. + This ensures we don't overload any single credential. + """ + + @property + def name(self) -> str: + return "concurrent" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is at max concurrent. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + # If no limit set, always allow + if state.max_concurrent is None: + return LimitCheckResult.ok() + + # Check if at or above limit + if state.active_requests >= state.max_concurrent: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_CONCURRENT, + reason=f"At max concurrent: {state.active_requests}/{state.max_concurrent}", + blocked_until=None, # No specific time - depends on request completion + ) + + return LimitCheckResult.ok() + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset concurrent count. + + Note: This is rarely needed as active_requests is + managed by acquire/release, not limit checking. + """ + # Typically don't reset active_requests via limit system + pass diff --git a/src/rotator_library/usage/limits/cooldowns.py b/src/rotator_library/usage/limits/cooldowns.py new file mode 100644 index 00000000..bf08a4a4 --- /dev/null +++ b/src/rotator_library/usage/limits/cooldowns.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Cooldown checker. + +Checks if a credential is currently in cooldown. +""" + +import time +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from .base import LimitChecker + + +class CooldownChecker(LimitChecker): + """ + Checks cooldown status for credentials. + + Blocks credentials that are currently cooling down from + rate limits, errors, or other causes. + """ + + @property + def name(self) -> str: + return "cooldowns" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is in cooldown. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + now = time.time() + group_key = quota_group or model + + # Check model/group-specific cooldowns + keys_to_check = [] + if group_key: + keys_to_check.append(group_key) + if quota_group and quota_group != model: + keys_to_check.append(model) + + for key in keys_to_check: + cooldown = state.cooldowns.get(key) + if cooldown and cooldown.until > now: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_COOLDOWN, + reason=f"Cooldown for '{key}': {cooldown.reason} (expires in {cooldown.remaining_seconds:.0f}s)", + blocked_until=cooldown.until, + ) + + # Check global cooldown + global_cooldown = state.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_COOLDOWN, + reason=f"Global cooldown: {global_cooldown.reason} (expires in {global_cooldown.remaining_seconds:.0f}s)", + blocked_until=global_cooldown.until, + ) + + return LimitCheckResult.ok() + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Clear cooldown for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + if quota_group: + if quota_group in state.cooldowns: + del state.cooldowns[quota_group] + elif model: + if model in state.cooldowns: + del state.cooldowns[model] + else: + # Clear all cooldowns + state.cooldowns.clear() + + def get_cooldown_end( + self, + state: CredentialState, + model_or_group: Optional[str] = None, + ) -> Optional[float]: + """ + Get when cooldown ends for a credential. + + Args: + state: Credential state + model_or_group: Optional scope to check + + Returns: + Timestamp when cooldown ends, or None if not in cooldown + """ + now = time.time() + + # Check specific scope + if model_or_group: + cooldown = state.cooldowns.get(model_or_group) + if cooldown and cooldown.until > now: + return cooldown.until + + # Check global + global_cooldown = state.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return global_cooldown.until + + return None diff --git a/src/rotator_library/usage/limits/custom_caps.py b/src/rotator_library/usage/limits/custom_caps.py new file mode 100644 index 00000000..0cc7f308 --- /dev/null +++ b/src/rotator_library/usage/limits/custom_caps.py @@ -0,0 +1,373 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Custom cap limit checker. + +Enforces user-defined limits on API usage. +""" + +import time +import logging +from typing import Dict, List, Optional, Tuple + +from ..types import CredentialState, LimitCheckResult, LimitResult, WindowStats +from ..config import CustomCapConfig, CooldownMode, CapMode, CapMode, CapMode +from ..tracking.windows import WindowManager +from .base import LimitChecker + +lib_logger = logging.getLogger("rotator_library") + + +# Scope constants for cap application +SCOPE_MODEL = "model" +SCOPE_GROUP = "group" + + +class CustomCapChecker(LimitChecker): + """ + Checks custom cap limits. + + Custom caps allow users to set custom usage limits per tier/model/group. + Limits can be absolute numbers or percentages of API limits. + + Caps are checked independently for both model AND group scopes: + - A model cap being exceeded blocks only that model + - A group cap being exceeded blocks the entire group + - Both caps can exist and are checked separately (first blocked wins) + """ + + def __init__( + self, + caps: List[CustomCapConfig], + window_manager: WindowManager, + ): + """ + Initialize custom cap checker. + + Args: + caps: List of custom cap configurations + window_manager: WindowManager for checking window usage + """ + self._caps = caps + self._windows = window_manager + # Index caps by (tier_key, model_or_group) for fast lookup + self._cap_index: Dict[tuple, CustomCapConfig] = {} + for cap in caps: + self._cap_index[(cap.tier_key, cap.model_or_group)] = cap + + @property + def name(self) -> str: + return "custom_caps" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if any custom cap is exceeded. + + Checks both model and group caps independently. If both exist, + each is checked against its respective usage scope. First blocked + cap wins. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + if not self._caps: + return LimitCheckResult.ok() + + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return LimitCheckResult.ok() + + priority = state.priority + + # Find all applicable caps (model + group separately) + all_caps = self._find_all_caps(str(priority), model, quota_group) + if not all_caps: + return LimitCheckResult.ok() + + # Check each cap against its proper scope + for cap, scope, scope_key in all_caps: + result = self._check_single_cap( + state, cap, scope, scope_key, model, quota_group + ) + if not result.allowed: + return result + + return LimitCheckResult.ok() + + def get_cap_for( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> Optional[CustomCapConfig]: + """ + Get the first applicable custom cap for a credential/model. + + Args: + state: Credential state + model: Model name + quota_group: Quota group + + Returns: + CustomCapConfig if one applies, None otherwise + """ + priority = state.priority + all_caps = self._find_all_caps(str(priority), model, quota_group) + if all_caps: + return all_caps[0][0] # Return first cap + return None + + def get_all_caps_for( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> List[Tuple[CustomCapConfig, str, str]]: + """ + Get all applicable custom caps for a credential/model. + + Args: + state: Credential state + model: Model name + quota_group: Quota group + + Returns: + List of (cap, scope, scope_key) tuples + """ + priority = state.priority + return self._find_all_caps(str(priority), model, quota_group) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _find_all_caps( + self, + priority_key: str, + model: str, + quota_group: Optional[str], + ) -> List[Tuple[CustomCapConfig, str, str]]: + """ + Find all applicable caps for a request. + + Returns caps for both model AND group scopes (if they exist). + Each cap is returned with its scope type and scope key. + + Args: + priority_key: Priority level as string + model: Model name + quota_group: Quota group (optional) + + Returns: + List of (cap, scope, scope_key) tuples where: + - cap: The CustomCapConfig + - scope: SCOPE_MODEL or SCOPE_GROUP + - scope_key: The model name or group name + """ + result: List[Tuple[CustomCapConfig, str, str]] = [] + + # Check model cap (priority-specific, then default) + model_cap = self._cap_index.get((priority_key, model)) or self._cap_index.get( + ("default", model) + ) + if model_cap: + result.append((model_cap, SCOPE_MODEL, model)) + + # Check group cap (priority-specific, then default) - only if group differs from model + if quota_group and quota_group != model: + group_cap = self._cap_index.get( + (priority_key, quota_group) + ) or self._cap_index.get(("default", quota_group)) + if group_cap: + result.append((group_cap, SCOPE_GROUP, quota_group)) + + return result + + def _check_single_cap( + self, + state: CredentialState, + cap: CustomCapConfig, + scope: str, + scope_key: str, + model: str, + quota_group: Optional[str], + ) -> LimitCheckResult: + """ + Check a single cap against its appropriate usage scope. + + Args: + state: Credential state + cap: The cap configuration to check + scope: SCOPE_MODEL or SCOPE_GROUP + scope_key: The model name or group name + model: Original model name (for fallback) + quota_group: Original quota group (for fallback) + + Returns: + LimitCheckResult for this specific cap + """ + # Get windows based on scope + windows = None + + if scope == SCOPE_GROUP: + group_stats = state.get_group_stats(scope_key, create=False) + if group_stats: + windows = group_stats.windows + else: # SCOPE_MODEL + model_stats = state.get_model_stats(scope_key, create=False) + if model_stats: + windows = model_stats.windows + + if windows is None: + return LimitCheckResult.ok() + + # Get usage from primary window + primary_window = self._windows.get_primary_window(windows) + if primary_window is None: + return LimitCheckResult.ok() + + current_usage = primary_window.request_count + max_requests = self._resolve_max_requests(cap, primary_window.limit) + + if current_usage >= max_requests: + # Calculate cooldown end + cooldown_until = self._calculate_cooldown_until(cap, primary_window) + + # Build descriptive reason with scope info + scope_desc = "model" if scope == SCOPE_MODEL else "group" + reason = ( + f"Custom cap for {scope_desc} '{scope_key}' exceeded " + f"({current_usage}/{max_requests})" + ) + + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_CUSTOM_CAP, + reason=reason, + blocked_until=cooldown_until, + ) + + return LimitCheckResult.ok() + + def _find_cap( + self, + priority_key: str, + group_key: str, + model: str, + ) -> Optional[CustomCapConfig]: + """ + Find the most specific applicable cap (legacy method for compatibility). + + Deprecated: Use _find_all_caps() for layered cap checking. + """ + # Try exact matches first + # Priority + group + cap = self._cap_index.get((priority_key, group_key)) + if cap: + return cap + + # Priority + model (if different from group) + if model != group_key: + cap = self._cap_index.get((priority_key, model)) + if cap: + return cap + + # Default tier + group + cap = self._cap_index.get(("default", group_key)) + if cap: + return cap + + # Default tier + model + if model != group_key: + cap = self._cap_index.get(("default", model)) + if cap: + return cap + + return None + + def _resolve_max_requests( + self, + cap: CustomCapConfig, + window_limit: Optional[int], + ) -> int: + """ + Resolve max requests based on mode. + + Modes: + - ABSOLUTE: Use value as-is (e.g., 130 → 130) + - OFFSET: Add/subtract from window limit (e.g., -130 → max - 130) + - PERCENTAGE: Percentage of window limit (e.g., 80 → 80% of max) + + Always clamps result to >= 0. + """ + if cap.max_requests_mode == CapMode.ABSOLUTE: + return max(0, cap.max_requests) + + # For OFFSET and PERCENTAGE, we need window_limit + if window_limit is None: + # No limit known - fallback behavior + if cap.max_requests_mode == CapMode.OFFSET: + # Can't apply offset without knowing the max + # Use absolute value as fallback + return max(0, abs(cap.max_requests)) + # PERCENTAGE with no limit - use safe default + return 1000 + + if cap.max_requests_mode == CapMode.OFFSET: + # +130 means max + 130, -130 means max - 130 + return max(0, window_limit + cap.max_requests) + + if cap.max_requests_mode == CapMode.PERCENTAGE: + return max(0, int(window_limit * cap.max_requests / 100)) + + # Fallback (shouldn't happen) + return max(0, cap.max_requests) + + def _calculate_cooldown_until( + self, + cap: CustomCapConfig, + window: WindowStats, + ) -> Optional[float]: + """ + Calculate when the custom cap cooldown ends. + + Modes: + - QUOTA_RESET: Wait until natural window reset + - OFFSET: Add/subtract offset from natural reset (clamped to >= reset) + - FIXED: Fixed duration from now + """ + now = time.time() + natural_reset = window.reset_at + + if cap.cooldown_mode == CooldownMode.QUOTA_RESET: + # Wait until window resets + return natural_reset + + elif cap.cooldown_mode == CooldownMode.OFFSET: + # Offset from natural reset time + # Positive offset = wait AFTER reset + # Negative offset = wait BEFORE reset (clamped to >= reset for safety) + if natural_reset: + calculated = natural_reset + cap.cooldown_value + # Always clamp to at least natural_reset (can't end before quota resets) + return max(calculated, natural_reset) + else: + # No natural reset known, use absolute offset from now + return now + abs(cap.cooldown_value) + + elif cap.cooldown_mode == CooldownMode.FIXED: + # Fixed duration from now + calculated = now + cap.cooldown_value + return calculated + + return None diff --git a/src/rotator_library/usage/limits/engine.py b/src/rotator_library/usage/limits/engine.py new file mode 100644 index 00000000..2d60a369 --- /dev/null +++ b/src/rotator_library/usage/limits/engine.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Limit engine for orchestrating limit checks. + +Central component that runs all limit checkers and determines +if a credential is available for use. +""" + +import logging +from typing import Dict, List, Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from ..config import ProviderUsageConfig +from ..tracking.windows import WindowManager +from .base import LimitChecker +from .concurrent import ConcurrentLimitChecker +from .window_limits import WindowLimitChecker +from .cooldowns import CooldownChecker +from .fair_cycle import FairCycleChecker +from .custom_caps import CustomCapChecker +from ...error_handler import mask_credential +from ...error_handler import mask_credential + +lib_logger = logging.getLogger("rotator_library") + + +class LimitEngine: + """ + Central engine for limit checking. + + Orchestrates all limit checkers and provides a single entry point + for determining credential availability. + """ + + def __init__( + self, + config: ProviderUsageConfig, + window_manager: WindowManager, + ): + """ + Initialize limit engine. + + Args: + config: Provider usage configuration + window_manager: WindowManager for window-based checks + """ + self._config = config + self._window_manager = window_manager + + # Initialize all limit checkers + # Order matters: concurrent first (fast check), then others + # Note: WindowLimitChecker is optional - only included if window_limits_enabled + self._checkers: List[LimitChecker] = [ + ConcurrentLimitChecker(), + CooldownChecker(), + ] + + # Window limit checker - kept as reference for info purposes, + # only added to blocking checkers if explicitly enabled + self._window_checker = WindowLimitChecker(window_manager) + if config.window_limits_enabled: + self._checkers.append(self._window_checker) + + # Custom caps and fair cycle always active + self._custom_cap_checker = CustomCapChecker(config.custom_caps, window_manager) + self._fair_cycle_checker = FairCycleChecker(config.fair_cycle, window_manager) + self._checkers.append(self._custom_cap_checker) + self._checkers.append(self._fair_cycle_checker) + + # Quick access to specific checkers + self._concurrent_checker = self._checkers[0] + self._cooldown_checker = self._checkers[1] + + def check_all( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check all limits for a credential. + + Runs all limit checkers in order and returns the first failure, + or success if all pass. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating overall pass/fail + """ + for checker in self._checkers: + result = checker.check(state, model, quota_group) + if not result.allowed: + lib_logger.debug( + f"Credential {mask_credential(state.accessor, style='full')} blocked by {checker.name}: {result.reason}" + ) + return result + + return LimitCheckResult.ok() + + def check_specific( + self, + checker_name: str, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check a specific limit type. + + Args: + checker_name: Name of the checker ("cooldowns", "window_limits", etc.) + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult from the specified checker + """ + for checker in self._checkers: + if checker.name == checker_name: + return checker.check(state, model, quota_group) + + # Unknown checker - return ok + return LimitCheckResult.ok() + + def get_available_candidates( + self, + states: List[CredentialState], + model: str, + quota_group: Optional[str] = None, + ) -> List[CredentialState]: + """ + Filter credentials to only those passing all limits. + + Args: + states: List of credential states to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + List of available credential states + """ + available = [] + for state in states: + result = self.check_all(state, model, quota_group) + if result.allowed: + available.append(state) + + return available + + def get_blocking_info( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> Dict[str, LimitCheckResult]: + """ + Get detailed blocking info for each limit type. + + Useful for debugging and status reporting. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + Dict mapping checker name to its result + """ + results = {} + for checker in self._checkers: + results[checker.name] = checker.check(state, model, quota_group) + return results + + def reset_all( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset all limits for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + for checker in self._checkers: + checker.reset(state, model, quota_group) + + @property + def concurrent_checker(self) -> ConcurrentLimitChecker: + """Get the concurrent limit checker.""" + return self._concurrent_checker + + @property + def cooldown_checker(self) -> CooldownChecker: + """Get the cooldown checker.""" + return self._cooldown_checker + + @property + def window_checker(self) -> WindowLimitChecker: + """Get the window limit checker.""" + return self._window_checker + + @property + def custom_cap_checker(self) -> CustomCapChecker: + """Get the custom cap checker.""" + return self._custom_cap_checker + + @property + def fair_cycle_checker(self) -> FairCycleChecker: + """Get the fair cycle checker.""" + return self._fair_cycle_checker + + def add_checker(self, checker: LimitChecker) -> None: + """ + Add a custom limit checker. + + Allows extending the limit system with custom logic. + + Args: + checker: LimitChecker implementation to add + """ + self._checkers.append(checker) + + def remove_checker(self, name: str) -> bool: + """ + Remove a limit checker by name. + + Args: + name: Name of the checker to remove + + Returns: + True if removed, False if not found + """ + for i, checker in enumerate(self._checkers): + if checker.name == name: + del self._checkers[i] + return True + return False diff --git a/src/rotator_library/usage/limits/fair_cycle.py b/src/rotator_library/usage/limits/fair_cycle.py new file mode 100644 index 00000000..62e51cbd --- /dev/null +++ b/src/rotator_library/usage/limits/fair_cycle.py @@ -0,0 +1,386 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Fair cycle limit checker. + +Ensures credentials are used fairly by blocking exhausted ones +until all credentials in the pool are exhausted. +""" + +import time +import logging +from typing import Dict, List, Optional, Set + +from ..types import ( + CredentialState, + LimitCheckResult, + LimitResult, + FairCycleState, + GlobalFairCycleState, + TrackingMode, + FAIR_CYCLE_GLOBAL_KEY, +) +from ..config import FairCycleConfig +from ..tracking.windows import WindowManager +from ...error_handler import mask_credential +from ...error_handler import mask_credential +from .base import LimitChecker + +lib_logger = logging.getLogger("rotator_library") + + +class FairCycleChecker(LimitChecker): + """ + Checks fair cycle constraints. + + Blocks credentials that have been "exhausted" (quota used or long cooldown) + until all credentials in the pool have been exhausted, then resets the cycle. + """ + + def __init__( + self, config: FairCycleConfig, window_manager: Optional[WindowManager] = None + ): + """ + Initialize fair cycle checker. + + Args: + config: Fair cycle configuration + window_manager: WindowManager for getting window limits (optional) + """ + self._config = config + self._window_manager = window_manager + # Global cycle state per provider + self._global_state: Dict[str, Dict[str, GlobalFairCycleState]] = {} + + @property + def name(self) -> str: + return "fair_cycle" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is blocked by fair cycle. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + if not self._config.enabled: + return LimitCheckResult.ok() + + group_key = self._resolve_tracking_key(model, quota_group) + fc_state = state.fair_cycle.get(group_key) + + # Check quota-based exhaustion (cycle_request_count >= window.limit * threshold) + # This is separate from explicit exhaustion marking + if fc_state and not fc_state.exhausted: + quota_limit = self._get_quota_limit(state, model, quota_group) + if quota_limit is not None: + threshold = int(quota_limit * self._config.quota_threshold) + if fc_state.cycle_request_count >= threshold: + # Mark as exhausted due to quota threshold + now = time.time() + fc_state.exhausted = True + fc_state.exhausted_at = now + fc_state.exhausted_reason = "quota_threshold" + lib_logger.info( + f"Credential {mask_credential(state.accessor, style='full')} fair-cycle exhausted for {group_key}: " + f"cycle_request_count ({fc_state.cycle_request_count}) >= " + f"quota_threshold ({threshold})" + ) + + # Not exhausted = allowed + if fc_state is None or not fc_state.exhausted: + return LimitCheckResult.ok() + + # Exhausted - check if cycle should reset + provider = state.provider + global_state = self._get_global_state(provider, group_key) + + # Check if cycle has expired + if self._should_reset_cycle(global_state): + # Don't block - cycle will be reset + return LimitCheckResult.ok() + + # Still blocked by fair cycle + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_FAIR_CYCLE, + reason=f"Fair cycle: exhausted for '{group_key}' - waiting for other credentials", + blocked_until=None, # Depends on other credentials + ) + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset fair cycle state for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + group_key = self._resolve_tracking_key(model or "", quota_group) + + if quota_group or model: + if group_key in state.fair_cycle: + fc_state = state.fair_cycle[group_key] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + else: + # Reset all + for fc_state in state.fair_cycle.values(): + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + def check_all_exhausted( + self, + provider: str, + group_key: str, + all_states: List[CredentialState], + priorities: Optional[Dict[str, int]] = None, + ) -> bool: + """ + Check if all credentials in the pool are exhausted. + + Args: + provider: Provider name + group_key: Model or quota group + all_states: All credential states for this provider + priorities: Optional priority filter + + Returns: + True if all are exhausted + """ + # Filter by tier if not cross-tier + if priorities and not self._config.cross_tier: + # Group by priority tier + priority_groups: Dict[int, List[CredentialState]] = {} + for state in all_states: + p = priorities.get(state.stable_id, 999) + priority_groups.setdefault(p, []).append(state) + + # Check each priority group separately + for priority, group_states in priority_groups.items(): + if not self._all_exhausted_in_group(group_states, group_key): + return False + return True + else: + return self._all_exhausted_in_group(all_states, group_key) + + def reset_cycle( + self, + provider: str, + group_key: str, + all_states: List[CredentialState], + ) -> None: + """ + Reset the fair cycle for all credentials. + + Args: + provider: Provider name + group_key: Model or quota group + all_states: All credential states to reset + """ + now = time.time() + + for state in all_states: + if group_key in state.fair_cycle: + fc_state = state.fair_cycle[group_key] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + # Update global state + global_state = self._get_global_state(provider, group_key) + global_state.cycle_start = now + global_state.all_exhausted_at = None + global_state.cycle_count += 1 + + lib_logger.info( + f"Fair cycle reset for {provider}/{group_key}, cycle #{global_state.cycle_count}" + ) + + def mark_all_exhausted( + self, + provider: str, + group_key: str, + ) -> None: + """ + Record that all credentials are now exhausted. + + Args: + provider: Provider name + group_key: Model or quota group + """ + global_state = self._get_global_state(provider, group_key) + global_state.all_exhausted_at = time.time() + + lib_logger.info(f"All credentials exhausted for {provider}/{group_key}") + + def get_tracking_key(self, model: str, quota_group: Optional[str]) -> str: + """Get the fair cycle tracking key for a request.""" + return self._resolve_tracking_key(model, quota_group) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _get_global_state( + self, + provider: str, + group_key: str, + ) -> GlobalFairCycleState: + """Get or create global fair cycle state.""" + if provider not in self._global_state: + self._global_state[provider] = {} + + if group_key not in self._global_state[provider]: + self._global_state[provider][group_key] = GlobalFairCycleState( + cycle_start=time.time() + ) + + return self._global_state[provider][group_key] + + def _resolve_tracking_key( + self, + model: str, + quota_group: Optional[str], + ) -> str: + """Resolve tracking key based on fair cycle mode.""" + if self._config.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return quota_group or model + + def _should_reset_cycle(self, global_state: GlobalFairCycleState) -> bool: + """Check if cycle duration has expired.""" + now = time.time() + return now >= global_state.cycle_start + self._config.duration + + def _get_quota_limit( + self, + state: CredentialState, + model: str, + quota_group: Optional[str], + ) -> Optional[int]: + """ + Get the quota limit for fair cycle comparison. + + Uses the smallest window limit available (most restrictive). + + Args: + state: Credential state + model: Model name + quota_group: Quota group (optional) + + Returns: + The quota limit, or None if unknown + """ + if self._window_manager is None: + return None + + primary_def = self._window_manager.get_primary_definition() + if primary_def is None: + return None + + group_key = quota_group or model + windows = None + + # Check group first if quota_group is specified + if quota_group: + group_stats = state.get_group_stats(quota_group, create=False) + if group_stats: + windows = group_stats.windows + + # Fall back to model + if windows is None: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + + if windows is None: + return None + + # Get limit from primary window + primary_window = self._window_manager.get_active_window( + windows, primary_def.name + ) + if primary_window and primary_window.limit: + return primary_window.limit + + # If no primary window limit, try to find smallest limit from any window + smallest_limit: Optional[int] = None + for window in windows.values(): + if window.limit is not None: + if smallest_limit is None or window.limit < smallest_limit: + smallest_limit = window.limit + + return smallest_limit + + def _all_exhausted_in_group( + self, + states: List[CredentialState], + group_key: str, + ) -> bool: + """Check if all credentials in a group are exhausted.""" + if not states: + return True + + for state in states: + fc_state = state.fair_cycle.get(group_key) + if fc_state is None or not fc_state.exhausted: + return False + + return True + + def get_global_state_dict(self) -> Dict[str, Dict[str, Dict]]: + """ + Get global state for serialization. + + Returns: + Dict suitable for JSON serialization + """ + result = {} + for provider, groups in self._global_state.items(): + result[provider] = {} + for group_key, state in groups.items(): + result[provider][group_key] = { + "cycle_start": state.cycle_start, + "all_exhausted_at": state.all_exhausted_at, + "cycle_count": state.cycle_count, + } + return result + + def load_global_state_dict(self, data: Dict[str, Dict[str, Dict]]) -> None: + """ + Load global state from serialized data. + + Args: + data: Dict from get_global_state_dict() + """ + self._global_state.clear() + for provider, groups in data.items(): + self._global_state[provider] = {} + for group_key, state_data in groups.items(): + self._global_state[provider][group_key] = GlobalFairCycleState( + cycle_start=state_data.get("cycle_start", 0), + all_exhausted_at=state_data.get("all_exhausted_at"), + cycle_count=state_data.get("cycle_count", 0), + ) diff --git a/src/rotator_library/usage/limits/window_limits.py b/src/rotator_library/usage/limits/window_limits.py new file mode 100644 index 00000000..31433f8c --- /dev/null +++ b/src/rotator_library/usage/limits/window_limits.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Window limit checker. + +Checks if a credential has exceeded its request quota for a window. +""" + +from typing import Dict, List, Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult, WindowStats +from ..tracking.windows import WindowManager +from .base import LimitChecker + + +class WindowLimitChecker(LimitChecker): + """ + Checks window-based request limits. + + Blocks credentials that have exhausted their quota in any + tracked window. + """ + + def __init__(self, window_manager: WindowManager): + """ + Initialize window limit checker. + + Args: + window_manager: WindowManager instance for window operations + """ + self._windows = window_manager + + @property + def name(self) -> str: + return "window_limits" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if any window limit is exceeded. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + group_key = quota_group or model + + # Check all configured windows + for definition in self._windows.definitions.values(): + windows = None + + if definition.applies_to == "model": + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + elif definition.applies_to == "group": + group_stats = state.get_group_stats(group_key, create=False) + if group_stats: + windows = group_stats.windows + + if windows is None: + continue + + window = windows.get(definition.name) + if window is None or window.limit is None: + continue + + active = self._windows.get_active_window(windows, definition.name) + if active is None: + continue + + if active.request_count >= active.limit: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_WINDOW, + reason=( + f"Window '{definition.name}' exhausted " + f"({active.request_count}/{active.limit})" + ), + blocked_until=active.reset_at, + ) + + return LimitCheckResult.ok() + + def get_remaining( + self, + state: CredentialState, + window_name: str, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Optional[int]: + """ + Get remaining requests in a specific window. + + Args: + state: Credential state + window_name: Name of window to check + model: Model to check + quota_group: Quota group to check + + Returns: + Remaining requests, or None if unlimited/unknown + """ + group_key = quota_group or model or "" + definition = self._windows.definitions.get(window_name) + + windows = None + if definition: + if definition.applies_to == "model" and model: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + elif definition.applies_to == "group": + group_stats = state.get_group_stats(group_key, create=False) + if group_stats: + windows = group_stats.windows + + if windows is None: + return None + + return self._windows.get_window_remaining(windows, window_name) + + def get_all_remaining( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Dict[str, Optional[int]]: + """ + Get remaining requests for all windows. + + Args: + state: Credential state + model: Model to check + quota_group: Quota group to check + + Returns: + Dict of window_name -> remaining (None if unlimited) + """ + result = {} + for definition in self._windows.definitions.values(): + result[definition.name] = self.get_remaining( + state, + definition.name, + model=model, + quota_group=quota_group, + ) + return result diff --git a/src/rotator_library/usage/manager.py b/src/rotator_library/usage/manager.py new file mode 100644 index 00000000..1fbf2da1 --- /dev/null +++ b/src/rotator_library/usage/manager.py @@ -0,0 +1,2142 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +UsageManager facade and CredentialContext. + +This is the main public API for the usage tracking system. +""" + +import asyncio +import logging +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Set, Union + +from ..core.types import CredentialInfo, RequestCompleteResult +from ..error_handler import ClassifiedError, classify_error, mask_credential + +from .types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + LimitCheckResult, + RotationMode, + LimitResult, + FAIR_CYCLE_GLOBAL_KEY, + TrackingMode, + ResetMode, +) +from .config import ( + ProviderUsageConfig, + load_provider_usage_config, + get_default_windows, + CapMode, +) +from .identity.registry import CredentialRegistry +from .tracking.engine import TrackingEngine +from .tracking.windows import WindowManager +from .limits.engine import LimitEngine +from .selection.engine import SelectionEngine +from .persistence.storage import UsageStorage +from .integration.hooks import HookDispatcher +from .integration.api import UsageAPI + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialContext: + """ + Context manager for credential lifecycle. + + Handles: + - Automatic release on exit + - Success/failure recording + - Usage tracking + + Usage: + async with usage_manager.acquire_credential(provider, model) as ctx: + response = await make_request(ctx.credential) + ctx.mark_success(response) + """ + + def __init__( + self, + manager: "UsageManager", + credential: str, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + ): + self._manager = manager + self.credential = credential # The accessor (path or key) + self.stable_id = stable_id + self.model = model + self.quota_group = quota_group + self._acquired_at = time.time() + self._result: Optional[Literal["success", "failure"]] = None + self._response: Optional[Any] = None + self._response_headers: Optional[Dict[str, Any]] = None + self._error: Optional[ClassifiedError] = None + self._tokens: Dict[str, int] = {} + self._approx_cost: float = 0.0 + + async def __aenter__(self) -> "CredentialContext": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + # Always release the credential + await self._manager._release_credential(self.stable_id, self.model) + + success = False + error = self._error + response = self._response + + if self._result == "success": + success = True + elif self._result == "failure": + success = False + elif exc_val is not None: + error = classify_error(exc_val) + success = False + else: + success = True + + await self._manager._handle_request_complete( + stable_id=self.stable_id, + model=self.model, + quota_group=self.quota_group, + success=success, + response=response, + response_headers=self._response_headers, + error=error, + prompt_tokens=self._tokens.get("prompt", 0), + completion_tokens=self._tokens.get("completion", 0), + thinking_tokens=self._tokens.get("thinking", 0), + prompt_tokens_cache_read=self._tokens.get("prompt_cached", 0), + prompt_tokens_cache_write=self._tokens.get("prompt_cache_write", 0), + approx_cost=self._approx_cost, + ) + + return False # Don't suppress exceptions + + def mark_success( + self, + response: Any = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """Mark request as successful.""" + self._result = "success" + self._response = response + self._response_headers = response_headers + self._tokens = { + "prompt": prompt_tokens, + "completion": completion_tokens, + "thinking": thinking_tokens, + "prompt_cached": prompt_tokens_cache_read, + "prompt_cache_write": prompt_tokens_cache_write, + } + self._approx_cost = approx_cost + + def mark_failure(self, error: ClassifiedError) -> None: + """Mark request as failed.""" + self._result = "failure" + self._error = error + + +class UsageManager: + """ + Main facade for usage tracking and credential selection. + + This class provides the primary interface for: + - Acquiring credentials for requests (with context manager) + - Recording usage and failures + - Selecting the best available credential + - Managing cooldowns and limits + + Example: + manager = UsageManager(provider="gemini", file_path="usage.json") + await manager.initialize(credentials) + + async with manager.acquire_credential(model="gemini-pro") as ctx: + response = await make_request(ctx.credential) + ctx.mark_success(response, prompt_tokens=100, completion_tokens=50) + """ + + def __init__( + self, + provider: str, + file_path: Optional[Union[str, Path]] = None, + provider_plugins: Optional[Dict[str, Any]] = None, + config: Optional[ProviderUsageConfig] = None, + max_concurrent_per_key: Optional[int] = None, + ): + """ + Initialize UsageManager. + + Args: + provider: Provider name (e.g., "gemini", "openai") + file_path: Path to usage.json file + provider_plugins: Dict of provider plugin classes + config: Optional pre-built configuration + max_concurrent_per_key: Max concurrent requests per credential + """ + self.provider = provider + self._provider_plugins = provider_plugins or {} + self._max_concurrent_per_key = max_concurrent_per_key + + # Load configuration + if config: + self._config = config + else: + self._config = load_provider_usage_config(provider, self._provider_plugins) + + # Initialize components + self._registry = CredentialRegistry() + self._window_manager = WindowManager( + window_definitions=self._config.windows or get_default_windows() + ) + self._tracking = TrackingEngine(self._window_manager, self._config) + self._limits = LimitEngine(self._config, self._window_manager) + self._selection = SelectionEngine( + self._config, self._limits, self._window_manager + ) + self._hooks = HookDispatcher(self._provider_plugins) + self._api = UsageAPI(self) + + # Storage + if file_path: + self._storage = UsageStorage(file_path) + else: + self._storage = None + + # State + self._states: Dict[str, CredentialState] = {} + self._initialized = False + self._lock = asyncio.Lock() + self._loaded_from_storage = False + self._loaded_count = 0 + self._quota_exhausted_summary: Dict[str, Dict[str, float]] = {} + self._quota_exhausted_task: Optional[asyncio.Task] = None + self._quota_exhausted_lock = asyncio.Lock() + self._save_task: Optional[asyncio.Task] = None + self._save_lock = asyncio.Lock() + + # Concurrency control: per-credential locks and conditions for waiting + self._key_locks: Dict[str, asyncio.Lock] = {} + self._key_conditions: Dict[str, asyncio.Condition] = {} + + # Track which credentials are currently active in the proxy session + # (vs. historical data loaded from storage) + self._active_stable_ids: Set[str] = set() + + async def initialize( + self, + credentials: List[str], + priorities: Optional[Dict[str, int]] = None, + tiers: Optional[Dict[str, str]] = None, + ) -> None: + """ + Initialize with credentials. + + Args: + credentials: List of credential accessors (paths or keys) + priorities: Optional priority overrides (accessor -> priority) + tiers: Optional tier overrides (accessor -> tier name) + """ + async with self._lock: + if self._initialized: + return + # Load persisted state + if self._storage: + ( + self._states, + fair_cycle_global, + loaded_from_storage, + ) = await self._storage.load() + self._loaded_from_storage = loaded_from_storage + self._loaded_count = len(self._states) + if fair_cycle_global: + self._limits.fair_cycle_checker.load_global_state_dict( + fair_cycle_global + ) + + # Register credentials and track active ones + self._active_stable_ids.clear() + for accessor in credentials: + stable_id = self._registry.get_stable_id(accessor, self.provider) + self._active_stable_ids.add(stable_id) + + # Create or update state + if stable_id not in self._states: + self._states[stable_id] = CredentialState( + stable_id=stable_id, + provider=self.provider, + accessor=accessor, + created_at=time.time(), + ) + else: + # Update accessor in case it changed + self._states[stable_id].accessor = accessor + + # Apply overrides + if priorities and accessor in priorities: + self._states[stable_id].priority = priorities[accessor] + if tiers and accessor in tiers: + self._states[stable_id].tier = tiers[accessor] + + # Debug: Log state before max_concurrent calculation + old_max_concurrent = self._states[stable_id].max_concurrent + + # Always set max concurrent, applying priority multiplier + # Uses configured value or defaults to 1 if not set + base_concurrent = ( + self._max_concurrent_per_key + if self._max_concurrent_per_key is not None + else 1 + ) + priority = self._states[stable_id].priority + multiplier = self._config.get_effective_multiplier(priority) + effective_concurrent = base_concurrent * multiplier + self._states[stable_id].max_concurrent = effective_concurrent + + # Clean up stale windows from tier changes + # This handles the case where a credential's tier changed and now has + # windows from the old tier that should be removed + # Also populate window_definitions for each credential based on tier + total_removed = 0 + for stable_id, state in self._states.items(): + # Populate window definitions for this credential's tier + state.window_definitions = self._get_window_definitions_for_state(state) + + valid_windows = self._get_valid_window_names_for_state(state) + removed = self._cleanup_stale_windows_for_state(state, valid_windows) + total_removed += removed + + if total_removed > 0: + lib_logger.info( + f"Cleaned up {total_removed} stale window(s) for {self.provider}" + ) + # Mark storage dirty so changes get saved + if self._storage: + self._storage.mark_dirty() + + self._initialized = True + lib_logger.debug( + f"UsageManager initialized for {self.provider} with {len(credentials)} credentials" + ) + + async def acquire_credential( + self, + model: str, + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + candidates: Optional[List[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> CredentialContext: + """ + Acquire a credential for a request. + + Returns a context manager that automatically releases + the credential and records success/failure. + + This method will wait for credentials to become available if all are + currently busy (at max_concurrent), up until the deadline. + + Args: + model: Model to use + quota_group: Optional quota group (uses model name if None) + exclude: Set of stable_ids to exclude (by accessor) + candidates: Optional list of credential accessors to consider. + If provided, only these will be considered for selection. + priorities: Optional priority overrides (accessor -> priority). + If provided, overrides the stored priorities. + deadline: Request deadline timestamp + + Returns: + CredentialContext for use with async with + + Raises: + NoAvailableKeysError: If no credentials available within deadline + """ + from ..error_handler import NoAvailableKeysError + + # Convert accessor-based exclude to stable_id-based + exclude_ids = set() + if exclude: + for accessor in exclude: + stable_id = self._registry.get_stable_id(accessor, self.provider) + exclude_ids.add(stable_id) + + # Filter states to only candidates if provided + if candidates is not None: + candidate_ids = set() + for accessor in candidates: + stable_id = self._registry.get_stable_id(accessor, self.provider) + candidate_ids.add(stable_id) + states_to_check = { + sid: state + for sid, state in self._states.items() + if sid in candidate_ids + } + else: + states_to_check = self._get_active_states() + + # Convert accessor-based priorities to stable_id-based + priority_overrides = None + if priorities: + priority_overrides = {} + for accessor, priority in priorities.items(): + stable_id = self._registry.get_stable_id(accessor, self.provider) + priority_overrides[stable_id] = priority + + # Normalize model name for consistent tracking and selection + normalized_model = self._normalize_model(model) + + # Ensure key conditions exist for all candidates + for stable_id in states_to_check: + if stable_id not in self._key_conditions: + self._key_conditions[stable_id] = asyncio.Condition() + self._key_locks[stable_id] = asyncio.Lock() + + # Main acquisition loop - continues until deadline + while time.time() < deadline: + # Try to select a credential + stable_id = self._selection.select( + provider=self.provider, + model=normalized_model, + states=states_to_check, + quota_group=quota_group, + exclude=exclude_ids, + priorities=priority_overrides, + deadline=deadline, + ) + + if stable_id is not None: + state = self._states[stable_id] + lock = self._key_locks.get(stable_id) + + if lock: + async with lock: + # Double-check availability after acquiring lock + if ( + state.max_concurrent is None + or state.active_requests < state.max_concurrent + ): + state.active_requests += 1 + lib_logger.debug( + f"Acquired credential {mask_credential(state.accessor, style='full')} " + f"for {model} (active: {state.active_requests}" + f"{f'/{state.max_concurrent}' if state.max_concurrent else ''})" + ) + return CredentialContext( + manager=self, + credential=state.accessor, + stable_id=stable_id, + model=normalized_model, + quota_group=quota_group, + ) + else: + # No lock configured, just increment + state.active_requests += 1 + return CredentialContext( + manager=self, + credential=state.accessor, + stable_id=stable_id, + model=normalized_model, + quota_group=quota_group, + ) + + # No credential available - need to wait + # Find the best credential to wait for (prefer lowest usage) + best_wait_id = None + best_usage = float("inf") + + for sid, state in states_to_check.items(): + if sid in exclude_ids: + continue + if ( + state.max_concurrent is not None + and state.active_requests >= state.max_concurrent + ): + # This one is busy but might become free + usage = state.totals.request_count + if usage < best_usage: + best_usage = usage + best_wait_id = sid + + if best_wait_id is None: + # All credentials blocked by cooldown or limits, not just concurrency + # Check if waiting for cooldown makes sense + soonest_cooldown = self._get_soonest_cooldown_end( + states_to_check, normalized_model, quota_group + ) + + if soonest_cooldown is not None: + remaining_budget = deadline - time.time() + wait_needed = soonest_cooldown - time.time() + + if wait_needed > remaining_budget: + # No credential will be available in time + lib_logger.warning( + f"All credentials on cooldown. Soonest in {wait_needed:.1f}s, " + f"budget {remaining_budget:.1f}s. Failing fast." + ) + break + + # Wait for cooldown to expire + lib_logger.info( + f"All credentials on cooldown. Waiting {wait_needed:.1f}s..." + ) + await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) + continue + + # No cooldowns and no busy keys - truly no keys available + break + + # Wait on the best credential's condition + condition = self._key_conditions.get(best_wait_id) + if condition: + lib_logger.debug( + f"All credentials busy. Waiting for {mask_credential(self._states[best_wait_id].accessor, style='full')}..." + ) + try: + async with condition: + remaining_budget = deadline - time.time() + if remaining_budget <= 0: + break + # Wait for notification or timeout (max 1 second to re-check) + await asyncio.wait_for( + condition.wait(), + timeout=min(1.0, remaining_budget), + ) + lib_logger.debug("Credential released. Re-evaluating...") + except asyncio.TimeoutError: + # Timeout is normal, just retry the loop + lib_logger.debug("Wait timed out. Re-evaluating...") + else: + # No condition, just sleep briefly and retry + await asyncio.sleep(0.1) + + # Deadline exceeded + raise NoAvailableKeysError( + f"Could not acquire a credential for {self.provider}/{model} " + f"within the time budget." + ) + + def _get_soonest_cooldown_end( + self, + states: Dict[str, CredentialState], + model: str, + quota_group: Optional[str], + ) -> Optional[float]: + """Get the soonest cooldown end time across all credentials.""" + soonest = None + now = time.time() + group_key = quota_group or model + + for state in states.values(): + # Check model-specific cooldown + cooldown = state.get_cooldown(group_key) + if cooldown and cooldown.until > now: + if soonest is None or cooldown.until < soonest: + soonest = cooldown.until + + # Check global cooldown + global_cooldown = state.get_cooldown() + if global_cooldown and global_cooldown.until > now: + if soonest is None or global_cooldown.until < soonest: + soonest = global_cooldown.until + + return soonest + + async def get_best_credential( + self, + model: str, + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Get the best available credential without acquiring. + + Useful for checking availability or manual acquisition. + + Args: + model: Model to use + quota_group: Optional quota group + exclude: Set of accessors to exclude + deadline: Request deadline + + Returns: + Credential accessor, or None if none available + """ + # Convert exclude from accessors to stable_ids + exclude_ids = set() + if exclude: + for accessor in exclude: + stable_id = self._registry.get_stable_id(accessor, self.provider) + exclude_ids.add(stable_id) + + # Normalize model name for consistent selection + normalized_model = self._normalize_model(model) + + stable_id = self._selection.select( + provider=self.provider, + model=normalized_model, + states=self._get_active_states(), + quota_group=quota_group, + exclude=exclude_ids, + deadline=deadline, + ) + + if stable_id is None: + return None + + return self._states[stable_id].accessor + + async def record_usage( + self, + accessor: str, + model: str, + success: bool, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + error: Optional[ClassifiedError] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Record usage for a credential (manual recording). + + Use this for manual tracking outside of context manager. + + Args: + accessor: Credential accessor + model: Model used + success: Whether request succeeded + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + thinking_tokens: Thinking tokens used + prompt_tokens_cache_read: Cached prompt tokens read + prompt_tokens_cache_write: Cached prompt tokens written + approx_cost: Approximate cost + error: Classified error if failed + quota_group: Quota group + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + + if success: + await self._record_success( + stable_id, + model, + quota_group, + prompt_tokens, + completion_tokens, + thinking_tokens, + prompt_tokens_cache_read, + prompt_tokens_cache_write, + approx_cost, + ) + else: + await self._record_failure( + stable_id, + model, + quota_group, + error, + request_count=1, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + async def _handle_request_complete( + self, + stable_id: str, + model: str, + quota_group: Optional[str], + success: bool, + response: Optional[Any], + response_headers: Optional[Dict[str, Any]], + error: Optional[ClassifiedError], + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """Handle provider hooks and record request outcome.""" + state = self._states.get(stable_id) + if not state: + return + + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + hook_result: Optional[RequestCompleteResult] = None + if self._hooks: + hook_result = await self._hooks.dispatch_request_complete( + provider=self.provider, + credential=state.accessor, + model=normalized_model, + success=success, + response=response, + error=error, + ) + + request_count = 1 + cooldown_override = None + force_exhausted = False + + if hook_result: + if hook_result.count_override is not None: + request_count = max(0, hook_result.count_override) + cooldown_override = hook_result.cooldown_override + force_exhausted = hook_result.force_exhausted + + if not success and error and hook_result is None: + if error.error_type in {"server_error", "api_connection"}: + request_count = 0 + + if request_count == 0: + prompt_tokens = 0 + completion_tokens = 0 + thinking_tokens = 0 + prompt_tokens_cache_read = 0 + prompt_tokens_cache_write = 0 + approx_cost = 0.0 + + if cooldown_override: + await self._tracking.apply_cooldown( + state=state, + reason="provider_hook", + duration=cooldown_override, + model_or_group=group_key, + source="provider_hook", + ) + + if force_exhausted: + await self._tracking.mark_exhausted( + state=state, + model_or_group=self._resolve_fair_cycle_key( + group_key or normalized_model + ), + reason="provider_hook", + ) + + if success: + await self._record_success( + stable_id, + normalized_model, + quota_group, + prompt_tokens, + completion_tokens, + thinking_tokens, + prompt_tokens_cache_read, + prompt_tokens_cache_write, + approx_cost, + response_headers=response_headers, + request_count=request_count, + ) + else: + await self._record_failure( + stable_id, + normalized_model, + quota_group, + error, + request_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + async def apply_cooldown( + self, + accessor: str, + duration: float, + reason: str = "manual", + model_or_group: Optional[str] = None, + ) -> None: + """ + Apply a cooldown to a credential. + + Args: + accessor: Credential accessor + duration: Cooldown duration in seconds + reason: Reason for cooldown + model_or_group: Scope of cooldown + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if state: + await self._tracking.apply_cooldown( + state=state, + reason=reason, + duration=duration, + model_or_group=model_or_group, + ) + await self._save_if_needed() + + async def get_availability_stats( + self, + model: str, + quota_group: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get availability statistics for credentials. + + Args: + model: Model to check + quota_group: Quota group + + Returns: + Dict with availability info + """ + return self._selection.get_availability_stats( + provider=self.provider, + model=model, + states=self._get_active_states(), + quota_group=quota_group, + ) + + async def get_stats_for_endpoint( + self, + model_filter: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get comprehensive stats suitable for status endpoints. + + Returns credential states, usage windows, cooldowns, and fair cycle state. + + Args: + model_filter: Optional model to filter stats for + + Returns: + Dict with comprehensive statistics + """ + stats = { + "provider": self.provider, + "credential_count": len(self._active_stable_ids), + "rotation_mode": self._config.rotation_mode.value, + "credentials": {}, + } + + stats.update( + { + "active_count": 0, + "exhausted_count": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_cost": None, + "quota_groups": {}, + } + ) + + for stable_id, state in self._states.items(): + # Skip credentials not currently active in the proxy + if stable_id not in self._active_stable_ids: + continue + + now = time.time() + + # Determine credential status with proper granularity + status = "active" + has_global_cooldown = False + has_group_cooldown = False + fc_exhausted_groups = [] + + # Check cooldowns (global vs per-group) + for key, cooldown in state.cooldowns.items(): + if cooldown.until > now: + if key == "_global_": + has_global_cooldown = True + else: + has_group_cooldown = True + + # Check fair cycle per group + for group_key, fc_state in state.fair_cycle.items(): + if fc_state.exhausted: + fc_exhausted_groups.append(group_key) + + # Determine final status + known_groups = set(state.group_usage.keys()) if state.group_usage else set() + + if has_global_cooldown: + status = "cooldown" + elif fc_exhausted_groups: + # Check if ALL known groups are exhausted + if known_groups and set(fc_exhausted_groups) >= known_groups: + status = "exhausted" + else: + status = "mixed" # Some groups available + elif has_group_cooldown: + status = "cooldown" + + cred_stats = { + "stable_id": stable_id, + "accessor_masked": mask_credential(state.accessor, style="full"), + "full_path": state.accessor, + "identifier": mask_credential(state.accessor, style="full"), + "email": state.display_name, + "tier": state.tier, + "priority": state.priority, + "active_requests": state.active_requests, + "status": status, + "totals": { + "request_count": state.totals.request_count, + "success_count": state.totals.success_count, + "failure_count": state.totals.failure_count, + "prompt_tokens": state.totals.prompt_tokens, + "completion_tokens": state.totals.completion_tokens, + "thinking_tokens": state.totals.thinking_tokens, + "output_tokens": state.totals.output_tokens, + "prompt_tokens_cache_read": state.totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": state.totals.prompt_tokens_cache_write, + "total_tokens": state.totals.total_tokens, + "approx_cost": state.totals.approx_cost, + "first_used_at": state.totals.first_used_at, + "last_used_at": state.totals.last_used_at, + }, + "model_usage": {}, + "group_usage": {}, + "cooldowns": {}, + "fair_cycle": {}, + } + + stats["total_requests"] += state.totals.request_count + stats["tokens"]["output"] += state.totals.output_tokens + stats["tokens"]["input_cached"] += state.totals.prompt_tokens_cache_read + # prompt_tokens in LiteLLM = uncached tokens (not total input) + # Total input = prompt_tokens + prompt_tokens_cache_read + stats["tokens"]["input_uncached"] += state.totals.prompt_tokens + if state.totals.approx_cost: + stats["approx_cost"] = ( + stats["approx_cost"] or 0.0 + ) + state.totals.approx_cost + + if status == "active": + stats["active_count"] += 1 + elif status == "exhausted": + stats["exhausted_count"] += 1 + + # Add model usage stats + for model_key, model_stats in state.model_usage.items(): + model_windows = {} + for window_name, window in model_stats.windows.items(): + model_windows[window_name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "limit": window.limit, + "remaining": window.remaining, + "max_recorded_requests": window.max_recorded_requests, + "max_recorded_at": window.max_recorded_at, + "reset_at": window.reset_at, + "approx_cost": window.approx_cost, + "first_used_at": window.first_used_at, + "last_used_at": window.last_used_at, + } + cred_stats["model_usage"][model_key] = { + "windows": model_windows, + "totals": { + "request_count": model_stats.totals.request_count, + "success_count": model_stats.totals.success_count, + "failure_count": model_stats.totals.failure_count, + "prompt_tokens": model_stats.totals.prompt_tokens, + "completion_tokens": model_stats.totals.completion_tokens, + "thinking_tokens": model_stats.totals.thinking_tokens, + "output_tokens": model_stats.totals.output_tokens, + "prompt_tokens_cache_read": model_stats.totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": model_stats.totals.prompt_tokens_cache_write, + "total_tokens": model_stats.totals.total_tokens, + "approx_cost": model_stats.totals.approx_cost, + "first_used_at": model_stats.totals.first_used_at, + "last_used_at": model_stats.totals.last_used_at, + }, + } + + # Add group usage stats + for group_key, group_stats in state.group_usage.items(): + group_windows = {} + for window_name, window in group_stats.windows.items(): + group_windows[window_name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "limit": window.limit, + "remaining": window.remaining, + "max_recorded_requests": window.max_recorded_requests, + "max_recorded_at": window.max_recorded_at, + "reset_at": window.reset_at, + "approx_cost": window.approx_cost, + "first_used_at": window.first_used_at, + "last_used_at": window.last_used_at, + } + cred_stats["group_usage"][group_key] = { + "windows": group_windows, + "totals": { + "request_count": group_stats.totals.request_count, + "success_count": group_stats.totals.success_count, + "failure_count": group_stats.totals.failure_count, + "prompt_tokens": group_stats.totals.prompt_tokens, + "completion_tokens": group_stats.totals.completion_tokens, + "thinking_tokens": group_stats.totals.thinking_tokens, + "output_tokens": group_stats.totals.output_tokens, + "prompt_tokens_cache_read": group_stats.totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": group_stats.totals.prompt_tokens_cache_write, + "total_tokens": group_stats.totals.total_tokens, + "approx_cost": group_stats.totals.approx_cost, + "first_used_at": group_stats.totals.first_used_at, + "last_used_at": group_stats.totals.last_used_at, + }, + } + + # Add per-group status info for this credential + group_data = cred_stats["group_usage"][group_key] + + # Fair cycle status for this group + fc_state = state.fair_cycle.get(group_key) + group_data["fair_cycle_exhausted"] = ( + fc_state.exhausted if fc_state else False + ) + group_data["fair_cycle_reason"] = ( + fc_state.exhausted_reason + if fc_state and fc_state.exhausted + else None + ) + + # Group-specific cooldown + group_cooldown = state.cooldowns.get(group_key) + if group_cooldown and group_cooldown.is_active: + group_data["cooldown_remaining"] = int( + group_cooldown.remaining_seconds + ) + group_data["cooldown_source"] = group_cooldown.source + else: + group_data["cooldown_remaining"] = None + group_data["cooldown_source"] = None + + # Custom cap info for this group + cap = self._limits.custom_cap_checker.get_cap_for( + state, group_key, group_key + ) + if cap: + # Get usage from primary window + primary_window = group_windows.get( + self._window_manager.get_primary_definition().name + if self._window_manager.get_primary_definition() + else "5h" + ) + cap_used = ( + primary_window.get("request_count", 0) if primary_window else 0 + ) + cap_limit = cap.max_requests + # Resolve cap limit based on mode + api_limit = primary_window.get("limit") if primary_window else None + if cap.max_requests_mode == CapMode.OFFSET: + if api_limit: + cap_limit = max(0, api_limit + cap.max_requests) + else: + cap_limit = max(0, abs(cap.max_requests)) + elif cap.max_requests_mode == CapMode.PERCENTAGE: + if api_limit: + cap_limit = max(0, int(api_limit * cap.max_requests / 100)) + else: + cap_limit = 0 + else: # ABSOLUTE + cap_limit = max(0, cap.max_requests) + group_data["custom_cap"] = { + "limit": cap_limit, + "used": cap_used, + "remaining": max(0, cap_limit - cap_used), + } + else: + group_data["custom_cap"] = None + + # Aggregate quota group stats with per-window breakdown + group_agg = stats["quota_groups"].setdefault( + group_key, + { + "tiers": {}, # Credential tier counts (provider-level) + "windows": {}, # Per-window aggregated stats + "fair_cycle_summary": { # FC status across all credentials + "exhausted_count": 0, + "total_count": 0, + }, + }, + ) + + # Update fair cycle summary for this group + group_agg["fair_cycle_summary"]["total_count"] += 1 + if group_data.get("fair_cycle_exhausted"): + group_agg["fair_cycle_summary"]["exhausted_count"] += 1 + + # Add credential to tier count (provider-level, not per-window) + tier_key = state.tier or "unknown" + tier_stats = group_agg["tiers"].setdefault( + tier_key, + {"priority": state.priority or 0, "total": 0}, + ) + tier_stats["total"] += 1 + + # Aggregate per-window stats + for window_name, window in group_windows.items(): + window_agg = group_agg[ + "windows" + ].setdefault( + window_name, + { + "total_used": 0, + "total_remaining": 0, + "total_max": 0, + "remaining_pct": None, + "tier_availability": {}, # Per-window credential availability + }, + ) + + # Track tier availability for this window + tier_avail = window_agg["tier_availability"].setdefault( + tier_key, + {"total": 0, "available": 0}, + ) + tier_avail["total"] += 1 + + # Check if this credential has quota remaining in this window + limit = window.get("limit") + if limit is not None: + used = window["request_count"] + remaining = max(0, limit - used) + window_agg["total_used"] += used + window_agg["total_remaining"] += remaining + window_agg["total_max"] += limit + + # Credential has availability if remaining > 0 + if remaining > 0: + tier_avail["available"] += 1 + else: + # No limit = unlimited = always available + tier_avail["available"] += 1 + + # Add active cooldowns + for key, cooldown in state.cooldowns.items(): + if cooldown.is_active: + cred_stats["cooldowns"][key] = { + "reason": cooldown.reason, + "remaining_seconds": cooldown.remaining_seconds, + "source": cooldown.source, + } + + # Add fair cycle state + for key, fc_state in state.fair_cycle.items(): + if model_filter and key != model_filter: + continue + cred_stats["fair_cycle"][key] = { + "exhausted": fc_state.exhausted, + "cycle_request_count": fc_state.cycle_request_count, + } + + # Sort group_usage by quota limit (lowest first), then alphabetically + # This ensures detail view matches the global summary sort order + def group_sort_key(item): + group_name, group_data = item + windows = group_data.get("windows", {}) + if not windows: + return (float("inf"), group_name) # No windows = sort last + + # Find minimum limit across windows + min_limit = float("inf") + for window_data in windows.values(): + limit = window_data.get("limit") + if limit is not None and limit > 0: + min_limit = min(min_limit, limit) + + return (min_limit, group_name) + + sorted_group_usage = dict( + sorted(cred_stats["group_usage"].items(), key=group_sort_key) + ) + cred_stats["group_usage"] = sorted_group_usage + + stats["credentials"][stable_id] = cred_stats + + # Calculate remaining percentages for each window in each quota group + for group_stats in stats["quota_groups"].values(): + for window_stats in group_stats.get("windows", {}).values(): + if window_stats["total_max"] > 0: + window_stats["remaining_pct"] = round( + window_stats["total_remaining"] + / window_stats["total_max"] + * 100, + 1, + ) + + total_input = ( + stats["tokens"]["input_cached"] + stats["tokens"]["input_uncached"] + ) + stats["tokens"]["input_cache_pct"] = ( + round(stats["tokens"]["input_cached"] / total_input * 100, 1) + if total_input > 0 + else 0 + ) + + return stats + + def _get_provider_plugin_instance(self) -> Optional[Any]: + """Get provider plugin instance for the current provider.""" + if not self._provider_plugins: + return None + + # Provider plugins dict maps provider name -> plugin class or instance + plugin = self._provider_plugins.get(self.provider) + if plugin is None: + return None + + # If it's a class, instantiate it (singleton via metaclass); if already an instance, use directly + if isinstance(plugin, type): + return plugin() + return plugin + + def _normalize_model(self, model: str) -> str: + """ + Normalize model name using provider's mapping. + + Converts internal model names (e.g., claude-sonnet-4-5-thinking) to + public-facing names (e.g., claude-sonnet-4.5) for consistent storage + and tracking. + + Args: + model: Model name (with or without provider prefix) + + Returns: + Normalized model name (provider prefix preserved if present) + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): + return plugin_instance.normalize_model_for_tracking(model) + + return model + + def _get_model_quota_group(self, model: str) -> Optional[str]: + """ + Get the quota group for a model, if the provider defines one. + + Models in the same quota group share a single quota pool. + For example, all Claude models in Antigravity share the same daily quota. + + Args: + model: Model name (with or without provider prefix) + + Returns: + Group name (e.g., "claude") or None if not grouped + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): + return plugin_instance.get_model_quota_group(model) + + return None + + def get_model_quota_group(self, model: str) -> Optional[str]: + """Public helper to get quota group for a model.""" + normalized_model = self._normalize_model(model) + return self._get_model_quota_group(normalized_model) + + def _get_grouped_models(self, group: str) -> List[str]: + """ + Get all model names in a quota group (with provider prefix), normalized. + + Returns only public-facing model names, deduplicated. Internal variants + (e.g., claude-sonnet-4-5-thinking) are normalized to their public name + (e.g., claude-sonnet-4.5). + + Args: + group: Group name (e.g., "claude") + + Returns: + List of normalized, deduplicated model names with provider prefix + (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): + models = plugin_instance.get_models_in_quota_group(group) + + # Normalize and deduplicate + if hasattr(plugin_instance, "normalize_model_for_tracking"): + seen: Set[str] = set() + normalized: List[str] = [] + for m in models: + prefixed = f"{self.provider}/{m}" + norm = plugin_instance.normalize_model_for_tracking(prefixed) + if norm not in seen: + seen.add(norm) + normalized.append(norm) + return normalized + + # Fallback: just add provider prefix + return [f"{self.provider}/{m}" for m in models] + + return [] + + async def save(self, force: bool = False) -> bool: + """ + Save usage data to file. + + Args: + force: Force save even if debounce not elapsed + + Returns: + True if saved successfully + """ + if self._storage: + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + return await self._storage.save( + self._states, fair_cycle_global, force=force + ) + return False + + async def get_usage_snapshot(self) -> Dict[str, Dict[str, Any]]: + """ + Get a lightweight usage snapshot keyed by accessor. + + Returns: + Dict mapping accessor -> usage metadata. + """ + async with self._lock: + snapshot: Dict[str, Dict[str, Any]] = {} + for state in self._states.values(): + snapshot[state.accessor] = { + "last_used_ts": state.totals.last_used_at or 0, + } + return snapshot + + async def shutdown(self) -> None: + """Shutdown and save any pending data.""" + await self.save(force=True) + + async def reload_from_disk(self) -> None: + """ + Force reload usage data from disk. + + Useful when wanting fresh stats without making external API calls. + This reloads persisted state while preserving current credential registrations. + """ + if not self._storage: + lib_logger.debug( + f"reload_from_disk: No storage configured for {self.provider}" + ) + return + + async with self._lock: + # Load persisted state + loaded_states, fair_cycle_global, _ = await self._storage.load() + + # Merge loaded state with current state + # Keep current accessors but update usage data + for stable_id, loaded_state in loaded_states.items(): + if stable_id in self._states: + # Update usage data from loaded state + current = self._states[stable_id] + current.model_usage = loaded_state.model_usage + current.group_usage = loaded_state.group_usage + current.totals = loaded_state.totals + current.cooldowns = loaded_state.cooldowns + current.fair_cycle = loaded_state.fair_cycle + current.last_updated = loaded_state.last_updated + else: + # New credential from disk, add it + self._states[stable_id] = loaded_state + + # Reload fair cycle global state + if fair_cycle_global: + self._limits.fair_cycle_checker.load_global_state_dict( + fair_cycle_global + ) + + lib_logger.info( + f"Reloaded usage data from disk for {self.provider}: " + f"{len(self._states)} credentials" + ) + + async def update_quota_baseline( + self, + accessor: str, + model: str, + quota_max_requests: Optional[int] = None, + quota_reset_ts: Optional[float] = None, + quota_used: Optional[int] = None, + quota_group: Optional[str] = None, + force: bool = False, + apply_exhaustion: bool = False, + ) -> Optional[Dict[str, Any]]: + """ + Update quota baseline from provider API response. + + Called by provider plugins after receiving rate limit headers or + quota information from API responses. + + Args: + accessor: Credential accessor (path or key) + model: Model name + quota_max_requests: Max requests allowed in window + quota_reset_ts: When quota resets (Unix timestamp) + quota_used: Current used count from API + quota_group: Optional quota group (uses model if None) + force: If True, always use API values (for manual refresh). + If False (default), use max(local, api) to prevent stale + API data from overwriting accurate local counts during + background fetches. + apply_exhaustion: If True, apply cooldown for exhausted quota. + Provider controls when this is set based on its semantics + (e.g., Antigravity only on initial fetch, others always + when remaining == 0). + + Returns: + Cooldown info dict if cooldown was applied, None otherwise + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if not state: + lib_logger.warning( + f"update_quota_baseline: Unknown credential {accessor[:20]}..." + ) + return None + + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + primary_def = self._window_manager.get_primary_definition() + + # Update windows based on quota scope + # If group_key exists, quota is at group level - only update group stats + # We can't know which model the requests went to from API-level quota + if group_key: + group_stats = state.get_group_stats(group_key) + if primary_def: + group_window = self._window_manager.get_or_create_window( + group_stats.windows, primary_def.name + ) + self._apply_quota_update( + group_window, quota_max_requests, quota_reset_ts, quota_used, force + ) + + # Sync timing to all model windows in this group + # All models share the same started_at/reset_at/limit as the group + self._sync_group_timing_to_models( + state, group_key, group_window, primary_def.name + ) + else: + # No quota group - model IS the quota scope, update model stats + model_stats = state.get_model_stats(normalized_model) + if primary_def: + model_window = self._window_manager.get_or_create_window( + model_stats.windows, primary_def.name + ) + self._apply_quota_update( + model_window, quota_max_requests, quota_reset_ts, quota_used, force + ) + + # Mark state as updated + state.last_updated = time.time() + + # Apply cooldown if provider indicates exhaustion + # Provider controls when apply_exhaustion is set based on its semantics + if apply_exhaustion: + cooldown_target = group_key or normalized_model + if quota_reset_ts: + await self._tracking.apply_cooldown( + state=state, + reason="quota_exhausted", + until=quota_reset_ts, + model_or_group=cooldown_target, + source="api_quota", + ) + + await self._queue_quota_exhausted_log( + accessor=accessor, + group_key=cooldown_target, + quota_reset_ts=quota_reset_ts, + ) + + await self._save_if_needed() + + return { + "cooldown_until": quota_reset_ts, + "reason": "quota_exhausted", + "model": model, + "cooldown_hours": max(0.0, (quota_reset_ts - time.time()) / 3600), + } + else: + # ERROR: Provider says exhausted but no reset timestamp! + lib_logger.error( + f"Quota exhausted for {cooldown_target} on " + f"{mask_credential(accessor, style='full')} but no reset_timestamp " + f"provided by API - cannot apply cooldown" + ) + + await self._save_if_needed() + + return None + + # ========================================================================= + # WINDOW CLEANUP + # ========================================================================= + + def _get_valid_window_names_for_state(self, state: CredentialState) -> Set[str]: + """ + Get the set of valid window names for a credential based on its tier. + + Uses the provider's usage_reset_configs to determine which window(s) + should exist for this credential's tier/priority. + + Args: + state: The credential state + + Returns: + Set of valid window names (e.g., {"5h"} or {"168h"}) + """ + plugin_class = self._provider_plugins.get(self.provider) + if not plugin_class: + # No plugin - use current config windows as valid + return {w.name for w in self._config.windows} + + # Check if provider defines usage_reset_configs + usage_reset_configs = getattr(plugin_class, "usage_reset_configs", None) + if not usage_reset_configs: + # No tier-specific configs - use current config windows + return {w.name for w in self._config.windows} + + # Get tier priorities mapping + tier_priorities = getattr(plugin_class, "tier_priorities", {}) + default_priority = getattr(plugin_class, "default_tier_priority", 10) + + # Resolve credential's priority from tier + priority = state.priority + if priority is None and state.tier: + priority = tier_priorities.get(state.tier, default_priority) + if priority is None: + priority = default_priority + + # Find matching usage config for this priority + matching_config = None + for key, config in usage_reset_configs.items(): + if isinstance(key, frozenset) and priority in key: + matching_config = config + break + if matching_config is None: + matching_config = usage_reset_configs.get("default") + + if matching_config is None: + # No matching config - use current windows + return {w.name for w in self._config.windows} + + # Generate window name from window_seconds + window_seconds = matching_config.window_seconds + if window_seconds == 86400: + window_name = "daily" + elif window_seconds % 3600 == 0: + window_name = f"{window_seconds // 3600}h" + else: + window_name = "window" + + return {window_name} + + def _get_window_definitions_for_state( + self, state: CredentialState + ) -> List["WindowDefinition"]: + """ + Get the window definitions for a credential based on its tier. + + Uses the provider's usage_reset_configs to determine which window(s) + should be used for this credential's tier/priority. + + Args: + state: The credential state + + Returns: + List of WindowDefinition objects for this credential's tier + """ + from .config import WindowDefinition + + plugin_class = self._provider_plugins.get(self.provider) + if not plugin_class: + # No plugin - use current config windows + return list(self._config.windows) if self._config.windows else [] + + # Check if provider defines usage_reset_configs + usage_reset_configs = getattr(plugin_class, "usage_reset_configs", None) + if not usage_reset_configs: + # No tier-specific configs - use current config windows + return list(self._config.windows) if self._config.windows else [] + + # Get tier priorities mapping + tier_priorities = getattr(plugin_class, "tier_priorities", {}) + default_priority = getattr(plugin_class, "default_tier_priority", 10) + + # Resolve credential's priority from tier + priority = state.priority + if priority is None and state.tier: + priority = tier_priorities.get(state.tier, default_priority) + if priority is None: + priority = default_priority + + # Find matching usage config for this priority + matching_config = None + for key, config in usage_reset_configs.items(): + if isinstance(key, frozenset) and priority in key: + matching_config = config + break + if matching_config is None: + matching_config = usage_reset_configs.get("default") + + if matching_config is None: + # No matching config - use current windows + return list(self._config.windows) if self._config.windows else [] + + # Generate window name from window_seconds + window_seconds = matching_config.window_seconds + if window_seconds == 86400: + window_name = "daily" + elif window_seconds % 3600 == 0: + window_name = f"{window_seconds // 3600}h" + else: + window_name = "window" + + # Create WindowDefinition for this tier + return [ + WindowDefinition.rolling( + name=window_name, + duration_seconds=window_seconds, + is_primary=True, + applies_to=matching_config.field_name or "model", + ) + ] + + def _cleanup_stale_windows_for_state( + self, state: CredentialState, valid_windows: Set[str] + ) -> int: + """ + Remove windows that don't match the credential's current tier config. + + This handles the case where a credential's tier changed and now has + windows from the old tier that should be cleaned up. + + Args: + state: The credential state to clean up + valid_windows: Set of valid window names for this credential + + Returns: + Number of windows removed + """ + removed_count = 0 + + # Clean up model_usage windows + for model_name, model_stats in state.model_usage.items(): + windows_to_remove = [ + name for name in model_stats.windows.keys() if name not in valid_windows + ] + for window_name in windows_to_remove: + del model_stats.windows[window_name] + removed_count += 1 + lib_logger.debug( + f"Removed stale window '{window_name}' from model " + f"'{model_name}' for {mask_credential(state.accessor, style='full')}" + ) + + # Clean up group_usage windows + for group_name, group_stats in state.group_usage.items(): + windows_to_remove = [ + name for name in group_stats.windows.keys() if name not in valid_windows + ] + for window_name in windows_to_remove: + del group_stats.windows[window_name] + removed_count += 1 + lib_logger.debug( + f"Removed stale window '{window_name}' from group " + f"'{group_name}' for {mask_credential(state.accessor, style='full')}" + ) + + return removed_count + + async def clear_cooldown_if_exists( + self, + accessor: str, + model_or_group: Optional[str] = None, + ) -> bool: + """ + Clear a cooldown if one exists for the given scope. + + Used during baseline refresh to clear cooldowns when API + reports quota is available. + + Args: + accessor: Credential accessor (path or key) + model_or_group: Scope of cooldown to clear (None = global) + + Returns: + True if a cooldown was cleared, False if none existed + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if not state: + return False + + key = model_or_group or "_global_" + cooldown = state.cooldowns.get(key) + + if cooldown and cooldown.is_active: + await self._tracking.clear_cooldown(state, model_or_group) + lib_logger.info( + f"Cleared cooldown for {key} on " + f"{mask_credential(accessor, style='full')} - API shows quota available " + f"(was: {cooldown.reason}, source: {cooldown.source})" + ) + return True + + return False + + def _apply_quota_update( + self, + window: WindowStats, + quota_max_requests: Optional[int], + quota_reset_ts: Optional[float], + quota_used: Optional[int], + force: bool, + ) -> None: + """Apply quota update to a window.""" + if quota_max_requests is not None: + window.limit = quota_max_requests + + # Determine if there's actual usage (either API-reported or local) + has_usage = ( + quota_used is not None and quota_used > 0 + ) or window.request_count > 0 + + # Only set started_at and reset_at if there's actual usage + # This prevents bogus reset times for unused windows + if has_usage: + if quota_reset_ts is not None: + window.reset_at = quota_reset_ts + # Set started_at to now if not already set (API shows usage we don't have locally) + if window.started_at is None: + window.started_at = time.time() + + if quota_used is not None: + if force: + synced_count = quota_used + else: + synced_count = max( + window.request_count, + quota_used, + window.success_count + window.failure_count, + ) + self._reconcile_window_counts(window, synced_count) + + def _reconcile_window_counts(self, window: WindowStats, request_count: int) -> None: + """Reconcile window counts after quota sync.""" + local_total = window.success_count + window.failure_count + window.request_count = request_count + if local_total == 0 and request_count > 0: + window.success_count = request_count + window.failure_count = 0 + return + + if request_count < local_total: + failure_count = min(window.failure_count, request_count) + success_count = max(0, request_count - failure_count) + window.success_count = success_count + window.failure_count = failure_count + return + + if request_count > local_total: + window.success_count += request_count - local_total + + def _sync_group_timing_to_models( + self, + state: "CredentialState", + group_key: str, + group_window: "WindowStats", + window_name: str, + ) -> None: + """ + Sync timing from group window to all model windows in the group. + + Called after updating a group window to ensure all models have + consistent started_at, reset_at, and limit values. All models + in a quota group share the same timing since they share API quota. + + Args: + state: Credential state containing model stats + group_key: Quota group name + group_window: The authoritative group window + window_name: Name of the window to sync (e.g., "5h") + """ + models_in_group = self._get_grouped_models(group_key) + for model_name in models_in_group: + model_stats = state.get_model_stats(model_name, create=False) + if model_stats: + model_window = model_stats.windows.get(window_name) + if model_window: + model_window.started_at = group_window.started_at + model_window.reset_at = group_window.reset_at + if group_window.limit is not None: + model_window.limit = group_window.limit + + # ========================================================================= + # PROPERTIES + # ========================================================================= + + @property + def config(self) -> ProviderUsageConfig: + """Get the configuration.""" + return self._config + + @property + def registry(self) -> CredentialRegistry: + """Get the credential registry.""" + return self._registry + + @property + def api(self) -> UsageAPI: + """Get the usage API facade.""" + return self._api + + @property + def initialized(self) -> bool: + """Check if the manager is initialized.""" + return self._initialized + + @property + def tracking(self) -> TrackingEngine: + """Get the tracking engine.""" + return self._tracking + + @property + def limits(self) -> LimitEngine: + """Get the limit engine.""" + return self._limits + + @property + def window_manager(self) -> WindowManager: + """Get the window manager.""" + return self._window_manager + + @property + def selection(self) -> SelectionEngine: + """Get the selection engine.""" + return self._selection + + @property + def states(self) -> Dict[str, CredentialState]: + """Get all credential states.""" + return self._states + + @property + def loaded_from_storage(self) -> bool: + """Whether usage data was loaded from storage.""" + return self._loaded_from_storage + + @property + def loaded_credentials(self) -> int: + """Number of credentials loaded from storage.""" + return self._loaded_count + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _get_active_states(self) -> Dict[str, CredentialState]: + """ + Get only active credential states. + + Returns states for credentials currently registered with the proxy, + excluding stale/historical credentials that exist only in storage. + Use this for rotation, selection, and availability checking. + + Returns: + Dict of stable_id -> CredentialState for active credentials only + """ + return { + sid: state + for sid, state in self._states.items() + if sid in self._active_stable_ids + } + + def _resolve_fair_cycle_key(self, group_key: str) -> str: + """Resolve fair cycle tracking key based on config.""" + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return group_key + + async def _release_credential(self, stable_id: str, model: str) -> None: + """Release a credential after use and notify waiting tasks.""" + state = self._states.get(stable_id) + if not state: + return + + # Decrement active requests + lock = self._key_locks.get(stable_id) + if lock: + async with lock: + state.active_requests = max(0, state.active_requests - 1) + else: + state.active_requests = max(0, state.active_requests - 1) + + # Log release with current state + remaining = state.active_requests + max_concurrent = state.max_concurrent + lib_logger.info( + f"Released credential {mask_credential(state.accessor, style='full')} " + f"from {model} (remaining concurrent: {remaining}" + f"{f'/{max_concurrent}' if max_concurrent else ''})" + ) + + # Notify all tasks waiting on this credential's condition + condition = self._key_conditions.get(stable_id) + if condition: + async with condition: + condition.notify_all() + + async def _queue_quota_exhausted_log( + self, accessor: str, group_key: str, quota_reset_ts: float + ) -> None: + async with self._quota_exhausted_lock: + masked = mask_credential(accessor, style="full") + if masked not in self._quota_exhausted_summary: + self._quota_exhausted_summary[masked] = {} + self._quota_exhausted_summary[masked][group_key] = quota_reset_ts + + if self._quota_exhausted_task is None or self._quota_exhausted_task.done(): + self._quota_exhausted_task = asyncio.create_task( + self._flush_quota_exhausted_log() + ) + + async def _flush_quota_exhausted_log(self) -> None: + await asyncio.sleep(0.2) + async with self._quota_exhausted_lock: + summary = self._quota_exhausted_summary + self._quota_exhausted_summary = {} + + if not summary: + return + + now = time.time() + parts = [] + for accessor, groups in sorted(summary.items()): + group_parts = [] + for group, reset_ts in sorted(groups.items()): + hours = max(0.0, (reset_ts - now) / 3600) if reset_ts else 0.0 + group_parts.append(f"{group} {hours:.1f}h") + parts.append(f"{accessor}[{', '.join(group_parts)}]") + + lib_logger.info(f"Quota exhausted: {', '.join(parts)}") + + async def _save_if_needed(self) -> None: + """Persist state if storage is configured.""" + if not self._storage: + return + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + saved = await self._storage.save(self._states, fair_cycle_global) + if not saved: + await self._schedule_save_flush() + + async def _schedule_save_flush(self) -> None: + if self._save_task and not self._save_task.done(): + return + self._save_task = asyncio.create_task(self._flush_save()) + + async def _flush_save(self) -> None: + async with self._save_lock: + await asyncio.sleep(self._storage.save_debounce_seconds) + if not self._storage: + return + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + await self._storage.save_if_dirty(self._states, fair_cycle_global) + + async def _record_success( + self, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + response_headers: Optional[Dict[str, Any]] = None, + request_count: int = 1, + ) -> None: + """Record a successful request.""" + state = self._states.get(stable_id) + if state: + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + await self._tracking.record_success( + state=state, + model=normalized_model, + quota_group=group_key, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + response_headers=response_headers, + request_count=request_count, + ) + + # Apply custom cap cooldown if exceeded + cap_result = self._limits.custom_cap_checker.check( + state, normalized_model, group_key + ) + if ( + not cap_result.allowed + and cap_result.result == LimitResult.BLOCKED_CUSTOM_CAP + and cap_result.blocked_until + ): + await self._tracking.apply_cooldown( + state=state, + reason="custom_cap", + until=cap_result.blocked_until, + model_or_group=group_key or normalized_model, + source="custom_cap", + ) + + await self._save_if_needed() + + async def _record_failure( + self, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + error: Optional[ClassifiedError] = None, + request_count: int = 1, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """Record a failed request.""" + state = self._states.get(stable_id) + if not state: + return + + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + # Determine cooldown from error + cooldown_duration = None + quota_reset = None + mark_exhausted = False + + if error: + cooldown_duration = error.retry_after + quota_reset = error.quota_reset_timestamp + + # Mark exhausted for quota errors with long cooldown + if error.error_type == "quota_exceeded": + if ( + cooldown_duration + and cooldown_duration >= self._config.exhaustion_cooldown_threshold + ): + mark_exhausted = True + + # Log quota exhaustion like legacy system + masked = mask_credential(state.accessor, style="full") + cooldown_target = group_key or normalized_model + + if quota_reset: + reset_dt = datetime.fromtimestamp(quota_reset, tz=timezone.utc) + hours = max(0.0, (quota_reset - time.time()) / 3600) + lib_logger.info( + f"Quota exhausted for '{cooldown_target}' on {masked}. " + f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" + ) + elif cooldown_duration: + hours = cooldown_duration / 3600 + lib_logger.info( + f"Quota exhausted on {masked} for '{cooldown_target}'. " + f"Cooldown: {cooldown_duration:.0f}s ({hours:.1f}h)" + ) + + await self._tracking.record_failure( + state=state, + model=normalized_model, + error_type=error.error_type if error else "unknown", + quota_group=group_key, + cooldown_duration=cooldown_duration, + quota_reset_timestamp=quota_reset, + mark_exhausted=mark_exhausted, + request_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Log fair cycle marking like legacy system + if mark_exhausted and self._config.fair_cycle.enabled: + fc_key = group_key or normalized_model + # Resolve fair cycle tracking key based on config + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + fc_key = FAIR_CYCLE_GLOBAL_KEY + exhausted_count = sum( + 1 + for s in self._states.values() + if fc_key in s.fair_cycle and s.fair_cycle[fc_key].exhausted + ) + masked = mask_credential(state.accessor, style="full") + lib_logger.info( + f"Fair cycle: marked {masked} exhausted for {fc_key} ({exhausted_count} total)" + ) + + await self._save_if_needed() diff --git a/src/rotator_library/usage/persistence/__init__.py b/src/rotator_library/usage/persistence/__init__.py new file mode 100644 index 00000000..9bdf7d76 --- /dev/null +++ b/src/rotator_library/usage/persistence/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Usage data persistence.""" + +from .storage import UsageStorage + +__all__ = ["UsageStorage"] diff --git a/src/rotator_library/usage/persistence/storage.py b/src/rotator_library/usage/persistence/storage.py new file mode 100644 index 00000000..4b75374c --- /dev/null +++ b/src/rotator_library/usage/persistence/storage.py @@ -0,0 +1,496 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage data storage. + +Handles loading and saving usage data to JSON files. +""" + +import asyncio +import json +import logging +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from ..types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + CooldownInfo, + FairCycleState, + GlobalFairCycleState, + StorageSchema, +) +from ...utils.resilient_io import ResilientStateWriter, safe_read_json +from ...error_handler import mask_credential + +lib_logger = logging.getLogger("rotator_library") + + +def _format_timestamp(ts: Optional[float]) -> Optional[str]: + """Format a unix timestamp as a human-readable local time string.""" + if ts is None: + return None + try: + # Use local timezone for human readability + dt = datetime.fromtimestamp(ts) + return dt.strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return None + + +class UsageStorage: + """ + Handles persistence of usage data to JSON files. + + Features: + - Async file I/O with aiofiles + - Atomic writes (write to temp, then rename) + - Automatic schema migration + - Debounced saves to reduce I/O + """ + + CURRENT_SCHEMA_VERSION = 2 + + def __init__( + self, + file_path: Union[str, Path], + save_debounce_seconds: float = 5.0, + ): + """ + Initialize storage. + + Args: + file_path: Path to the usage.json file + save_debounce_seconds: Minimum time between saves + """ + self.file_path = Path(file_path) + self.save_debounce_seconds = save_debounce_seconds + + self._last_save: float = 0 + self._pending_save: bool = False + self._save_lock = asyncio.Lock() + self._dirty: bool = False + self._writer = ResilientStateWriter(self.file_path, lib_logger) + + async def load( + self, + ) -> tuple[Dict[str, CredentialState], Dict[str, Dict[str, Any]], bool]: + """ + Load usage data from file. + + Returns: + Tuple of (states dict, fair_cycle_global dict, loaded_from_file bool) + """ + if not self.file_path.exists(): + return {}, {}, False + + try: + async with self._file_lock(): + data = safe_read_json(self.file_path, lib_logger, parse_json=True) + + if not data: + return {}, {}, True + + # Check schema version + version = data.get("schema_version", 1) + if version < self.CURRENT_SCHEMA_VERSION: + lib_logger.info( + f"Migrating usage data from v{version} to v{self.CURRENT_SCHEMA_VERSION}" + ) + data = self._migrate(data, version) + + # Parse credentials + states = {} + for stable_id, cred_data in data.get("credentials", {}).items(): + state = self._parse_credential_state(stable_id, cred_data) + if state: + states[stable_id] = state + + lib_logger.info(f"Loaded {len(states)} credentials from {self.file_path}") + return states, data.get("fair_cycle_global", {}), True + + except json.JSONDecodeError as e: + lib_logger.error(f"Failed to parse usage file: {e}") + return {}, {}, True + except Exception as e: + lib_logger.error(f"Failed to load usage file: {e}") + return {}, {}, True + + async def save( + self, + states: Dict[str, CredentialState], + fair_cycle_global: Optional[Dict[str, Dict[str, Any]]] = None, + force: bool = False, + ) -> bool: + """ + Save usage data to file. + + Args: + states: Dict of stable_id -> CredentialState + fair_cycle_global: Global fair cycle state + force: Force save even if debounce not elapsed + + Returns: + True if saved, False if skipped or failed + """ + now = time.time() + + # Check debounce + if not force and (now - self._last_save) < self.save_debounce_seconds: + self._dirty = True + return False + + async with self._save_lock: + try: + # Build storage data + data = { + "schema_version": self.CURRENT_SCHEMA_VERSION, + "updated_at": datetime.now(timezone.utc).isoformat(), + "credentials": {}, + "accessor_index": {}, + "fair_cycle_global": fair_cycle_global or {}, + } + + for stable_id, state in states.items(): + data["credentials"][stable_id] = self._serialize_credential_state( + state + ) + data["accessor_index"][state.accessor] = stable_id + + saved = self._writer.write(data) + + if saved: + self._last_save = now + self._dirty = False + lib_logger.debug( + f"Saved {len(states)} credentials to {self.file_path}" + ) + return True + + self._dirty = True + return False + + except Exception as e: + lib_logger.error(f"Failed to save usage file: {e}") + return False + + async def save_if_dirty( + self, + states: Dict[str, CredentialState], + fair_cycle_global: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> bool: + """ + Save if there are pending changes. + + Args: + states: Dict of stable_id -> CredentialState + fair_cycle_global: Global fair cycle state + + Returns: + True if saved, False otherwise + """ + if self._dirty: + return await self.save(states, fair_cycle_global, force=True) + return False + + def mark_dirty(self) -> None: + """Mark data as changed, needing save.""" + self._dirty = True + + @property + def is_dirty(self) -> bool: + """Check if there are unsaved changes.""" + return self._dirty + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _file_lock(self): + """Get a lock for file operations.""" + return self._save_lock + + def _migrate(self, data: Dict[str, Any], from_version: int) -> Dict[str, Any]: + """Migrate data from older schema versions.""" + if from_version == 1: + # v1 -> v2: Add accessor_index, restructure credentials + data["schema_version"] = 2 + data.setdefault("accessor_index", {}) + data.setdefault("fair_cycle_global", {}) + + # v1 used file paths as keys, v2 uses stable_ids + # For migration, treat paths as stable_ids + old_credentials = data.get("credentials", data.get("key_states", {})) + new_credentials = {} + + for key, cred_data in old_credentials.items(): + # Use path as temporary stable_id + stable_id = cred_data.get("stable_id", key) + new_credentials[stable_id] = cred_data + new_credentials[stable_id]["accessor"] = key + + data["credentials"] = new_credentials + + return data + + def _parse_window_stats(self, name: str, data: Dict[str, Any]) -> WindowStats: + """Parse window stats from storage data.""" + return WindowStats( + name=name, + request_count=data.get("request_count", 0), + success_count=data.get("success_count", 0), + failure_count=data.get("failure_count", 0), + prompt_tokens=data.get("prompt_tokens", 0), + completion_tokens=data.get("completion_tokens", 0), + thinking_tokens=data.get("thinking_tokens", 0), + output_tokens=data.get("output_tokens", 0), + prompt_tokens_cache_read=data.get("prompt_tokens_cache_read", 0), + prompt_tokens_cache_write=data.get("prompt_tokens_cache_write", 0), + total_tokens=data.get("total_tokens", 0), + approx_cost=data.get("approx_cost", 0.0), + started_at=data.get("started_at"), + reset_at=data.get("reset_at"), + limit=data.get("limit"), + max_recorded_requests=data.get("max_recorded_requests"), + max_recorded_at=data.get("max_recorded_at"), + first_used_at=data.get("first_used_at"), + last_used_at=data.get("last_used_at"), + ) + + def _serialize_window_stats(self, window: WindowStats) -> Dict[str, Any]: + """Serialize window stats for storage.""" + return { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "approx_cost": window.approx_cost, + "started_at": window.started_at, + "started_at_human": _format_timestamp(window.started_at), + "reset_at": window.reset_at, + "reset_at_human": _format_timestamp(window.reset_at), + "limit": window.limit, + "max_recorded_requests": window.max_recorded_requests, + "max_recorded_at": window.max_recorded_at, + "max_recorded_at_human": _format_timestamp(window.max_recorded_at), + "first_used_at": window.first_used_at, + "first_used_at_human": _format_timestamp(window.first_used_at), + "last_used_at": window.last_used_at, + "last_used_at_human": _format_timestamp(window.last_used_at), + } + + def _parse_total_stats(self, data: Dict[str, Any]) -> TotalStats: + """Parse total stats from storage data.""" + return TotalStats( + request_count=data.get("request_count", 0), + success_count=data.get("success_count", 0), + failure_count=data.get("failure_count", 0), + prompt_tokens=data.get("prompt_tokens", 0), + completion_tokens=data.get("completion_tokens", 0), + thinking_tokens=data.get("thinking_tokens", 0), + output_tokens=data.get("output_tokens", 0), + prompt_tokens_cache_read=data.get("prompt_tokens_cache_read", 0), + prompt_tokens_cache_write=data.get("prompt_tokens_cache_write", 0), + total_tokens=data.get("total_tokens", 0), + approx_cost=data.get("approx_cost", 0.0), + first_used_at=data.get("first_used_at"), + last_used_at=data.get("last_used_at"), + ) + + def _serialize_total_stats(self, totals: TotalStats) -> Dict[str, Any]: + """Serialize total stats for storage.""" + return { + "request_count": totals.request_count, + "success_count": totals.success_count, + "failure_count": totals.failure_count, + "prompt_tokens": totals.prompt_tokens, + "completion_tokens": totals.completion_tokens, + "thinking_tokens": totals.thinking_tokens, + "output_tokens": totals.output_tokens, + "prompt_tokens_cache_read": totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": totals.prompt_tokens_cache_write, + "total_tokens": totals.total_tokens, + "approx_cost": totals.approx_cost, + "first_used_at": totals.first_used_at, + "first_used_at_human": _format_timestamp(totals.first_used_at), + "last_used_at": totals.last_used_at, + "last_used_at_human": _format_timestamp(totals.last_used_at), + } + + def _parse_model_stats(self, data: Dict[str, Any]) -> ModelStats: + """Parse model stats from storage data.""" + windows = {} + for name, wdata in data.get("windows", {}).items(): + # Skip legacy "total" window - now tracked in totals + if name == "total": + continue + windows[name] = self._parse_window_stats(name, wdata) + + totals = self._parse_total_stats(data.get("totals", {})) + + return ModelStats(windows=windows, totals=totals) + + def _serialize_model_stats(self, stats: ModelStats) -> Dict[str, Any]: + """Serialize model stats for storage.""" + return { + "windows": { + name: self._serialize_window_stats(window) + for name, window in stats.windows.items() + }, + "totals": self._serialize_total_stats(stats.totals), + } + + def _parse_group_stats(self, data: Dict[str, Any]) -> GroupStats: + """Parse group stats from storage data.""" + windows = {} + for name, wdata in data.get("windows", {}).items(): + # Skip legacy "total" window - now tracked in totals + if name == "total": + continue + windows[name] = self._parse_window_stats(name, wdata) + + totals = self._parse_total_stats(data.get("totals", {})) + + return GroupStats(windows=windows, totals=totals) + + def _serialize_group_stats(self, stats: GroupStats) -> Dict[str, Any]: + """Serialize group stats for storage.""" + return { + "windows": { + name: self._serialize_window_stats(window) + for name, window in stats.windows.items() + }, + "totals": self._serialize_total_stats(stats.totals), + } + + def _parse_credential_state( + self, + stable_id: str, + data: Dict[str, Any], + ) -> Optional[CredentialState]: + """Parse a credential state from storage data.""" + try: + # Parse model_usage + model_usage = {} + for key, usage_data in data.get("model_usage", {}).items(): + model_usage[key] = self._parse_model_stats(usage_data) + + # Parse group_usage + group_usage = {} + for key, usage_data in data.get("group_usage", {}).items(): + group_usage[key] = self._parse_group_stats(usage_data) + + # Parse credential-level totals + totals = self._parse_total_stats(data.get("totals", {})) + + # Parse cooldowns + cooldowns = {} + for key, cdata in data.get("cooldowns", {}).items(): + cooldowns[key] = CooldownInfo( + reason=cdata.get("reason", "unknown"), + until=cdata.get("until", 0), + started_at=cdata.get("started_at", 0), + source=cdata.get("source", "system"), + model_or_group=cdata.get("model_or_group"), + backoff_count=cdata.get("backoff_count", 0), + ) + + # Parse fair cycle + fair_cycle = {} + for key, fcdata in data.get("fair_cycle", {}).items(): + fair_cycle[key] = FairCycleState( + exhausted=fcdata.get("exhausted", False), + exhausted_at=fcdata.get("exhausted_at"), + exhausted_reason=fcdata.get("exhausted_reason"), + cycle_request_count=fcdata.get("cycle_request_count", 0), + model_or_group=key, + ) + + return CredentialState( + stable_id=stable_id, + provider=data.get("provider", "unknown"), + accessor=data.get("accessor", stable_id), + display_name=data.get("display_name"), + tier=data.get("tier"), + priority=data.get("priority", 999), + model_usage=model_usage, + group_usage=group_usage, + totals=totals, + cooldowns=cooldowns, + fair_cycle=fair_cycle, + active_requests=0, # Always starts at 0 + max_concurrent=data.get("max_concurrent"), + created_at=data.get("created_at"), + last_updated=data.get("last_updated"), + ) + + except Exception as e: + lib_logger.warning( + f"Failed to parse credential {mask_credential(stable_id, style='full')}: {e}" + ) + return None + + def _serialize_credential_state(self, state: CredentialState) -> Dict[str, Any]: + """Serialize a credential state for storage.""" + # Serialize cooldowns (only active ones) + now = time.time() + cooldowns = {} + for key, cd in state.cooldowns.items(): + if cd.until > now: # Only save active cooldowns + cooldowns[key] = { + "reason": cd.reason, + "until": cd.until, + "until_human": _format_timestamp(cd.until), + "started_at": cd.started_at, + "started_at_human": _format_timestamp(cd.started_at), + "source": cd.source, + "model_or_group": cd.model_or_group, + "backoff_count": cd.backoff_count, + } + + # Serialize fair cycle + fair_cycle = {} + for key, fc in state.fair_cycle.items(): + fair_cycle[key] = { + "exhausted": fc.exhausted, + "exhausted_at": fc.exhausted_at, + "exhausted_at_human": _format_timestamp(fc.exhausted_at), + "exhausted_reason": fc.exhausted_reason, + "cycle_request_count": fc.cycle_request_count, + } + + return { + "provider": state.provider, + "accessor": state.accessor, + "display_name": state.display_name, + "tier": state.tier, + "priority": state.priority, + "model_usage": { + key: self._serialize_model_stats(stats) + for key, stats in state.model_usage.items() + }, + "group_usage": { + key: self._serialize_group_stats(stats) + for key, stats in state.group_usage.items() + }, + "totals": self._serialize_total_stats(state.totals), + "cooldowns": cooldowns, + "fair_cycle": fair_cycle, + "max_concurrent": state.max_concurrent, + "created_at": state.created_at, + "created_at_human": _format_timestamp(state.created_at), + "last_updated": state.last_updated, + "last_updated_human": _format_timestamp(state.last_updated), + } diff --git a/src/rotator_library/usage/selection/__init__.py b/src/rotator_library/usage/selection/__init__.py new file mode 100644 index 00000000..79824e51 --- /dev/null +++ b/src/rotator_library/usage/selection/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Credential selection and rotation strategies.""" + +from .engine import SelectionEngine +from .strategies.balanced import BalancedStrategy +from .strategies.sequential import SequentialStrategy + +__all__ = [ + "SelectionEngine", + "BalancedStrategy", + "SequentialStrategy", +] diff --git a/src/rotator_library/usage/selection/engine.py b/src/rotator_library/usage/selection/engine.py new file mode 100644 index 00000000..64769a1f --- /dev/null +++ b/src/rotator_library/usage/selection/engine.py @@ -0,0 +1,487 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Selection engine for credential selection. + +Central component that orchestrates limit checking, modifiers, +and rotation strategies to select the best credential. +""" + +import time +import logging +from typing import Any, Dict, List, Optional, Set, Union + +from ..types import ( + CredentialState, + SelectionContext, + RotationMode, + LimitCheckResult, +) +from ..config import ProviderUsageConfig +from ..limits.engine import LimitEngine +from ..tracking.windows import WindowManager +from .strategies.balanced import BalancedStrategy +from .strategies.sequential import SequentialStrategy + +lib_logger = logging.getLogger("rotator_library") + + +class SelectionEngine: + """ + Central engine for credential selection. + + Orchestrates: + 1. Limit checking (filter unavailable credentials) + 2. Fair cycle modifiers (filter exhausted credentials) + 3. Rotation strategy (select from available) + """ + + def __init__( + self, + config: ProviderUsageConfig, + limit_engine: LimitEngine, + window_manager: WindowManager, + ): + """ + Initialize selection engine. + + Args: + config: Provider usage configuration + limit_engine: LimitEngine for availability checks + """ + self._config = config + self._limits = limit_engine + self._windows = window_manager + + # Initialize strategies + self._balanced = BalancedStrategy(config.rotation_tolerance) + self._sequential = SequentialStrategy(config.sequential_fallback_multiplier) + + # Current strategy + if config.rotation_mode == RotationMode.SEQUENTIAL: + self._strategy = self._sequential + else: + self._strategy = self._balanced + + def select( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Select the best available credential. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + exclude: Set of stable_ids to exclude + priorities: Override priorities (stable_id -> priority) + deadline: Request deadline timestamp + + Returns: + Selected stable_id, or None if none available + """ + exclude = exclude or set() + + # Step 1: Get all candidates (not excluded) + candidates = [sid for sid in states.keys() if sid not in exclude] + + if not candidates: + return None + + # Step 2: Filter by limits + available = [] + for stable_id in candidates: + state = states[stable_id] + result = self._limits.check_all(state, model, quota_group) + if result.allowed: + available.append(stable_id) + + if not available: + # Check if we should reset fair cycle + if self._config.fair_cycle.enabled: + reset_performed = self._try_fair_cycle_reset( + provider, + model, + quota_group, + states, + candidates, + priorities, + ) + if reset_performed: + # Retry selection after reset + return self.select( + provider, + model, + states, + quota_group, + exclude, + priorities, + deadline, + ) + + lib_logger.debug( + f"No available credentials for {provider}/{model} " + f"(all {len(candidates)} blocked by limits)" + ) + return None + + # Step 3: Build selection context + # Get usage counts for weighting + usage_counts = {} + for stable_id in available: + state = states[stable_id] + usage_counts[stable_id] = self._get_usage_count(state, model, quota_group) + + # Build priorities map + if priorities is None: + priorities = {} + for stable_id in available: + priorities[stable_id] = states[stable_id].priority + + context = SelectionContext( + provider=provider, + model=model, + quota_group=quota_group, + candidates=available, + priorities=priorities, + usage_counts=usage_counts, + rotation_mode=self._config.rotation_mode, + rotation_tolerance=self._config.rotation_tolerance, + deadline=deadline or (time.time() + 120), + ) + + # Step 4: Apply rotation strategy + selected = self._strategy.select(context, states) + + if selected: + lib_logger.debug( + f"Selected credential {selected} for {provider}/{model} " + f"(from {len(available)} available)" + ) + + return selected + + def select_with_retry( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + tried: Optional[Set[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Select a credential for retry, excluding already-tried ones. + + Convenience method for retry loops. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + tried: Set of already-tried stable_ids + priorities: Override priorities + deadline: Request deadline timestamp + + Returns: + Selected stable_id, or None if none available + """ + return self.select( + provider=provider, + model=model, + states=states, + quota_group=quota_group, + exclude=tried, + priorities=priorities, + deadline=deadline, + ) + + def get_availability_stats( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get availability statistics for credentials. + + Useful for status reporting and debugging. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + + Returns: + Dict with availability stats + """ + total = len(states) + available = 0 + blocked_by = { + "cooldowns": 0, + "window_limits": 0, + "custom_caps": 0, + "fair_cycle": 0, + "concurrent": 0, + } + + for stable_id, state in states.items(): + blocking = self._limits.get_blocking_info(state, model, quota_group) + + is_available = True + for checker_name, result in blocking.items(): + if not result.allowed: + is_available = False + if checker_name in blocked_by: + blocked_by[checker_name] += 1 + break + + if is_available: + available += 1 + + return { + "total": total, + "available": available, + "blocked": total - available, + "blocked_by": blocked_by, + "rotation_mode": self._config.rotation_mode.value, + } + + def set_rotation_mode(self, mode: RotationMode) -> None: + """ + Change the rotation mode. + + Args: + mode: New rotation mode + """ + self._config.rotation_mode = mode + if mode == RotationMode.SEQUENTIAL: + self._strategy = self._sequential + else: + self._strategy = self._balanced + + lib_logger.info(f"Rotation mode changed to {mode.value}") + + def mark_exhausted(self, provider: str, model_or_group: str) -> None: + """ + Mark current credential as exhausted (for sequential mode). + + Args: + provider: Provider name + model_or_group: Model or quota group + """ + if isinstance(self._strategy, SequentialStrategy): + self._strategy.mark_exhausted(provider, model_or_group) + + @property + def balanced_strategy(self) -> BalancedStrategy: + """Get the balanced strategy instance.""" + return self._balanced + + @property + def sequential_strategy(self) -> SequentialStrategy: + """Get the sequential strategy instance.""" + return self._sequential + + def _get_usage_count( + self, + state: CredentialState, + model: str, + quota_group: Optional[str], + ) -> int: + """Get the relevant usage count for rotation weighting.""" + primary_def = self._windows.get_primary_definition() + if primary_def: + windows = None + + if primary_def.applies_to == "model": + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + elif primary_def.applies_to == "group": + group_key = quota_group or model + group_stats = state.get_group_stats(group_key, create=False) + if group_stats: + windows = group_stats.windows + + if windows: + window = self._windows.get_active_window(windows, primary_def.name) + if window: + return window.request_count + + return state.totals.request_count + + def _get_shortest_cooldown( + self, + states: List[CredentialState], + group_key: str, + ) -> tuple: + """ + Find the shortest remaining cooldown among the given credentials. + + Args: + states: List of credential states to check + group_key: Model or quota group key for cooldown lookup + + Returns: + Tuple of (has_short_cooldown, stable_id, remaining_seconds) + where has_short_cooldown is True if any cooldown is under the threshold + """ + import time + + now = time.time() + threshold = self._config.fair_cycle.reset_cooldown_threshold + shortest_remaining = float("inf") + shortest_cred_id = None + + for state in states: + # Check group-specific cooldown + cooldown = state.cooldowns.get(group_key) + if cooldown and cooldown.until > now: + remaining = cooldown.until - now + if remaining < shortest_remaining: + shortest_remaining = remaining + shortest_cred_id = state.stable_id + + # Also check global cooldown + global_cooldown = state.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + remaining = global_cooldown.until - now + if remaining < shortest_remaining: + shortest_remaining = remaining + shortest_cred_id = state.stable_id + + if shortest_remaining < threshold: + return (True, shortest_cred_id, shortest_remaining) + return (False, None, shortest_remaining) + + def _try_fair_cycle_reset( + self, + provider: str, + model: str, + quota_group: Optional[str], + states: Dict[str, CredentialState], + candidates: List[str], + priorities: Optional[Dict[str, int]], + ) -> bool: + """ + Try to reset fair cycle if all credentials are exhausted. + + Tier-aware: If cross_tier is disabled, checks each tier separately. + + Args: + provider: Provider name + model: Model being requested + quota_group: Quota group for this model + states: All credential states + candidates: Candidate stable_ids + + Returns: + True if reset was performed, False otherwise + """ + from ..types import LimitResult + + group_key = quota_group or model + fair_cycle_checker = self._limits.fair_cycle_checker + tracking_key = fair_cycle_checker.get_tracking_key(model, quota_group) + + # Check if all candidates are blocked by fair cycle + all_fair_cycle_blocked = True + fair_cycle_blocked_count = 0 + + for stable_id in candidates: + state = states[stable_id] + result = self._limits.check_all(state, model, quota_group) + + if result.allowed: + # Some credential is available - no need to reset + return False + + if result.result == LimitResult.BLOCKED_FAIR_CYCLE: + fair_cycle_blocked_count += 1 + else: + # Blocked by something other than fair cycle + all_fair_cycle_blocked = False + + # If no credentials blocked by fair cycle, can't help + if fair_cycle_blocked_count == 0: + return False + + # Get all candidate states for reset + candidate_states = [states[sid] for sid in candidates] + priority_map = priorities or {sid: states[sid].priority for sid in candidates} + + # Tier-aware reset + if self._config.fair_cycle.cross_tier: + # Cross-tier: reset all at once + if fair_cycle_checker.check_all_exhausted( + provider, tracking_key, candidate_states, priorities=priority_map + ): + # Before resetting, check if any credential has a short cooldown + # that will expire soon - if so, wait instead of resetting + has_short, cred_id, remaining = self._get_shortest_cooldown( + candidate_states, group_key + ) + if has_short: + lib_logger.debug( + f"Skipping fair cycle reset for {provider}/{model}: " + f"credential {cred_id} has short cooldown ({remaining:.0f}s remaining)" + ) + return False + + lib_logger.info( + f"All credentials fair-cycle exhausted for {provider}/{model} " + f"(cross-tier), resetting cycle" + ) + fair_cycle_checker.reset_cycle(provider, tracking_key, candidate_states) + return True + else: + # Per-tier: group by priority and check each tier + tier_groups: Dict[int, List[CredentialState]] = {} + for state in candidate_states: + priority = state.priority + tier_groups.setdefault(priority, []).append(state) + + reset_any = False + for priority, tier_states in tier_groups.items(): + # Check if all in this tier are exhausted + all_tier_exhausted = all( + state.is_fair_cycle_exhausted(tracking_key) for state in tier_states + ) + + if all_tier_exhausted: + # Before resetting, check if any credential has a short cooldown + # that will expire soon - if so, wait instead of resetting + has_short, cred_id, remaining = self._get_shortest_cooldown( + tier_states, group_key + ) + if has_short: + lib_logger.debug( + f"Skipping fair cycle reset for {provider}/{model} tier {priority}: " + f"credential {cred_id} has short cooldown ({remaining:.0f}s remaining)" + ) + continue + + lib_logger.info( + f"All credentials fair-cycle exhausted for {provider}/{model} " + f"in tier {priority}, resetting tier cycle" + ) + fair_cycle_checker.reset_cycle(provider, tracking_key, tier_states) + reset_any = True + + return reset_any + + return False diff --git a/src/rotator_library/usage/selection/modifiers/__init__.py b/src/rotator_library/usage/selection/modifiers/__init__.py new file mode 100644 index 00000000..f06468eb --- /dev/null +++ b/src/rotator_library/usage/selection/modifiers/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Selection modifiers.""" diff --git a/src/rotator_library/usage/selection/strategies/__init__.py b/src/rotator_library/usage/selection/strategies/__init__.py new file mode 100644 index 00000000..68eeba8c --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Rotation strategy implementations.""" + +from .balanced import BalancedStrategy +from .sequential import SequentialStrategy + +__all__ = ["BalancedStrategy", "SequentialStrategy"] diff --git a/src/rotator_library/usage/selection/strategies/balanced.py b/src/rotator_library/usage/selection/strategies/balanced.py new file mode 100644 index 00000000..6d070117 --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/balanced.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Balanced rotation strategy. + +Distributes load evenly across credentials using weighted random selection. +""" + +import random +import logging +from typing import Dict, List, Optional + +from ...types import CredentialState, SelectionContext, RotationMode + +lib_logger = logging.getLogger("rotator_library") + + +class BalancedStrategy: + """ + Balanced credential rotation strategy. + + Uses weighted random selection where less-used credentials have + higher probability of being selected. The tolerance parameter + controls how much randomness is introduced. + + Weight formula: weight = (max_usage - credential_usage) + tolerance + 1 + """ + + def __init__(self, tolerance: float = 3.0): + """ + Initialize balanced strategy. + + Args: + tolerance: Controls randomness of selection. + - 0.0: Deterministic, least-used always selected + - 2.0-4.0: Recommended, balanced randomness + - 5.0+: High randomness + """ + self.tolerance = tolerance + + @property + def name(self) -> str: + return "balanced" + + @property + def mode(self) -> RotationMode: + return RotationMode.BALANCED + + def select( + self, + context: SelectionContext, + states: Dict[str, CredentialState], + ) -> Optional[str]: + """ + Select a credential using weighted random selection. + + Args: + context: Selection context with candidates and usage info + states: Dict of stable_id -> CredentialState + + Returns: + Selected stable_id, or None if no candidates + """ + if not context.candidates: + return None + + if len(context.candidates) == 1: + return context.candidates[0] + + # Group by priority for tiered selection + priority_groups = self._group_by_priority( + context.candidates, context.priorities + ) + + # Try each priority tier in order + for priority in sorted(priority_groups.keys()): + candidates = priority_groups[priority] + if not candidates: + continue + + # Calculate weights for this tier + weights = self._calculate_weights(candidates, context.usage_counts) + + # Weighted random selection + selected = self._weighted_random_choice(candidates, weights) + if selected: + return selected + + # Fallback: first candidate + return context.candidates[0] + + def _group_by_priority( + self, + candidates: List[str], + priorities: Dict[str, int], + ) -> Dict[int, List[str]]: + """Group candidates by priority tier.""" + groups: Dict[int, List[str]] = {} + for stable_id in candidates: + priority = priorities.get(stable_id, 999) + groups.setdefault(priority, []).append(stable_id) + return groups + + def _calculate_weights( + self, + candidates: List[str], + usage_counts: Dict[str, int], + ) -> List[float]: + """ + Calculate selection weights for candidates. + + Weight formula: weight = (max_usage - credential_usage) + tolerance + 1 + """ + if not candidates: + return [] + + # Get usage counts + usages = [usage_counts.get(stable_id, 0) for stable_id in candidates] + max_usage = max(usages) if usages else 0 + + # Calculate weights + weights = [] + for usage in usages: + weight = (max_usage - usage) + self.tolerance + 1 + weights.append(max(weight, 0.1)) # Ensure minimum weight + + return weights + + def _weighted_random_choice( + self, + candidates: List[str], + weights: List[float], + ) -> Optional[str]: + """Select a candidate using weighted random choice.""" + if not candidates: + return None + + if len(candidates) == 1: + return candidates[0] + + # Normalize weights + total = sum(weights) + if total <= 0: + return random.choice(candidates) + + # Weighted selection + r = random.uniform(0, total) + cumulative = 0 + for candidate, weight in zip(candidates, weights): + cumulative += weight + if r <= cumulative: + return candidate + + return candidates[-1] diff --git a/src/rotator_library/usage/selection/strategies/sequential.py b/src/rotator_library/usage/selection/strategies/sequential.py new file mode 100644 index 00000000..c7f22250 --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/sequential.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Sequential rotation strategy. + +Uses one credential until exhausted, then moves to the next. +Good for providers that benefit from request caching. +""" + +import logging +from typing import Dict, List, Optional + +from ...types import CredentialState, SelectionContext, RotationMode +from ....error_handler import mask_credential +from ....error_handler import mask_credential + +lib_logger = logging.getLogger("rotator_library") + + +class SequentialStrategy: + """ + Sequential credential rotation strategy. + + Sticks to one credential until it's exhausted (rate limited, + quota exceeded, etc.), then moves to the next in priority order. + + This is useful for providers where repeated requests to the same + credential benefit from caching (e.g., context caching in LLMs). + """ + + def __init__(self, fallback_multiplier: int = 1): + """ + Initialize sequential strategy. + + Args: + fallback_multiplier: Default concurrent slots per priority + when not explicitly configured + """ + self.fallback_multiplier = fallback_multiplier + # Track current "sticky" credential per (provider, model_group) + self._current: Dict[tuple, str] = {} + + @property + def name(self) -> str: + return "sequential" + + @property + def mode(self) -> RotationMode: + return RotationMode.SEQUENTIAL + + def select( + self, + context: SelectionContext, + states: Dict[str, CredentialState], + ) -> Optional[str]: + """ + Select a credential using sequential/sticky selection. + + Prefers the currently active credential if it's still available. + Otherwise, selects the first available by priority. + + Args: + context: Selection context with candidates and usage info + states: Dict of stable_id -> CredentialState + + Returns: + Selected stable_id, or None if no candidates + """ + if not context.candidates: + return None + + if len(context.candidates) == 1: + return context.candidates[0] + + key = (context.provider, context.quota_group or context.model) + + # Check if current sticky credential is still available + current = self._current.get(key) + if current and current in context.candidates: + return current + + # Current not available - select new one by tier -> usage -> recency + selected = self._select_by_priority( + context.candidates, + context.priorities, + context.usage_counts, + states, + ) + + # Make it sticky + if selected: + self._current[key] = selected + masked = ( + mask_credential(states[selected].accessor, style="full") + if selected in states + else mask_credential(selected, style="full") + ) + lib_logger.debug(f"Sequential: switched to credential {masked} for {key}") + + return selected + + def mark_exhausted(self, provider: str, model_or_group: str) -> None: + """ + Mark current credential as exhausted, forcing rotation. + + Args: + provider: Provider name + model_or_group: Model or quota group + """ + key = (provider, model_or_group) + if key in self._current: + old = self._current[key] + del self._current[key] + lib_logger.debug( + f"Sequential: marked {mask_credential(old, style='full')} exhausted for {key}" + ) + + def get_current(self, provider: str, model_or_group: str) -> Optional[str]: + """ + Get the currently sticky credential. + + Args: + provider: Provider name + model_or_group: Model or quota group + + Returns: + Current sticky credential stable_id, or None + """ + key = (provider, model_or_group) + return self._current.get(key) + + def _select_by_priority( + self, + candidates: List[str], + priorities: Dict[str, int], + usage_counts: Optional[Dict[str, int]] = None, + states: Optional[Dict[str, CredentialState]] = None, + ) -> Optional[str]: + """ + Select credential by: tier (priority) -> usage (highest) -> recency (most recent). + + Sequential mode prefers most-used credentials within the window to maximize + cache hits. When selecting a new sticky credential: + 1. Highest tier (lowest priority number) first + 2. Within same tier, prefer highest usage count + 3. Within same usage, prefer most recently used + + Args: + candidates: List of available credential stable_ids + priorities: Dict of stable_id -> priority (lower = higher tier) + usage_counts: Dict of stable_id -> request count for relevant window + states: Dict of stable_id -> CredentialState for recency lookup + + Returns: + Selected stable_id, or None if no candidates + """ + if not candidates: + return None + + usage_counts = usage_counts or {} + states = states or {} + + def sort_key(c: str): + # 1. Priority/tier (lower number = higher tier = preferred) + priority = priorities.get(c, 999) + + # 2. Usage count (higher = preferred, so negate for ascending sort) + usage = -(usage_counts.get(c, 0)) + + # 3. Recency (more recent = preferred, so negate for ascending sort) + state = states.get(c) + last_used = -(state.totals.last_used_at or 0) if state else 0 + + return (priority, usage, last_used) + + sorted_candidates = sorted(candidates, key=sort_key) + return sorted_candidates[0] + + def clear_sticky(self, provider: Optional[str] = None) -> None: + """ + Clear sticky credential state. + + Args: + provider: If specified, only clear for this provider + """ + if provider: + keys_to_remove = [k for k in self._current if k[0] == provider] + for key in keys_to_remove: + del self._current[key] + else: + self._current.clear() diff --git a/src/rotator_library/usage/tracking/__init__.py b/src/rotator_library/usage/tracking/__init__.py new file mode 100644 index 00000000..e459d28f --- /dev/null +++ b/src/rotator_library/usage/tracking/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Usage tracking and window management.""" + +from .engine import TrackingEngine +from .windows import WindowManager + +__all__ = ["TrackingEngine", "WindowManager"] diff --git a/src/rotator_library/usage/tracking/engine.py b/src/rotator_library/usage/tracking/engine.py new file mode 100644 index 00000000..5b6017c9 --- /dev/null +++ b/src/rotator_library/usage/tracking/engine.py @@ -0,0 +1,720 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Tracking engine for usage recording. + +Central component for recording requests, successes, and failures. +""" + +import asyncio +import logging +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set + +if TYPE_CHECKING: + from ..config import WindowDefinition + +from ..types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + CooldownInfo, + FairCycleState, + TrackingMode, + UsageUpdate, + FAIR_CYCLE_GLOBAL_KEY, +) +from ..config import WindowDefinition, ProviderUsageConfig +from .windows import WindowManager +from ...error_handler import mask_credential +from ...error_handler import mask_credential + +lib_logger = logging.getLogger("rotator_library") + + +class TrackingEngine: + """ + Central engine for usage tracking. + + Responsibilities: + - Recording request successes and failures + - Managing usage windows + - Updating global statistics + - Managing cooldowns + - Tracking fair cycle state + """ + + def __init__( + self, + window_manager: WindowManager, + config: ProviderUsageConfig, + ): + """ + Initialize tracking engine. + + Args: + window_manager: WindowManager instance for window operations + config: Provider usage configuration + """ + self._windows = window_manager + self._config = config + self._lock = asyncio.Lock() + + async def record_usage( + self, + state: CredentialState, + model: str, + update: UsageUpdate, + group: Optional[str] = None, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Record usage for a request (consolidated function). + + Updates: + - model_usage[model].windows[*] + totals + - group_usage[group].windows[*] + totals (if group provided) + - credential.totals + + Args: + state: Credential state to update + model: Model that was used + update: UsageUpdate with all metrics + group: Quota group for this model (None = no group tracking) + response_headers: Optional response headers with rate limit info + """ + async with self._lock: + now = time.time() + fair_cycle_key = self._resolve_fair_cycle_key(group or model) + + # Calculate derived values + output_tokens = update.completion_tokens + update.thinking_tokens + total_tokens = ( + update.prompt_tokens + + update.completion_tokens + + update.thinking_tokens + + update.prompt_tokens_cache_read + + update.prompt_tokens_cache_write + ) + + # 1. Update model stats + model_stats = state.get_model_stats(model) + self._apply_to_windows( + model_stats.windows, + update, + now, + total_tokens, + output_tokens, + window_definitions=state.window_definitions or None, + ) + self._apply_to_totals( + model_stats.totals, update, now, total_tokens, output_tokens + ) + + # 2. Update group stats (if applicable) + if group: + group_stats = state.get_group_stats(group) + self._apply_to_windows( + group_stats.windows, + update, + now, + total_tokens, + output_tokens, + window_definitions=state.window_definitions or None, + ) + self._apply_to_totals( + group_stats.totals, update, now, total_tokens, output_tokens + ) + + # Sync model window timing from group (group is authoritative) + # All models in a quota group share the same started_at/reset_at + self._sync_window_timing_from_group( + model_stats.windows, group_stats.windows + ) + + # 3. Update credential totals + self._apply_to_totals( + state.totals, update, now, total_tokens, output_tokens + ) + + # 4. Update fair cycle request count + if self._config.fair_cycle.enabled: + fc_state = state.fair_cycle.get(fair_cycle_key) + if not fc_state: + fc_state = FairCycleState(model_or_group=fair_cycle_key) + state.fair_cycle[fair_cycle_key] = fc_state + fc_state.cycle_request_count += update.request_count + + # 5. Update from response headers if provided + if response_headers: + self._update_from_headers(state, response_headers, model, group) + + state.last_updated = now + + async def record_success( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + thinking_tokens: int = 0, + approx_cost: float = 0.0, + request_count: int = 1, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Record a successful request. + + Args: + state: Credential state to update + model: Model that was used + quota_group: Quota group for this model (None = use model name) + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + prompt_tokens_cache_read: Cached prompt tokens read + prompt_tokens_cache_write: Cached prompt tokens written + thinking_tokens: Thinking tokens used + approx_cost: Approximate cost + request_count: Number of requests to record + response_headers: Optional response headers with rate limit info + """ + update = UsageUpdate( + request_count=request_count, + success=True, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + await self.record_usage( + state=state, + model=model, + update=update, + group=quota_group, + response_headers=response_headers, + ) + + async def record_failure( + self, + state: CredentialState, + model: str, + error_type: str, + quota_group: Optional[str] = None, + cooldown_duration: Optional[float] = None, + quota_reset_timestamp: Optional[float] = None, + mark_exhausted: bool = False, + request_count: int = 1, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """ + Record a failed request. + + Args: + state: Credential state to update + model: Model that was used + error_type: Type of error (quota_exceeded, rate_limit, etc.) + quota_group: Quota group for this model + cooldown_duration: How long to cool down (if applicable) + quota_reset_timestamp: When quota resets (from API) + mark_exhausted: Whether to mark as exhausted for fair cycle + request_count: Number of requests to record + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + thinking_tokens: Thinking tokens used + prompt_tokens_cache_read: Cached prompt tokens read + prompt_tokens_cache_write: Cached prompt tokens written + approx_cost: Approximate cost + """ + update = UsageUpdate( + request_count=request_count, + success=False, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Record the usage + await self.record_usage( + state=state, + model=model, + update=update, + group=quota_group, + ) + + async with self._lock: + group_key = quota_group or model + fair_cycle_key = self._resolve_fair_cycle_key(group_key) + + # Apply cooldown if specified + if cooldown_duration is not None and cooldown_duration > 0: + self._apply_cooldown( + state=state, + reason=error_type, + duration=cooldown_duration, + model_or_group=group_key, + source="error", + ) + + # Use quota reset timestamp if provided + if quota_reset_timestamp is not None: + self._apply_cooldown( + state=state, + reason=error_type, + until=quota_reset_timestamp, + model_or_group=group_key, + source="api_quota", + ) + + # Mark exhausted for fair cycle if requested + if mark_exhausted: + self._mark_exhausted(state, fair_cycle_key, error_type) + + async def acquire( + self, + state: CredentialState, + model: str, + ) -> bool: + """ + Acquire a credential for a request (increment active count). + + Args: + state: Credential state + model: Model being used + + Returns: + True if acquired, False if at max concurrent + """ + async with self._lock: + # Check concurrent limit + if state.max_concurrent is not None: + if state.active_requests >= state.max_concurrent: + return False + + state.active_requests += 1 + return True + + async def apply_cooldown( + self, + state: CredentialState, + reason: str, + duration: Optional[float] = None, + until: Optional[float] = None, + model_or_group: Optional[str] = None, + source: str = "system", + ) -> None: + """ + Apply a cooldown to a credential. + + Args: + state: Credential state + reason: Why the cooldown was applied + duration: Cooldown duration in seconds (if not using 'until') + until: Timestamp when cooldown ends (if not using 'duration') + model_or_group: Scope of cooldown (None = credential-wide) + source: Source of cooldown (system, custom_cap, rate_limit, etc.) + """ + async with self._lock: + self._apply_cooldown( + state=state, + reason=reason, + duration=duration, + until=until, + model_or_group=model_or_group, + source=source, + ) + + async def clear_cooldown( + self, + state: CredentialState, + model_or_group: Optional[str] = None, + ) -> None: + """ + Clear a cooldown from a credential. + + Args: + state: Credential state + model_or_group: Scope of cooldown to clear (None = global) + """ + async with self._lock: + key = model_or_group or "_global_" + if key in state.cooldowns: + del state.cooldowns[key] + + async def mark_exhausted( + self, + state: CredentialState, + model_or_group: str, + reason: str, + ) -> None: + """ + Mark a credential as exhausted for fair cycle. + + Args: + state: Credential state + model_or_group: Scope of exhaustion + reason: Why credential was exhausted + """ + async with self._lock: + self._mark_exhausted(state, model_or_group, reason) + + async def reset_fair_cycle( + self, + state: CredentialState, + model_or_group: str, + ) -> None: + """ + Reset fair cycle state for a credential. + + Args: + state: Credential state + model_or_group: Scope to reset + """ + async with self._lock: + if model_or_group in state.fair_cycle: + fc_state = state.fair_cycle[model_or_group] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + def get_window_usage( + self, + state: CredentialState, + window_name: str, + model: Optional[str] = None, + group: Optional[str] = None, + ) -> int: + """ + Get request count for a specific window. + + Args: + state: Credential state + window_name: Name of window + model: Model to check (optional) + group: Group to check (optional) + + Returns: + Request count (0 if window doesn't exist) + """ + # Check group first if provided + if group: + group_stats = state.group_usage.get(group) + if group_stats: + window = self._windows.get_active_window( + group_stats.windows, window_name + ) + if window: + return window.request_count + + # Check model if provided + if model: + model_stats = state.model_usage.get(model) + if model_stats: + window = self._windows.get_active_window( + model_stats.windows, window_name + ) + if window: + return window.request_count + + return 0 + + def get_primary_window_usage( + self, + state: CredentialState, + model: Optional[str] = None, + group: Optional[str] = None, + ) -> int: + """ + Get request count for the primary window. + + Args: + state: Credential state + model: Model to check (optional) + group: Group to check (optional) + + Returns: + Request count (0 if no primary window) + """ + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return 0 + return self.get_window_usage(state, primary_def.name, model, group) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _apply_to_windows( + self, + windows: Dict[str, WindowStats], + update: UsageUpdate, + now: float, + total_tokens: int, + output_tokens: int, + window_definitions: Optional[List["WindowDefinition"]] = None, + ) -> None: + """Apply update to all configured windows.""" + # Use credential's window definitions if provided, otherwise fall back to config + defs = window_definitions if window_definitions else self._config.windows + for window_def in defs: + window = self._windows.get_or_create_window(windows, window_def.name) + self._apply_to_window( + window, update, now, total_tokens, output_tokens, window_def + ) + + def _apply_to_window( + self, + window: WindowStats, + update: UsageUpdate, + now: float, + total_tokens: int, + output_tokens: int, + window_def: Optional["WindowDefinition"] = None, + ) -> None: + """Apply update to a single window.""" + window.request_count += update.request_count + if update.success: + window.success_count += update.request_count + else: + window.failure_count += update.request_count + + window.prompt_tokens += update.prompt_tokens + window.completion_tokens += update.completion_tokens + window.thinking_tokens += update.thinking_tokens + window.output_tokens += output_tokens + window.prompt_tokens_cache_read += update.prompt_tokens_cache_read + window.prompt_tokens_cache_write += update.prompt_tokens_cache_write + window.total_tokens += total_tokens + window.approx_cost += update.approx_cost + + window.last_used_at = now + if window.first_used_at is None: + window.first_used_at = now + + # Set started_at on first usage and calculate reset_at + if window.started_at is None: + window.started_at = now + # Calculate reset_at based on window definition + # Use passed window_def first, then fall back to shared definitions + effective_def = window_def or self._windows.definitions.get(window.name) + if effective_def and window.reset_at is None: + window.reset_at = self._windows._calculate_reset_time( + effective_def, now + ) + + # Update max recorded requests (historical high-water mark) + if ( + window.max_recorded_requests is None + or window.request_count > window.max_recorded_requests + ): + window.max_recorded_requests = window.request_count + window.max_recorded_at = now + + def _apply_to_totals( + self, + totals: TotalStats, + update: UsageUpdate, + now: float, + total_tokens: int, + output_tokens: int, + ) -> None: + """Apply update to totals.""" + totals.request_count += update.request_count + if update.success: + totals.success_count += update.request_count + else: + totals.failure_count += update.request_count + + totals.prompt_tokens += update.prompt_tokens + totals.completion_tokens += update.completion_tokens + totals.thinking_tokens += update.thinking_tokens + totals.output_tokens += output_tokens + totals.prompt_tokens_cache_read += update.prompt_tokens_cache_read + totals.prompt_tokens_cache_write += update.prompt_tokens_cache_write + totals.total_tokens += total_tokens + totals.approx_cost += update.approx_cost + + totals.last_used_at = now + if totals.first_used_at is None: + totals.first_used_at = now + + def _sync_window_timing_from_group( + self, + model_windows: Dict[str, WindowStats], + group_windows: Dict[str, WindowStats], + ) -> None: + """ + Sync timing fields from group windows to model windows. + + Group window is authoritative for started_at and reset_at. + All models in a quota group share the same timing to ensure + consistent quota tracking and window resets. + + Args: + model_windows: The model's windows dict to update + group_windows: The group's windows dict (authoritative) + """ + for window_name, group_window in group_windows.items(): + model_window = model_windows.get(window_name) + if model_window: + model_window.started_at = group_window.started_at + model_window.reset_at = group_window.reset_at + + def _apply_cooldown( + self, + state: CredentialState, + reason: str, + duration: Optional[float] = None, + until: Optional[float] = None, + model_or_group: Optional[str] = None, + source: str = "system", + ) -> None: + """Internal cooldown application (no lock).""" + now = time.time() + + if until is not None: + cooldown_until = until + elif duration is not None: + cooldown_until = now + duration + else: + return # No cooldown specified + + key = model_or_group or "_global_" + + # Check for existing cooldown + existing = state.cooldowns.get(key) + backoff_count = 0 + if existing and existing.is_active: + # Preserve original reason/source/started_at - cooldown reason should + # reflect why it was originally set, not subsequent updates + # Time (until) is updated to the new value as API is authoritative + backoff_count = existing.backoff_count + 1 + reason = existing.reason + source = existing.source + started_at = existing.started_at + else: + started_at = now + + state.cooldowns[key] = CooldownInfo( + reason=reason, + until=cooldown_until, + started_at=started_at, + source=source, + model_or_group=model_or_group, + backoff_count=backoff_count, + ) + + # Check if cooldown qualifies as exhaustion + cooldown_duration = cooldown_until - now + if cooldown_duration >= self._config.exhaustion_cooldown_threshold: + if self._config.fair_cycle.enabled and model_or_group: + fair_cycle_key = self._resolve_fair_cycle_key(model_or_group) + self._mark_exhausted(state, fair_cycle_key, f"cooldown_{reason}") + + def _mark_exhausted( + self, + state: CredentialState, + model_or_group: str, + reason: str, + ) -> None: + """Internal exhaustion marking (no lock).""" + now = time.time() + + if model_or_group not in state.fair_cycle: + state.fair_cycle[model_or_group] = FairCycleState( + model_or_group=model_or_group + ) + + fc_state = state.fair_cycle[model_or_group] + + # Idempotency check: skip if already exhausted (avoid duplicate logging) + if fc_state.exhausted: + return + + fc_state.exhausted = True + fc_state.exhausted_at = now + fc_state.exhausted_reason = reason + + lib_logger.info( + f"Credential {mask_credential(state.accessor, style='full')} marked fair-cycle exhausted " + f"for {model_or_group}: {reason}" + ) + + def _resolve_fair_cycle_key(self, group_key: str) -> str: + """Resolve fair cycle tracking key based on config.""" + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return group_key + + def _update_from_headers( + self, + state: CredentialState, + headers: Dict[str, Any], + model: str, + group: Optional[str], + ) -> None: + """Update state from API response headers.""" + # Common header patterns for rate limiting + # X-RateLimit-Remaining, X-RateLimit-Reset, etc. + remaining = headers.get("x-ratelimit-remaining") + reset = headers.get("x-ratelimit-reset") + limit = headers.get("x-ratelimit-limit") + + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return + + # Update group windows if group is provided + if group: + group_stats = state.get_group_stats(group, create=False) + if group_stats: + window = group_stats.windows.get(primary_def.name) + if window: + self._apply_header_updates(window, limit, reset) + + # Update model windows + model_stats = state.get_model_stats(model, create=False) + if model_stats: + window = model_stats.windows.get(primary_def.name) + if window: + self._apply_header_updates(window, limit, reset) + + def _apply_header_updates( + self, + window: WindowStats, + limit: Optional[Any], + reset: Optional[Any], + ) -> None: + """Apply header updates to a window.""" + if limit is not None: + try: + window.limit = int(limit) + except (ValueError, TypeError): + pass + + if reset is not None: + try: + reset_float = float(reset) + # If reset is in the past, it might be a Unix timestamp + # If it's a small number, it might be seconds until reset + if reset_float < 1000000000: # Less than ~2001, probably relative + reset_float = time.time() + reset_float + window.reset_at = reset_float + except (ValueError, TypeError): + pass diff --git a/src/rotator_library/usage/tracking/windows.py b/src/rotator_library/usage/tracking/windows.py new file mode 100644 index 00000000..65e0f412 --- /dev/null +++ b/src/rotator_library/usage/tracking/windows.py @@ -0,0 +1,407 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Window management for usage tracking. + +Handles time-based usage windows with various reset modes. +""" + +import time +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone, time as dt_time +from typing import Any, Dict, List, Optional, Tuple + +from ..types import WindowStats, ResetMode +from ..config import WindowDefinition + +lib_logger = logging.getLogger("rotator_library") + + +class WindowManager: + """ + Manages usage tracking windows for credentials. + + Handles: + - Rolling windows (e.g., last 5 hours) + - Fixed daily windows (reset at specific UTC time) + - Calendar windows (weekly, monthly) + - API-authoritative windows (provider determines reset) + """ + + def __init__( + self, + window_definitions: List[WindowDefinition], + daily_reset_time_utc: str = "03:00", + ): + """ + Initialize window manager. + + Args: + window_definitions: List of window configurations + daily_reset_time_utc: Time for daily reset in HH:MM format + """ + self.definitions = {w.name: w for w in window_definitions} + self.daily_reset_time_utc = self._parse_time(daily_reset_time_utc) + + def get_active_window( + self, + windows: Dict[str, WindowStats], + window_name: str, + ) -> Optional[WindowStats]: + """ + Get an active (non-expired) window by name. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to get + + Returns: + WindowStats if active, None if expired or doesn't exist + """ + window = windows.get(window_name) + if window is None: + return None + + definition = self.definitions.get(window_name) + if definition is None: + return window # Unknown window, return as-is + + # Check if window needs reset + if self._should_reset(window, definition): + return None + + return window + + def get_or_create_window( + self, + windows: Dict[str, WindowStats], + window_name: str, + limit: Optional[int] = None, + ) -> WindowStats: + """ + Get an active window or create a new one. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to get/create + limit: Optional request limit for the window + + Returns: + Active WindowStats (may be newly created) + """ + window = self.get_active_window(windows, window_name) + if window is not None: + return window + + # Preserve fields from expired window (if exists) + old_max = None + old_max_at = None + old_limit = None + old_window = windows.get(window_name) + if old_window is not None: + # Preserve limit across window resets (until new API baseline arrives) + old_limit = old_window.limit + + # Take max of the old window's recorded max and its final request count + old_recorded_max = old_window.max_recorded_requests or 0 + if old_window.request_count > old_recorded_max: + old_max = old_window.request_count + old_max_at = old_window.last_used_at or time.time() + elif old_recorded_max > 0: + old_max = old_recorded_max + old_max_at = old_window.max_recorded_at + + # Create new window + # Note: started_at and reset_at are left as None until first actual usage + # This prevents bogus reset times from being displayed for unused windows + new_window = WindowStats( + name=window_name, + started_at=None, + reset_at=None, + limit=limit or old_limit, # Use passed limit, fall back to preserved limit + max_recorded_requests=old_max, # Carry forward historical max + max_recorded_at=old_max_at, + ) + + windows[window_name] = new_window + return new_window + + def get_primary_window( + self, + windows: Dict[str, WindowStats], + ) -> Optional[WindowStats]: + """ + Get the primary window used for rotation decisions. + + Args: + windows: Current windows dict for a credential + + Returns: + Primary WindowStats or None + """ + for name, definition in self.definitions.items(): + if definition.is_primary: + return self.get_active_window(windows, name) + return None + + def get_primary_definition(self) -> Optional[WindowDefinition]: + """Get the primary window definition.""" + for definition in self.definitions.values(): + if definition.is_primary: + return definition + return None + + def get_window_remaining( + self, + windows: Dict[str, WindowStats], + window_name: str, + ) -> Optional[int]: + """ + Get remaining requests in a window. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to check + + Returns: + Remaining requests, or None if unlimited/unknown + """ + window = self.get_active_window(windows, window_name) + if window is None: + return None + return window.remaining + + def update_limit( + self, + windows: Dict[str, WindowStats], + window_name: str, + new_limit: int, + ) -> None: + """ + Update the limit for a window (e.g., from API response). + + Args: + windows: Current windows dict for a credential + window_name: Name of window to update + new_limit: New request limit + """ + window = windows.get(window_name) + if window is not None: + window.limit = new_limit + + def update_reset_time( + self, + windows: Dict[str, WindowStats], + window_name: str, + reset_timestamp: float, + ) -> None: + """ + Update the reset time for a window (e.g., from API response). + + Args: + windows: Current windows dict for a credential + window_name: Name of window to update + reset_timestamp: New reset timestamp + """ + window = windows.get(window_name) + if window is not None: + window.reset_at = reset_timestamp + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _should_reset(self, window: WindowStats, definition: WindowDefinition) -> bool: + """ + Check if a window should be reset based on its definition. + """ + now = time.time() + + # If window has an explicit reset time, use it + if window.reset_at is not None: + return now >= window.reset_at + + # If window has no start time, it hasn't been used yet - no need to reset + if window.started_at is None: + return False + + # Check based on reset mode + if definition.reset_mode == ResetMode.ROLLING: + if definition.duration_seconds is None: + return False # Infinite window + return now >= window.started_at + definition.duration_seconds + + elif definition.reset_mode == ResetMode.FIXED_DAILY: + return self._past_daily_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.CALENDAR_WEEKLY: + return self._past_weekly_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.CALENDAR_MONTHLY: + return self._past_monthly_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.API_AUTHORITATIVE: + # Only reset if explicit reset_at is set and passed + return False + + return False + + def _calculate_reset_time( + self, + definition: WindowDefinition, + start_time: float, + ) -> Optional[float]: + """ + Calculate when a window should reset based on its definition. + """ + if definition.reset_mode == ResetMode.ROLLING: + if definition.duration_seconds is None: + return None # Infinite window + return start_time + definition.duration_seconds + + elif definition.reset_mode == ResetMode.FIXED_DAILY: + return self._next_daily_reset(start_time) + + elif definition.reset_mode == ResetMode.CALENDAR_WEEKLY: + return self._next_weekly_reset(start_time) + + elif definition.reset_mode == ResetMode.CALENDAR_MONTHLY: + return self._next_monthly_reset(start_time) + + elif definition.reset_mode == ResetMode.API_AUTHORITATIVE: + return None # Will be set by API response + + return None + + def _parse_time(self, time_str: str) -> dt_time: + """Parse HH:MM time string.""" + try: + parts = time_str.split(":") + return dt_time(hour=int(parts[0]), minute=int(parts[1])) + except (ValueError, IndexError): + return dt_time(hour=3, minute=0) # Default 03:00 + + def _past_daily_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the daily reset time since window started.""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get reset time for the day after start + reset_dt = start_dt.replace( + hour=self.daily_reset_time_utc.hour, + minute=self.daily_reset_time_utc.minute, + second=0, + microsecond=0, + ) + if reset_dt <= start_dt: + # Reset time already passed today, use tomorrow + from datetime import timedelta + + reset_dt += timedelta(days=1) + + return now_dt >= reset_dt + + def _next_daily_reset(self, from_time: float) -> float: + """Calculate next daily reset timestamp.""" + from datetime import timedelta + + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + reset_dt = from_dt.replace( + hour=self.daily_reset_time_utc.hour, + minute=self.daily_reset_time_utc.minute, + second=0, + microsecond=0, + ) + if reset_dt <= from_dt: + reset_dt += timedelta(days=1) + + return reset_dt.timestamp() + + def _past_weekly_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the weekly reset (Sunday 03:00 UTC).""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get start of next week (Sunday 03:00 UTC) + days_until_sunday = (6 - start_dt.weekday()) % 7 + if days_until_sunday == 0 and start_dt.hour >= 3: + days_until_sunday = 7 + + from datetime import timedelta + + reset_dt = start_dt.replace( + hour=3, minute=0, second=0, microsecond=0 + ) + timedelta(days=days_until_sunday) + + return now_dt >= reset_dt + + def _next_weekly_reset(self, from_time: float) -> float: + """Calculate next weekly reset timestamp.""" + from datetime import timedelta + + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + days_until_sunday = (6 - from_dt.weekday()) % 7 + if days_until_sunday == 0 and from_dt.hour >= 3: + days_until_sunday = 7 + + reset_dt = from_dt.replace( + hour=3, minute=0, second=0, microsecond=0 + ) + timedelta(days=days_until_sunday) + + return reset_dt.timestamp() + + def _past_monthly_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the monthly reset (1st 03:00 UTC).""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get 1st of next month + if start_dt.month == 12: + reset_dt = start_dt.replace( + year=start_dt.year + 1, + month=1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + else: + reset_dt = start_dt.replace( + month=start_dt.month + 1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + + return now_dt >= reset_dt + + def _next_monthly_reset(self, from_time: float) -> float: + """Calculate next monthly reset timestamp.""" + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + + if from_dt.month == 12: + reset_dt = from_dt.replace( + year=from_dt.year + 1, + month=1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + else: + reset_dt = from_dt.replace( + month=from_dt.month + 1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + + return reset_dt.timestamp() diff --git a/src/rotator_library/usage/types.py b/src/rotator_library/usage/types.py new file mode 100644 index 00000000..48d09fc0 --- /dev/null +++ b/src/rotator_library/usage/types.py @@ -0,0 +1,473 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Type definitions for the usage tracking package. + +This module contains dataclasses and type definitions specific to +usage tracking, limits, and credential selection. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set, Tuple, Union + +if TYPE_CHECKING: + from .config import WindowDefinition + + +# ============================================================================= +# ENUMS +# ============================================================================= + + +FAIR_CYCLE_GLOBAL_KEY = "_credential_" + + +class ResetMode(str, Enum): + """How a usage window resets.""" + + ROLLING = "rolling" # Continuous rolling window + FIXED_DAILY = "fixed_daily" # Reset at specific time each day + CALENDAR_WEEKLY = "calendar_weekly" # Reset at start of week + CALENDAR_MONTHLY = "calendar_monthly" # Reset at start of month + API_AUTHORITATIVE = "api_authoritative" # Provider API determines reset + + +class LimitResult(str, Enum): + """Result of a limit check.""" + + ALLOWED = "allowed" + BLOCKED_WINDOW = "blocked_window" + BLOCKED_COOLDOWN = "blocked_cooldown" + BLOCKED_FAIR_CYCLE = "blocked_fair_cycle" + BLOCKED_CUSTOM_CAP = "blocked_custom_cap" + BLOCKED_CONCURRENT = "blocked_concurrent" + + +class RotationMode(str, Enum): + """How credentials are rotated.""" + + BALANCED = "balanced" # Weighted random selection + SEQUENTIAL = "sequential" # Sticky until exhausted + + +class TrackingMode(str, Enum): + """How fair cycle tracks exhaustion.""" + + MODEL_GROUP = "model_group" # Track per quota group or model + CREDENTIAL = "credential" # Track per credential globally + + +class CooldownMode(str, Enum): + """How custom cap cooldowns are calculated.""" + + QUOTA_RESET = "quota_reset" # Wait until quota window resets + OFFSET = "offset" # Add offset seconds to current time + FIXED = "fixed" # Use fixed duration + + +class CapMode(str, Enum): + """How custom cap max_requests values are interpreted.""" + + ABSOLUTE = "absolute" # e.g., 130 → exactly 130 requests + OFFSET = "offset" # e.g., -130 → max - 130, +130 → max + 130 + PERCENTAGE = "percentage" # e.g., 80% → 80% of max + + +# ============================================================================= +# WINDOW STATS +# ============================================================================= + + +@dataclass +class WindowStats: + """ + Statistics for a single time-based usage window (e.g., 5h, daily). + + Tracks usage within a specific time window for quota management. + """ + + name: str # Window identifier (e.g., "5h", "daily") + request_count: int = 0 + success_count: int = 0 + failure_count: int = 0 + + # Token stats + prompt_tokens: int = 0 + completion_tokens: int = 0 + thinking_tokens: int = 0 + output_tokens: int = 0 # completion + thinking + prompt_tokens_cache_read: int = 0 + prompt_tokens_cache_write: int = 0 + total_tokens: int = 0 + + approx_cost: float = 0.0 + + # Window timing + started_at: Optional[float] = None # When window period started + reset_at: Optional[float] = None # When window resets + limit: Optional[int] = None # Max requests allowed (None = unlimited) + + # Historical max tracking (persists across window resets) + max_recorded_requests: Optional[int] = ( + None # Highest request_count ever in any window period + ) + max_recorded_at: Optional[float] = None # When the max was recorded + + # Usage timing (for smart selection) + first_used_at: Optional[float] = None # First request in this window + last_used_at: Optional[float] = None # Last request in this window + + @property + def remaining(self) -> Optional[int]: + """Remaining requests in this window, or None if unlimited.""" + if self.limit is None: + return None + return max(0, self.limit - self.request_count) + + @property + def is_exhausted(self) -> bool: + """True if limit reached.""" + if self.limit is None: + return False + return self.request_count >= self.limit + + +# ============================================================================= +# TOTAL STATS +# ============================================================================= + + +@dataclass +class TotalStats: + """ + All-time totals for a model, group, or credential. + + Tracks cumulative usage across all time (never resets). + """ + + request_count: int = 0 + success_count: int = 0 + failure_count: int = 0 + + # Token stats + prompt_tokens: int = 0 + completion_tokens: int = 0 + thinking_tokens: int = 0 + output_tokens: int = 0 # completion + thinking + prompt_tokens_cache_read: int = 0 + prompt_tokens_cache_write: int = 0 + total_tokens: int = 0 + + approx_cost: float = 0.0 + + # Timestamps + first_used_at: Optional[float] = None # All-time first use + last_used_at: Optional[float] = None # All-time last use + + +# ============================================================================= +# MODEL & GROUP STATS CONTAINERS +# ============================================================================= + + +@dataclass +class ModelStats: + """ + Stats for a single model (own usage only). + + Contains time-based windows and all-time totals. + Each model only tracks its own usage, not shared quota. + """ + + windows: Dict[str, WindowStats] = field(default_factory=dict) + totals: TotalStats = field(default_factory=TotalStats) + + +@dataclass +class GroupStats: + """ + Stats for a quota group (shared usage). + + Contains time-based windows and all-time totals. + Updated when ANY model in the group is used. + """ + + windows: Dict[str, WindowStats] = field(default_factory=dict) + totals: TotalStats = field(default_factory=TotalStats) + + +# ============================================================================= +# COOLDOWN TYPES +# ============================================================================= + + +@dataclass +class CooldownInfo: + """ + Information about a cooldown period. + + Cooldowns temporarily block a credential from being used. + """ + + reason: str # Why the cooldown was applied + until: float # Timestamp when cooldown ends + started_at: float # Timestamp when cooldown started + source: str = "system" # "system", "custom_cap", "rate_limit", "provider_hook" + model_or_group: Optional[str] = None # Scope of cooldown (None = credential-wide) + backoff_count: int = 0 # Number of consecutive cooldowns + + @property + def remaining_seconds(self) -> float: + """Seconds remaining in cooldown.""" + import time + + return max(0.0, self.until - time.time()) + + @property + def is_active(self) -> bool: + """True if cooldown is still in effect.""" + import time + + return time.time() < self.until + + +# ============================================================================= +# FAIR CYCLE TYPES +# ============================================================================= + + +@dataclass +class FairCycleState: + """ + Fair cycle state for a credential. + + Tracks whether a credential has been exhausted in the current cycle. + """ + + exhausted: bool = False + exhausted_at: Optional[float] = None + exhausted_reason: Optional[str] = None + cycle_request_count: int = 0 # Requests in current cycle + model_or_group: Optional[str] = None # Scope of exhaustion + + +@dataclass +class GlobalFairCycleState: + """ + Global fair cycle state for a provider. + + Tracks the overall cycle across all credentials. + """ + + cycle_start: float = 0.0 # Timestamp when current cycle started + all_exhausted_at: Optional[float] = None # When all credentials exhausted + cycle_count: int = 0 # How many full cycles completed + + +# ============================================================================= +# USAGE UPDATE (for consolidated tracking) +# ============================================================================= + + +@dataclass +class UsageUpdate: + """ + All data for a single usage update. + + Used by TrackingEngine.record_usage() to apply updates atomically. + """ + + request_count: int = 1 + success: bool = True + + # Tokens (optional) + prompt_tokens: int = 0 + completion_tokens: int = 0 + thinking_tokens: int = 0 + prompt_tokens_cache_read: int = 0 + prompt_tokens_cache_write: int = 0 + approx_cost: float = 0.0 + + +# ============================================================================= +# CREDENTIAL STATE +# ============================================================================= + + +@dataclass +class CredentialState: + """ + Complete state for a single credential. + + This is the primary storage unit for credential data. + """ + + # Identity + stable_id: str # Email (OAuth) or hash (API key) + provider: str + accessor: str # Current file path or API key + display_name: Optional[str] = None + tier: Optional[str] = None + priority: int = 999 # Lower = higher priority + + # Window definitions for this credential's tier + # Populated during initialization based on tier/priority + window_definitions: List["WindowDefinition"] = field(default_factory=list) + + # Stats - source of truth + model_usage: Dict[str, ModelStats] = field(default_factory=dict) + group_usage: Dict[str, GroupStats] = field(default_factory=dict) + totals: TotalStats = field(default_factory=TotalStats) # Credential-level totals + + # Cooldowns (keyed by model/group or "_global_") + cooldowns: Dict[str, CooldownInfo] = field(default_factory=dict) + + # Fair cycle state (keyed by model/group) + fair_cycle: Dict[str, FairCycleState] = field(default_factory=dict) + + # Active requests (for concurrent request limiting) + active_requests: int = 0 + max_concurrent: Optional[int] = None + + # Metadata + created_at: Optional[float] = None + last_updated: Optional[float] = None + + def get_cooldown( + self, model_or_group: Optional[str] = None + ) -> Optional[CooldownInfo]: + """Get active cooldown for given scope.""" + import time + + now = time.time() + + # Check specific cooldown + if model_or_group: + cooldown = self.cooldowns.get(model_or_group) + if cooldown and cooldown.until > now: + return cooldown + + # Check global cooldown + global_cooldown = self.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return global_cooldown + + return None + + def is_fair_cycle_exhausted(self, model_or_group: str) -> bool: + """Check if exhausted for fair cycle purposes.""" + state = self.fair_cycle.get(model_or_group) + return state.exhausted if state else False + + def get_model_stats(self, model: str, create: bool = True) -> Optional[ModelStats]: + """Get model stats, optionally creating if not exists.""" + if create: + return self.model_usage.setdefault(model, ModelStats()) + return self.model_usage.get(model) + + def get_group_stats(self, group: str, create: bool = True) -> Optional[GroupStats]: + """Get group stats, optionally creating if not exists.""" + if create: + return self.group_usage.setdefault(group, GroupStats()) + return self.group_usage.get(group) + + def get_window_for_model( + self, model: str, window_name: str + ) -> Optional[WindowStats]: + """Get a specific window for a model.""" + model_stats = self.model_usage.get(model) + if model_stats: + return model_stats.windows.get(window_name) + return None + + def get_window_for_group( + self, group: str, window_name: str + ) -> Optional[WindowStats]: + """Get a specific window for a group.""" + group_stats = self.group_usage.get(group) + if group_stats: + return group_stats.windows.get(window_name) + return None + + +# ============================================================================= +# SELECTION TYPES +# ============================================================================= + + +@dataclass +class SelectionContext: + """ + Context passed to rotation strategies during credential selection. + + Contains all information needed to make a selection decision. + """ + + provider: str + model: str + quota_group: Optional[str] # Quota group for this model + candidates: List[str] # Stable IDs of available candidates + priorities: Dict[str, int] # stable_id -> priority + usage_counts: Dict[str, int] # stable_id -> request count for relevant window + rotation_mode: RotationMode + rotation_tolerance: float + deadline: float + + +@dataclass +class LimitCheckResult: + """ + Result of checking all limits for a credential. + + Used by LimitEngine to report why a credential was blocked. + """ + + allowed: bool + result: LimitResult = LimitResult.ALLOWED + reason: Optional[str] = None + blocked_until: Optional[float] = None # When the block expires + + @classmethod + def ok(cls) -> "LimitCheckResult": + """Create an allowed result.""" + return cls(allowed=True, result=LimitResult.ALLOWED) + + @classmethod + def blocked( + cls, + result: LimitResult, + reason: str, + blocked_until: Optional[float] = None, + ) -> "LimitCheckResult": + """Create a blocked result.""" + return cls( + allowed=False, + result=result, + reason=reason, + blocked_until=blocked_until, + ) + + +# ============================================================================= +# STORAGE TYPES +# ============================================================================= + + +@dataclass +class StorageSchema: + """ + Root schema for usage.json storage file. + """ + + schema_version: int = 2 + updated_at: Optional[str] = None # ISO format + credentials: Dict[str, Dict[str, Any]] = field(default_factory=dict) + accessor_index: Dict[str, str] = field( + default_factory=dict + ) # accessor -> stable_id + fair_cycle_global: Dict[str, Dict[str, Any]] = field( + default_factory=dict + ) # provider -> GlobalFairCycleState diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index 46e30bbc..5430dc80 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -1,3980 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -import json -import os -import time -import logging -import asyncio -import random -from datetime import date, datetime, timezone, time as dt_time -from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union -import aiofiles -import litellm +"""Compatibility shim for legacy imports.""" -from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential -from .providers import PROVIDER_PLUGINS -from .utils.resilient_io import ResilientStateWriter -from .utils.paths import get_data_file -from .config import ( - DEFAULT_FAIR_CYCLE_DURATION, - DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, - DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, - DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, - COOLDOWN_BACKOFF_TIERS, - COOLDOWN_BACKOFF_MAX, - COOLDOWN_AUTH_ERROR, - COOLDOWN_TRANSIENT_ERROR, - COOLDOWN_RATE_LIMIT_DEFAULT, -) +from .usage import UsageManager, CredentialContext -lib_logger = logging.getLogger("rotator_library") -lib_logger.propagate = False -if not lib_logger.handlers: - lib_logger.addHandler(logging.NullHandler()) - - -class UsageManager: - """ - Manages usage statistics and cooldowns for API keys with asyncio-safe locking, - asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. - - The credential rotation strategy can be configured via the `rotation_tolerance` parameter: - - - **tolerance = 0.0**: Deterministic least-used selection. The credential with - the lowest usage count is always selected. This provides predictable, perfectly balanced - load distribution but may be vulnerable to fingerprinting. - - - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected - randomly with weights biased toward less-used ones. Credentials within 2 uses of the - maximum can still be selected with reasonable probability. This provides security through - unpredictability while maintaining good load balance. - - - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant - selection probability. Useful for stress testing or maximum unpredictability, but may - result in less balanced load distribution. - - The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` - - This ensures lower-usage credentials are preferred while tolerance controls how much - randomness is introduced into the selection process. - - Additionally, providers can specify a rotation mode: - - "balanced" (default): Rotate credentials to distribute load evenly - - "sequential": Use one credential until exhausted (preserves caching) - """ - - def __init__( - self, - file_path: Optional[Union[str, Path]] = None, - daily_reset_time_utc: Optional[str] = "03:00", - rotation_tolerance: float = 0.0, - provider_rotation_modes: Optional[Dict[str, str]] = None, - provider_plugins: Optional[Dict[str, Any]] = None, - priority_multipliers: Optional[Dict[str, Dict[int, int]]] = None, - priority_multipliers_by_mode: Optional[ - Dict[str, Dict[str, Dict[int, int]]] - ] = None, - sequential_fallback_multipliers: Optional[Dict[str, int]] = None, - fair_cycle_enabled: Optional[Dict[str, bool]] = None, - fair_cycle_tracking_mode: Optional[Dict[str, str]] = None, - fair_cycle_cross_tier: Optional[Dict[str, bool]] = None, - fair_cycle_duration: Optional[Dict[str, int]] = None, - exhaustion_cooldown_threshold: Optional[Dict[str, int]] = None, - custom_caps: Optional[ - Dict[str, Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]]] - ] = None, - ): - """ - Initialize the UsageManager. - - Args: - file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json"). - Can be absolute Path, relative Path, or string. - daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) - rotation_tolerance: Tolerance for weighted random credential rotation. - - 0.0: Deterministic, least-used credential always selected - - tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - - 5.0+: High randomness, more unpredictable selection patterns - provider_rotation_modes: Dict mapping provider names to rotation modes. - - "balanced": Rotate credentials to distribute load evenly (default) - - "sequential": Use one credential until exhausted (preserves caching) - provider_plugins: Dict mapping provider names to provider plugin instances. - Used for per-provider usage reset configuration (window durations, field names). - priority_multipliers: Dict mapping provider -> priority -> multiplier. - Universal multipliers that apply regardless of rotation mode. - Example: {"antigravity": {1: 5, 2: 3}} - priority_multipliers_by_mode: Dict mapping provider -> mode -> priority -> multiplier. - Mode-specific overrides. Example: {"antigravity": {"balanced": {3: 1}}} - sequential_fallback_multipliers: Dict mapping provider -> fallback multiplier. - Used in sequential mode when priority not in priority_multipliers. - Example: {"antigravity": 2} - fair_cycle_enabled: Dict mapping provider -> bool to enable fair cycle rotation. - When enabled, credentials must all exhaust before any can be reused. - Default: enabled for sequential mode only. - fair_cycle_tracking_mode: Dict mapping provider -> tracking mode. - - "model_group": Track per quota group or model (default) - - "credential": Track per credential globally - fair_cycle_cross_tier: Dict mapping provider -> bool for cross-tier tracking. - - False: Each tier cycles independently (default) - - True: All credentials must exhaust regardless of tier - fair_cycle_duration: Dict mapping provider -> cycle duration in seconds. - Default: 86400 (24 hours) - exhaustion_cooldown_threshold: Dict mapping provider -> threshold in seconds. - A cooldown must exceed this to qualify as "exhausted". Default: 300 (5 min) - custom_caps: Dict mapping provider -> tier -> model/group -> cap config. - Allows setting custom usage limits per tier, per model or quota group. - See ProviderInterface.default_custom_caps for format details. - """ - # Resolve file_path - use default if not provided - if file_path is None: - self.file_path = str(get_data_file("key_usage.json")) - elif isinstance(file_path, Path): - self.file_path = str(file_path) - else: - # String path - could be relative or absolute - self.file_path = file_path - self.rotation_tolerance = rotation_tolerance - self.provider_rotation_modes = provider_rotation_modes or {} - self.provider_plugins = provider_plugins or PROVIDER_PLUGINS - self.priority_multipliers = priority_multipliers or {} - self.priority_multipliers_by_mode = priority_multipliers_by_mode or {} - self.sequential_fallback_multipliers = sequential_fallback_multipliers or {} - self._provider_instances: Dict[str, Any] = {} # Cache for provider instances - self.key_states: Dict[str, Dict[str, Any]] = {} - - # Fair cycle rotation configuration - self.fair_cycle_enabled = fair_cycle_enabled or {} - self.fair_cycle_tracking_mode = fair_cycle_tracking_mode or {} - self.fair_cycle_cross_tier = fair_cycle_cross_tier or {} - self.fair_cycle_duration = fair_cycle_duration or {} - self.exhaustion_cooldown_threshold = exhaustion_cooldown_threshold or {} - self.custom_caps = custom_caps or {} - # In-memory cycle state: {provider: {tier_key: {tracking_key: {"cycle_started_at": float, "exhausted": Set[str]}}}} - self._cycle_exhausted: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {} - - self._data_lock = asyncio.Lock() - self._usage_data: Optional[Dict] = None - self._initialized = asyncio.Event() - self._init_lock = asyncio.Lock() - - self._timeout_lock = asyncio.Lock() - self._claimed_on_timeout: Set[str] = set() - - # Resilient writer for usage data persistence - self._state_writer = ResilientStateWriter(file_path, lib_logger) - - if daily_reset_time_utc: - hour, minute = map(int, daily_reset_time_utc.split(":")) - self.daily_reset_time_utc = dt_time( - hour=hour, minute=minute, tzinfo=timezone.utc - ) - else: - self.daily_reset_time_utc = None - - def _get_rotation_mode(self, provider: str) -> str: - """ - Get the rotation mode for a provider. - - Args: - provider: Provider name (e.g., "antigravity", "gemini_cli") - - Returns: - "balanced" or "sequential" - """ - return self.provider_rotation_modes.get(provider, "balanced") - - # ========================================================================= - # FAIR CYCLE ROTATION HELPERS - # ========================================================================= - - def _is_fair_cycle_enabled(self, provider: str, rotation_mode: str) -> bool: - """ - Check if fair cycle rotation is enabled for a provider. - - Args: - provider: Provider name - rotation_mode: Current rotation mode ("balanced" or "sequential") - - Returns: - True if fair cycle is enabled - """ - # Check provider-specific setting first - if provider in self.fair_cycle_enabled: - return self.fair_cycle_enabled[provider] - # Default: enabled only for sequential mode - return rotation_mode == "sequential" - - def _get_fair_cycle_tracking_mode(self, provider: str) -> str: - """ - Get fair cycle tracking mode for a provider. - - Returns: - "model_group" or "credential" - """ - return self.fair_cycle_tracking_mode.get(provider, "model_group") - - def _is_fair_cycle_cross_tier(self, provider: str) -> bool: - """ - Check if fair cycle tracks across all tiers (ignoring priority boundaries). - - Returns: - True if cross-tier tracking is enabled - """ - return self.fair_cycle_cross_tier.get(provider, False) - - def _get_fair_cycle_duration(self, provider: str) -> int: - """ - Get fair cycle duration in seconds for a provider. - - Returns: - Duration in seconds (default 86400 = 24 hours) - """ - return self.fair_cycle_duration.get(provider, DEFAULT_FAIR_CYCLE_DURATION) - - def _get_exhaustion_cooldown_threshold(self, provider: str) -> int: - """ - Get exhaustion cooldown threshold in seconds for a provider. - - A cooldown must exceed this duration to qualify as "exhausted" for fair cycle. - - Returns: - Threshold in seconds (default 300 = 5 minutes) - """ - return self.exhaustion_cooldown_threshold.get( - provider, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD - ) - - # ========================================================================= - # CUSTOM CAPS HELPERS - # ========================================================================= - - def _get_custom_cap_config( - self, - provider: str, - tier_priority: int, - model: str, - ) -> Optional[Dict[str, Any]]: - """ - Get custom cap config for a provider/tier/model combination. - - Resolution order: - 1. tier + model (exact match) - 2. tier + group (model's quota group) - 3. "default" + model - 4. "default" + group - - Args: - provider: Provider name - tier_priority: Credential's priority level - model: Model name (with provider prefix) - - Returns: - Cap config dict or None if no custom cap applies - """ - provider_caps = self.custom_caps.get(provider) - if not provider_caps: - return None - - # Strip provider prefix from model - clean_model = model.split("/")[-1] if "/" in model else model - - # Get quota group for this model - group = self._get_model_quota_group_by_provider(provider, model) - - # Try to find matching tier config - tier_config = None - default_config = None - - for tier_key, models_config in provider_caps.items(): - if tier_key == "default": - default_config = models_config - continue - - # Check if this tier_key matches our priority - if isinstance(tier_key, int) and tier_key == tier_priority: - tier_config = models_config - break - elif isinstance(tier_key, tuple) and tier_priority in tier_key: - tier_config = models_config - break - - # Resolution order for tier config - if tier_config: - # Try model first - if clean_model in tier_config: - return tier_config[clean_model] - # Try group - if group and group in tier_config: - return tier_config[group] - - # Resolution order for default config - if default_config: - # Try model first - if clean_model in default_config: - return default_config[clean_model] - # Try group - if group and group in default_config: - return default_config[group] - - return None - - def _get_model_quota_group_by_provider( - self, provider: str, model: str - ) -> Optional[str]: - """ - Get quota group for a model using provider name instead of credential. - - Args: - provider: Provider name - model: Model name - - Returns: - Group name or None - """ - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): - return plugin_instance.get_model_quota_group(model) - return None - - def _resolve_custom_cap_max( - self, - provider: str, - model: str, - cap_config: Dict[str, Any], - actual_max: Optional[int], - ) -> Optional[int]: - """ - Resolve custom cap max_requests value, handling percentages and clamping. - - Args: - provider: Provider name - model: Model name (for logging) - cap_config: Custom cap configuration - actual_max: Actual API max requests (may be None if unknown) - - Returns: - Resolved cap value (clamped), or None if can't be calculated - """ - max_requests = cap_config.get("max_requests") - if max_requests is None: - return None - - # Handle percentage - if isinstance(max_requests, str) and max_requests.endswith("%"): - if actual_max is None: - lib_logger.warning( - f"Custom cap '{max_requests}' for {provider}/{model} requires known max_requests. " - f"Skipping until quota baseline is fetched. Use absolute value for immediate enforcement." - ) - return None - try: - percentage = float(max_requests.rstrip("%")) / 100.0 - calculated = int(actual_max * percentage) - except ValueError: - lib_logger.warning( - f"Invalid percentage cap '{max_requests}' for {provider}/{model}" - ) - return None - else: - # Absolute value - try: - calculated = int(max_requests) - except (ValueError, TypeError): - lib_logger.warning( - f"Invalid cap value '{max_requests}' for {provider}/{model}" - ) - return None - - # Clamp to actual max (can only be MORE restrictive) - if actual_max is not None: - return min(calculated, actual_max) - return calculated - - def _calculate_custom_cooldown_until( - self, - cap_config: Dict[str, Any], - window_start_ts: Optional[float], - natural_reset_ts: Optional[float], - ) -> Optional[float]: - """ - Calculate when custom cap cooldown should end, clamped to natural reset. - - Args: - cap_config: Custom cap configuration - window_start_ts: When first request was made (for fixed mode) - natural_reset_ts: Natural quota reset timestamp - - Returns: - Cooldown end timestamp (clamped), or None if can't calculate - """ - mode = cap_config.get("cooldown_mode", DEFAULT_CUSTOM_CAP_COOLDOWN_MODE) - value = cap_config.get("cooldown_value", DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE) - - if mode == "quota_reset": - calculated = natural_reset_ts - elif mode == "offset": - if natural_reset_ts is None: - return None - calculated = natural_reset_ts + value - elif mode == "fixed": - if window_start_ts is None: - return None - calculated = window_start_ts + value - else: - lib_logger.warning(f"Unknown cooldown_mode '{mode}', using quota_reset") - calculated = natural_reset_ts - - if calculated is None: - return None - - # Clamp to natural reset (can only be MORE restrictive = longer cooldown) - if natural_reset_ts is not None: - return max(calculated, natural_reset_ts) - return calculated - - def _check_and_apply_custom_cap( - self, - credential: str, - model: str, - request_count: int, - ) -> bool: - """ - Check if custom cap is exceeded and apply cooldown if so. - - This should be called after incrementing request_count in record_success(). - - Args: - credential: Credential identifier - model: Model name (with provider prefix) - request_count: Current request count for this model - - Returns: - True if cap exceeded and cooldown applied, False otherwise - """ - provider = self._get_provider_from_credential(credential) - if not provider: - return False - - priority = self._get_credential_priority(credential, provider) - cap_config = self._get_custom_cap_config(provider, priority, model) - if not cap_config: - return False - - # Get model data for actual max and timing info - key_data = self._usage_data.get(credential, {}) - model_data = key_data.get("models", {}).get(model, {}) - actual_max = model_data.get("quota_max_requests") - window_start_ts = model_data.get("window_start_ts") - natural_reset_ts = model_data.get("quota_reset_ts") - - # Resolve custom cap max - custom_max = self._resolve_custom_cap_max( - provider, model, cap_config, actual_max - ) - if custom_max is None: - return False - - # Check if exceeded - if request_count < custom_max: - return False - - # Calculate cooldown end time - cooldown_until = self._calculate_custom_cooldown_until( - cap_config, window_start_ts, natural_reset_ts - ) - if cooldown_until is None: - # Can't calculate cooldown, use natural reset if available - if natural_reset_ts: - cooldown_until = natural_reset_ts - else: - lib_logger.warning( - f"Custom cap hit for {mask_credential(credential)}/{model} but can't calculate cooldown. " - f"Skipping cooldown application." - ) - return False - - now_ts = time.time() - - # Apply cooldown - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - model_cooldowns[model] = cooldown_until - - # Store custom cap info in model data for reference - model_data["custom_cap_max"] = custom_max - model_data["custom_cap_hit_at"] = now_ts - model_data["custom_cap_cooldown_until"] = cooldown_until - - hours_until = (cooldown_until - now_ts) / 3600 - lib_logger.info( - f"Custom cap hit: {mask_credential(credential)} reached {request_count}/{custom_max} " - f"for {model}. Cooldown for {hours_until:.1f}h" - ) - - # Sync cooldown across quota group - group = self._get_model_quota_group(credential, model) - if group: - grouped_models = self._get_grouped_models(credential, group) - for grouped_model in grouped_models: - if grouped_model != model: - model_cooldowns[grouped_model] = cooldown_until - - # Check if this should trigger fair cycle exhaustion - cooldown_duration = cooldown_until - now_ts - threshold = self._get_exhaustion_cooldown_threshold(provider) - if cooldown_duration > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key(credential, model, provider) - self._mark_credential_exhausted( - credential, provider, tier_key, tracking_key - ) - - return True - - def _get_tier_key(self, provider: str, priority: int) -> str: - """ - Get the tier key for cycle tracking based on cross_tier setting. - - Args: - provider: Provider name - priority: Credential priority level - - Returns: - "__all_tiers__" if cross-tier enabled, else str(priority) - """ - if self._is_fair_cycle_cross_tier(provider): - return "__all_tiers__" - return str(priority) - - def _get_tracking_key(self, credential: str, model: str, provider: str) -> str: - """ - Get the key for exhaustion tracking based on tracking mode. - - Args: - credential: Credential identifier - model: Model name (with provider prefix) - provider: Provider name - - Returns: - Tracking key string (quota group name, model name, or "__credential__") - """ - mode = self._get_fair_cycle_tracking_mode(provider) - if mode == "credential": - return "__credential__" - # model_group mode: use quota group if exists, else model - group = self._get_model_quota_group(credential, model) - return group if group else model - - def _get_credential_priority(self, credential: str, provider: str) -> int: - """ - Get the priority level for a credential. - - Args: - credential: Credential identifier - provider: Provider name - - Returns: - Priority level (default 999 if unknown) - """ - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr(plugin_instance, "get_credential_priority"): - priority = plugin_instance.get_credential_priority(credential) - if priority is not None: - return priority - return 999 - - def _get_cycle_data( - self, provider: str, tier_key: str, tracking_key: str - ) -> Optional[Dict[str, Any]]: - """ - Get cycle data for a provider/tier/tracking key combination. - - Returns: - Cycle data dict or None if not exists - """ - return ( - self._cycle_exhausted.get(provider, {}).get(tier_key, {}).get(tracking_key) - ) - - def _ensure_cycle_structure( - self, provider: str, tier_key: str, tracking_key: str - ) -> Dict[str, Any]: - """ - Ensure the nested cycle structure exists and return the cycle data dict. - """ - if provider not in self._cycle_exhausted: - self._cycle_exhausted[provider] = {} - if tier_key not in self._cycle_exhausted[provider]: - self._cycle_exhausted[provider][tier_key] = {} - if tracking_key not in self._cycle_exhausted[provider][tier_key]: - self._cycle_exhausted[provider][tier_key][tracking_key] = { - "cycle_started_at": None, - "exhausted": set(), - } - return self._cycle_exhausted[provider][tier_key][tracking_key] - - def _mark_credential_exhausted( - self, - credential: str, - provider: str, - tier_key: str, - tracking_key: str, - ) -> None: - """ - Mark a credential as exhausted for fair cycle tracking. - - Starts the cycle timer on first exhaustion. - Skips if credential is already in the exhausted set (prevents duplicate logging). - """ - cycle_data = self._ensure_cycle_structure(provider, tier_key, tracking_key) - - # Skip if already exhausted in this cycle (prevents duplicate logging) - if credential in cycle_data.get("exhausted", set()): - return - - # Start cycle timer on first exhaustion - if cycle_data["cycle_started_at"] is None: - cycle_data["cycle_started_at"] = time.time() - lib_logger.info( - f"Fair cycle started for {provider} tier={tier_key} tracking='{tracking_key}'" - ) - - cycle_data["exhausted"].add(credential) - lib_logger.info( - f"Fair cycle: marked {mask_credential(credential)} exhausted " - f"for {tracking_key} ({len(cycle_data['exhausted'])} total)" - ) - - def _is_credential_exhausted_in_cycle( - self, - credential: str, - provider: str, - tier_key: str, - tracking_key: str, - ) -> bool: - """ - Check if a credential was exhausted in the current cycle. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - return credential in cycle_data.get("exhausted", set()) - - def _is_cycle_expired( - self, provider: str, tier_key: str, tracking_key: str - ) -> bool: - """ - Check if the current cycle has exceeded its duration. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - cycle_started = cycle_data.get("cycle_started_at") - if cycle_started is None: - return False - duration = self._get_fair_cycle_duration(provider) - return time.time() >= cycle_started + duration - - def _should_reset_cycle( - self, - provider: str, - tier_key: str, - tracking_key: str, - all_credentials_in_tier: List[str], - available_not_on_cooldown: Optional[List[str]] = None, - ) -> bool: - """ - Check if cycle should reset. - - Returns True if: - 1. Cycle duration has expired, OR - 2. No credentials remain available (after cooldown + fair cycle exclusion), OR - 3. All credentials in the tier have been marked exhausted (fallback) - """ - # Check duration first - if self._is_cycle_expired(provider, tier_key, tracking_key): - return True - - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - - # If available credentials are provided, reset when none remain usable - if available_not_on_cooldown is not None: - has_available = any( - not self._is_credential_exhausted_in_cycle( - cred, provider, tier_key, tracking_key - ) - for cred in available_not_on_cooldown - ) - if not has_available and len(all_credentials_in_tier) > 0: - return True - - exhausted = cycle_data.get("exhausted", set()) - # All must be exhausted (and there must be at least one credential) - return ( - len(exhausted) >= len(all_credentials_in_tier) - and len(all_credentials_in_tier) > 0 - ) - - def _reset_cycle(self, provider: str, tier_key: str, tracking_key: str) -> None: - """ - Reset exhaustion tracking for a completed cycle. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data: - exhausted_count = len(cycle_data.get("exhausted", set())) - lib_logger.info( - f"Fair cycle complete for {provider} tier={tier_key} " - f"tracking='{tracking_key}' - resetting ({exhausted_count} credentials cycled)" - ) - cycle_data["cycle_started_at"] = None - cycle_data["exhausted"] = set() - - def _get_all_credentials_for_tier_key( - self, - provider: str, - tier_key: str, - available_keys: List[str], - credential_priorities: Optional[Dict[str, int]], - ) -> List[str]: - """ - Get all credentials that belong to a tier key. - - Args: - provider: Provider name - tier_key: Either "__all_tiers__" or str(priority) - available_keys: List of available credential identifiers - credential_priorities: Dict mapping credentials to priorities - - Returns: - List of credentials belonging to this tier key - """ - if tier_key == "__all_tiers__": - # Cross-tier: all credentials for this provider - return list(available_keys) - else: - # Within-tier: only credentials with matching priority - priority = int(tier_key) - if credential_priorities: - return [ - k - for k in available_keys - if credential_priorities.get(k, 999) == priority - ] - return list(available_keys) - - def _count_fair_cycle_excluded( - self, - provider: str, - tier_key: str, - tracking_key: str, - candidates: List[str], - ) -> int: - """ - Count how many candidates are excluded by fair cycle. - - Args: - provider: Provider name - tier_key: Tier key for tracking - tracking_key: Model/group tracking key - candidates: List of candidate credentials (not on cooldown) - - Returns: - Number of candidates excluded by fair cycle - """ - count = 0 - for cred in candidates: - if self._is_credential_exhausted_in_cycle( - cred, provider, tier_key, tracking_key - ): - count += 1 - return count - - def _get_priority_multiplier( - self, provider: str, priority: int, rotation_mode: str - ) -> int: - """ - Get the concurrency multiplier for a provider/priority/mode combination. - - Lookup order: - 1. Mode-specific tier override: priority_multipliers_by_mode[provider][mode][priority] - 2. Universal tier multiplier: priority_multipliers[provider][priority] - 3. Sequential fallback (if mode is sequential): sequential_fallback_multipliers[provider] - 4. Global default: 1 (no multiplier effect) - - Args: - provider: Provider name (e.g., "antigravity") - priority: Priority level (1 = highest priority) - rotation_mode: Current rotation mode ("sequential" or "balanced") - - Returns: - Multiplier value - """ - provider_lower = provider.lower() - - # 1. Check mode-specific override - if provider_lower in self.priority_multipliers_by_mode: - mode_multipliers = self.priority_multipliers_by_mode[provider_lower] - if rotation_mode in mode_multipliers: - if priority in mode_multipliers[rotation_mode]: - return mode_multipliers[rotation_mode][priority] - - # 2. Check universal tier multiplier - if provider_lower in self.priority_multipliers: - if priority in self.priority_multipliers[provider_lower]: - return self.priority_multipliers[provider_lower][priority] - - # 3. Sequential fallback (only for sequential mode) - if rotation_mode == "sequential": - if provider_lower in self.sequential_fallback_multipliers: - return self.sequential_fallback_multipliers[provider_lower] - - # 4. Global default - return 1 - - def _get_provider_from_credential(self, credential: str) -> Optional[str]: - """ - Extract provider name from credential path or identifier. - - Supports multiple credential formats: - - OAuth: "oauth_creds/antigravity_oauth_15.json" -> "antigravity" - - OAuth: "C:\\...\\oauth_creds\\gemini_cli_oauth_1.json" -> "gemini_cli" - - OAuth filename only: "antigravity_oauth_1.json" -> "antigravity" - - API key style: extracted from model names in usage data (e.g., "firmware/model" -> "firmware") - - Args: - credential: The credential identifier (path or key) - - Returns: - Provider name string or None if cannot be determined - """ - import re - - # Pattern: env:// URI format (e.g., "env://antigravity/1" -> "antigravity") - if credential.startswith("env://"): - parts = credential[6:].split("/") # Remove "env://" prefix - if parts and parts[0]: - return parts[0].lower() - # Malformed env:// URI (empty provider name) - lib_logger.warning(f"Malformed env:// credential URI: {credential}") - return None - - # Normalize path separators - normalized = credential.replace("\\", "/") - - # Pattern: path ending with {provider}_oauth_{number}.json - match = re.search(r"/([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: oauth_creds/{provider}_... - match = re.search(r"oauth_creds/([a-z_]+)_", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: filename only {provider}_oauth_{number}.json (no path) - match = re.match(r"([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: API key prefixes for specific providers - # These are raw API keys with recognizable prefixes - api_key_prefixes = { - "sk-nano-": "nanogpt", - "sk-or-": "openrouter", - "sk-ant-": "anthropic", - } - for prefix, provider in api_key_prefixes.items(): - if credential.startswith(prefix): - return provider - - # Fallback: For raw API keys, extract provider from model names in usage data - # This handles providers like firmware, chutes, nanogpt that use credential-level quota - if self._usage_data and credential in self._usage_data: - cred_data = self._usage_data[credential] - - # Check "models" section first (for per_model mode and quota tracking) - models_data = cred_data.get("models", {}) - if models_data: - # Get first model name and extract provider prefix - first_model = next(iter(models_data.keys()), None) - if first_model and "/" in first_model: - provider = first_model.split("/")[0].lower() - return provider - - # Fallback to "daily" section (legacy structure) - daily_data = cred_data.get("daily", {}) - daily_models = daily_data.get("models", {}) - if daily_models: - # Get first model name and extract provider prefix - first_model = next(iter(daily_models.keys()), None) - if first_model and "/" in first_model: - provider = first_model.split("/")[0].lower() - return provider - - return None - - def _get_provider_instance(self, provider: str) -> Optional[Any]: - """ - Get or create a provider plugin instance. - - Args: - provider: The provider name - - Returns: - Provider plugin instance or None - """ - if not provider: - return None - - plugin_class = self.provider_plugins.get(provider) - if not plugin_class: - return None - - # Get or create provider instance from cache - if provider not in self._provider_instances: - # Instantiate the plugin if it's a class, or use it directly if already an instance - if isinstance(plugin_class, type): - self._provider_instances[provider] = plugin_class() - else: - self._provider_instances[provider] = plugin_class - - return self._provider_instances[provider] - - def _get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]: - """ - Get the usage reset configuration for a credential from its provider plugin. - - Args: - credential: The credential identifier - - Returns: - Configuration dict with window_seconds, field_name, etc. - or None to use default daily reset. - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_usage_reset_config"): - return plugin_instance.get_usage_reset_config(credential) - - return None - - def _get_reset_mode(self, credential: str) -> str: - """ - Get the reset mode for a credential: 'credential' or 'per_model'. - - Args: - credential: The credential identifier - - Returns: - "per_model" or "credential" (default) - """ - config = self._get_usage_reset_config(credential) - return config.get("mode", "credential") if config else "credential" - - def _get_model_quota_group(self, credential: str, model: str) -> Optional[str]: - """ - Get the quota group for a model, if the provider defines one. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Group name (e.g., "claude") or None if not grouped - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): - return plugin_instance.get_model_quota_group(model) - - return None - - def _get_grouped_models(self, credential: str, group: str) -> List[str]: - """ - Get all model names in a quota group (with provider prefix), normalized. - - Returns only public-facing model names, deduplicated. Internal variants - (e.g., claude-sonnet-4-5-thinking) are normalized to their public name - (e.g., claude-sonnet-4.5). - - Args: - credential: The credential identifier - group: Group name (e.g., "claude") - - Returns: - List of normalized, deduplicated model names with provider prefix - (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): - models = plugin_instance.get_models_in_quota_group(group) - - # Normalize and deduplicate - if hasattr(plugin_instance, "normalize_model_for_tracking"): - seen = set() - normalized = [] - for m in models: - prefixed = f"{provider}/{m}" - norm = plugin_instance.normalize_model_for_tracking(prefixed) - if norm not in seen: - seen.add(norm) - normalized.append(norm) - return normalized - - # Fallback: just add provider prefix - return [f"{provider}/{m}" for m in models] - - return [] - - def _get_model_usage_weight(self, credential: str, model: str) -> int: - """ - Get the usage weight for a model when calculating grouped usage. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Weight multiplier (default 1 if not configured) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_model_usage_weight"): - return plugin_instance.get_model_usage_weight(model) - - return 1 - - def _normalize_model(self, credential: str, model: str) -> str: - """ - Normalize model name using provider's mapping. - - Converts internal model names (e.g., claude-sonnet-4-5-thinking) to - public-facing names (e.g., claude-sonnet-4.5) for consistent storage. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Normalized model name (provider prefix preserved if present) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): - return plugin_instance.normalize_model_for_tracking(model) - - return model - - # Providers where request_count should be used for credential selection - # instead of success_count (because failed requests also consume quota) - _REQUEST_COUNT_PROVIDERS = {"antigravity", "gemini_cli", "chutes", "nanogpt"} - - def _get_grouped_usage_count(self, key: str, model: str) -> int: - """ - Get usage count for credential selection, considering quota groups. - - For providers in _REQUEST_COUNT_PROVIDERS (e.g., antigravity), uses - request_count instead of success_count since failed requests also - consume quota. - - If the model belongs to a quota group, the request_count is already - synced across all models in the group (by record_success/record_failure), - so we just read from the requested model directly. - - Args: - key: Credential identifier - model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5") - - Returns: - Usage count for the model (synced across group if applicable) - """ - # Determine usage field based on provider - # Some providers (antigravity) count failed requests against quota - provider = self._get_provider_from_credential(key) - usage_field = ( - "request_count" - if provider in self._REQUEST_COUNT_PROVIDERS - else "success_count" - ) - - # For providers with synced quota groups (antigravity), request_count - # is already synced across all models in the group, so just read directly. - # For other providers, we still need to sum success_count across group. - if provider in self._REQUEST_COUNT_PROVIDERS: - # request_count is synced - just read the model's value - return self._get_usage_count(key, model, usage_field) - - # For non-synced providers, check if model is in a quota group and sum - group = self._get_model_quota_group(key, model) - - if group: - # Get all models in the group - grouped_models = self._get_grouped_models(key, group) - - # Sum weighted usage across all models in the group - total_weighted_usage = 0 - for grouped_model in grouped_models: - usage = self._get_usage_count(key, grouped_model, usage_field) - weight = self._get_model_usage_weight(key, grouped_model) - total_weighted_usage += usage * weight - return total_weighted_usage - - # Not grouped - return individual model usage (no weight applied) - return self._get_usage_count(key, model, usage_field) - - def _get_quota_display(self, key: str, model: str) -> str: - """ - Get a formatted quota display string for logging. - - For antigravity (providers in _REQUEST_COUNT_PROVIDERS), returns: - "quota: 170/250 [32%]" format - - For other providers, returns: - "usage: 170" format (no max available) - - Args: - key: Credential identifier - model: Model name (with provider prefix) - - Returns: - Formatted string for logging - """ - provider = self._get_provider_from_credential(key) - - if provider not in self._REQUEST_COUNT_PROVIDERS: - # Non-antigravity: just show usage count - usage = self._get_usage_count(key, model, "success_count") - return f"usage: {usage}" - - # Antigravity: show quota display with remaining percentage - if self._usage_data is None: - return "quota: 0/? [100%]" - - # Normalize model name for consistent lookup (data is stored under normalized names) - model = self._normalize_model(key, model) - - key_data = self._usage_data.get(key, {}) - model_data = key_data.get("models", {}).get(model, {}) - - request_count = model_data.get("request_count", 0) - max_requests = model_data.get("quota_max_requests") - - if max_requests: - remaining = max_requests - request_count - remaining_pct = ( - int((remaining / max_requests) * 100) if max_requests > 0 else 0 - ) - return f"quota: {request_count}/{max_requests} [{remaining_pct}%]" - else: - return f"quota: {request_count}" - - def _get_usage_field_name(self, credential: str) -> str: - """ - Get the usage tracking field name for a credential. - - Returns the provider-specific field name if configured, - otherwise falls back to "daily". - - Args: - credential: The credential identifier - - Returns: - Field name string (e.g., "5h_window", "weekly", "daily") - """ - config = self._get_usage_reset_config(credential) - if config and "field_name" in config: - return config["field_name"] - - # Check provider default - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_default_usage_field_name"): - return plugin_instance.get_default_usage_field_name() - - return "daily" - - def _get_usage_count( - self, key: str, model: str, field: str = "success_count" - ) -> int: - """ - Get the current usage count for a model from the appropriate usage structure. - - Supports both: - - New per-model structure: {"models": {"model_name": {"success_count": N, ...}}} - - Legacy structure: {"daily": {"models": {"model_name": {"success_count": N, ...}}}} - - Args: - key: Credential identifier - model: Model name - field: The field to read for usage count (default: "success_count"). - Use "request_count" for providers where failed requests also - consume quota (e.g., antigravity). - - Returns: - Usage count for the model in the current window/period - """ - if self._usage_data is None: - return 0 - - # Normalize model name for consistent lookup (data is stored under normalized names) - model = self._normalize_model(key, model) - - key_data = self._usage_data.get(key, {}) - reset_mode = self._get_reset_mode(key) - - if reset_mode == "per_model": - # New per-model structure: key_data["models"][model][field] - return key_data.get("models", {}).get(model, {}).get(field, 0) - else: - # Legacy structure: key_data["daily"]["models"][model][field] - return ( - key_data.get("daily", {}).get("models", {}).get(model, {}).get(field, 0) - ) - - # ========================================================================= - # TIMESTAMP FORMATTING HELPERS - # ========================================================================= - - def _format_timestamp_local(self, ts: Optional[float]) -> Optional[str]: - """ - Format Unix timestamp as local time string with timezone offset. - - Args: - ts: Unix timestamp or None - - Returns: - Formatted string like "2025-12-07 14:30:17 +0100" or None - """ - if ts is None: - return None - try: - dt = datetime.fromtimestamp(ts).astimezone() # Local timezone - # Use UTC offset for conciseness (works on all platforms) - return dt.strftime("%Y-%m-%d %H:%M:%S %z") - except (OSError, ValueError, OverflowError): - return None - - def _add_readable_timestamps(self, data: Dict) -> Dict: - """ - Add human-readable timestamp fields to usage data before saving. - - Adds 'window_started' and 'quota_resets' fields derived from - Unix timestamps for easier debugging and monitoring. - - Args: - data: The usage data dict to enhance - - Returns: - The same dict with readable timestamp fields added - """ - for key, key_data in data.items(): - # Handle per-model structure - models = key_data.get("models", {}) - for model_name, model_stats in models.items(): - if not isinstance(model_stats, dict): - continue - - # Add readable window start time - window_start = model_stats.get("window_start_ts") - if window_start: - model_stats["window_started"] = self._format_timestamp_local( - window_start - ) - elif "window_started" in model_stats: - del model_stats["window_started"] - - # Add readable reset time - quota_reset = model_stats.get("quota_reset_ts") - if quota_reset: - model_stats["quota_resets"] = self._format_timestamp_local( - quota_reset - ) - elif "quota_resets" in model_stats: - del model_stats["quota_resets"] - - return data - - def _sort_sequential( - self, - candidates: List[Tuple[str, int]], - credential_priorities: Optional[Dict[str, int]] = None, - ) -> List[Tuple[str, int]]: - """ - Sort credentials for sequential mode with position retention. - - Credentials maintain their position based on established usage patterns, - ensuring that actively-used credentials remain primary until exhausted. - - Sorting order (within each sort key, lower value = higher priority): - 1. Priority tier (lower number = higher priority) - 2. Usage count (higher = more established in rotation, maintains position) - 3. Last used timestamp (higher = more recent, tiebreaker for stickiness) - 4. Credential ID (alphabetical, stable ordering) - - Args: - candidates: List of (credential_id, usage_count) tuples - credential_priorities: Optional dict mapping credentials to priority levels - - Returns: - Sorted list of candidates (same format as input) - """ - if not candidates: - return [] - - if len(candidates) == 1: - return candidates - - def sort_key(item: Tuple[str, int]) -> Tuple[int, int, float, str]: - cred, usage_count = item - priority = ( - credential_priorities.get(cred, 999) if credential_priorities else 999 - ) - last_used = ( - self._usage_data.get(cred, {}).get("last_used_ts", 0) - if self._usage_data - else 0 - ) - return ( - priority, # ASC: lower priority number = higher priority - -usage_count, # DESC: higher usage = more established - -last_used, # DESC: more recent = preferred for ties - cred, # ASC: stable alphabetical ordering - ) - - sorted_candidates = sorted(candidates, key=sort_key) - - # Debug logging - show top 3 credentials in ordering - if lib_logger.isEnabledFor(logging.DEBUG): - order_info = [ - f"{mask_credential(c)}(p={credential_priorities.get(c, 999) if credential_priorities else 'N/A'}, u={u})" - for c, u in sorted_candidates[:3] - ] - lib_logger.debug(f"Sequential ordering: {' → '.join(order_info)}") - - return sorted_candidates - - # ========================================================================= - # FAIR CYCLE PERSISTENCE - # ========================================================================= - - def _serialize_cycle_state(self) -> Dict[str, Any]: - """ - Serialize in-memory cycle state for JSON persistence. - - Converts sets to lists for JSON compatibility. - """ - result: Dict[str, Any] = {} - for provider, tier_data in self._cycle_exhausted.items(): - result[provider] = {} - for tier_key, tracking_data in tier_data.items(): - result[provider][tier_key] = {} - for tracking_key, cycle_data in tracking_data.items(): - result[provider][tier_key][tracking_key] = { - "cycle_started_at": cycle_data.get("cycle_started_at"), - "exhausted": list(cycle_data.get("exhausted", set())), - } - return result - - def _deserialize_cycle_state(self, data: Dict[str, Any]) -> None: - """ - Deserialize cycle state from JSON and populate in-memory structure. - - Converts lists back to sets and validates expired cycles. - """ - self._cycle_exhausted = {} - now_ts = time.time() - - for provider, tier_data in data.items(): - if not isinstance(tier_data, dict): - continue - self._cycle_exhausted[provider] = {} - - for tier_key, tracking_data in tier_data.items(): - if not isinstance(tracking_data, dict): - continue - self._cycle_exhausted[provider][tier_key] = {} - - for tracking_key, cycle_data in tracking_data.items(): - if not isinstance(cycle_data, dict): - continue - - cycle_started = cycle_data.get("cycle_started_at") - exhausted_list = cycle_data.get("exhausted", []) - - # Check if cycle has expired - if cycle_started is not None: - duration = self._get_fair_cycle_duration(provider) - if now_ts >= cycle_started + duration: - # Cycle expired - skip (don't restore) - lib_logger.debug( - f"Fair cycle expired for {provider}/{tier_key}/{tracking_key} - not restoring" - ) - continue - - # Restore valid cycle - self._cycle_exhausted[provider][tier_key][tracking_key] = { - "cycle_started_at": cycle_started, - "exhausted": set(exhausted_list) if exhausted_list else set(), - } - - # Log restoration summary - total_cycles = sum( - len(tracking) - for tier in self._cycle_exhausted.values() - for tracking in tier.values() - ) - if total_cycles > 0: - lib_logger.info(f"Restored {total_cycles} active fair cycle(s) from disk") - - async def _lazy_init(self): - """Initializes the usage data by loading it from the file asynchronously.""" - async with self._init_lock: - if not self._initialized.is_set(): - await self._load_usage() - await self._reset_daily_stats_if_needed() - self._initialized.set() - - async def _load_usage(self): - """Loads usage data from the JSON file asynchronously with resilience.""" - async with self._data_lock: - if not os.path.exists(self.file_path): - self._usage_data = {} - return - - try: - async with aiofiles.open(self.file_path, "r") as f: - content = await f.read() - self._usage_data = json.loads(content) if content.strip() else {} - except FileNotFoundError: - # File deleted between exists check and open - self._usage_data = {} - except json.JSONDecodeError as e: - lib_logger.warning( - f"Corrupted usage file {self.file_path}: {e}. Starting fresh." - ) - self._usage_data = {} - except (OSError, PermissionError, IOError) as e: - lib_logger.warning( - f"Cannot read usage file {self.file_path}: {e}. Using empty state." - ) - self._usage_data = {} - - # Restore fair cycle state from persisted data - fair_cycle_data = self._usage_data.get("__fair_cycle__", {}) - if fair_cycle_data: - self._deserialize_cycle_state(fair_cycle_data) - - async def _save_usage(self): - """Saves the current usage data using the resilient state writer.""" - if self._usage_data is None: - return - - async with self._data_lock: - # Add human-readable timestamp fields before saving - self._add_readable_timestamps(self._usage_data) - - # Persist fair cycle state (separate from credential data) - if self._cycle_exhausted: - self._usage_data["__fair_cycle__"] = self._serialize_cycle_state() - elif "__fair_cycle__" in self._usage_data: - # Clean up empty cycle data - del self._usage_data["__fair_cycle__"] - - # Hand off to resilient writer - handles retries and disk failures - self._state_writer.write(self._usage_data) - - async def _get_usage_data_snapshot(self) -> Dict[str, Any]: - """ - Get a shallow copy of the current usage data. - - Returns: - Copy of usage data dict (safe for reading without lock) - """ - await self._lazy_init() - async with self._data_lock: - return dict(self._usage_data) if self._usage_data else {} - - async def get_available_credentials_for_model( - self, credentials: List[str], model: str - ) -> List[str]: - """ - Get credentials that are not on cooldown for a specific model. - - Filters out credentials where: - - key_cooldown_until > now (key-level cooldown) - - model_cooldowns[model] > now (model-specific cooldown, includes quota exhausted) - - Args: - credentials: List of credential identifiers to check - model: Model name to check cooldowns for - - Returns: - List of credentials that are available (not on cooldown) for this model - """ - await self._lazy_init() - now = time.time() - available = [] - - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - - # Skip if key-level cooldown is active - if (key_data.get("key_cooldown_until") or 0) > now: - continue - - # Normalize model name for consistent cooldown lookup - # (cooldowns are stored under normalized names by record_failure) - # For providers without normalize_model_for_tracking (non-Antigravity), - # this returns the model unchanged, so cooldown lookups work as before. - normalized_model = self._normalize_model(key, model) - - # Skip if model-specific cooldown is active - if ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) > now: - continue - - available.append(key) - - return available - - async def get_credential_availability_stats( - self, - credentials: List[str], - model: str, - credential_priorities: Optional[Dict[str, int]] = None, - ) -> Dict[str, int]: - """ - Get credential availability statistics including cooldown and fair cycle exclusions. - - This is used for logging to show why credentials are excluded. - - Args: - credentials: List of credential identifiers to check - model: Model name to check - credential_priorities: Optional dict mapping credentials to priorities - - Returns: - Dict with: - "total": Total credentials - "on_cooldown": Count on cooldown - "fair_cycle_excluded": Count excluded by fair cycle - "available": Count available for selection - """ - await self._lazy_init() - now = time.time() - - total = len(credentials) - on_cooldown = 0 - not_on_cooldown = [] - - # First pass: check cooldowns - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - - # Check if key-level or model-level cooldown is active - normalized_model = self._normalize_model(key, model) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) > now: - on_cooldown += 1 - else: - not_on_cooldown.append(key) - - # Second pass: check fair cycle exclusions (only for non-cooldown credentials) - fair_cycle_excluded = 0 - if not_on_cooldown: - provider = self._get_provider_from_credential(not_on_cooldown[0]) - if provider: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - # Check each credential against its own tier's exhausted set - for key in not_on_cooldown: - key_priority = ( - credential_priorities.get(key, 999) - if credential_priorities - else 999 - ) - tier_key = self._get_tier_key(provider, key_priority) - tracking_key = self._get_tracking_key(key, model, provider) - - if self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ): - fair_cycle_excluded += 1 - - available = total - on_cooldown - fair_cycle_excluded - - return { - "total": total, - "on_cooldown": on_cooldown, - "fair_cycle_excluded": fair_cycle_excluded, - "available": available, - } - - async def get_soonest_cooldown_end( - self, - credentials: List[str], - model: str, - ) -> Optional[float]: - """ - Find the soonest time when any credential will come off cooldown. - - This is used for smart waiting logic - if no credentials are available, - we can determine whether to wait (if soonest cooldown < deadline) or - fail fast (if soonest cooldown > deadline). - - Args: - credentials: List of credential identifiers to check - model: Model name to check cooldowns for - - Returns: - Timestamp of soonest cooldown end, or None if no credentials are on cooldown - """ - await self._lazy_init() - now = time.time() - soonest_end = None - - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - normalized_model = self._normalize_model(key, model) - - # Check key-level cooldown - key_cooldown = key_data.get("key_cooldown_until") or 0 - if key_cooldown > now: - if soonest_end is None or key_cooldown < soonest_end: - soonest_end = key_cooldown - - # Check model-level cooldown - model_cooldown = ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) - if model_cooldown > now: - if soonest_end is None or model_cooldown < soonest_end: - soonest_end = model_cooldown - - return soonest_end - - async def _reset_daily_stats_if_needed(self): - """ - Checks if usage stats need to be reset for any key. - - Supports three reset modes: - 1. per_model: Each model has its own window, resets based on quota_reset_ts or fallback window - 2. credential: One window per credential (legacy with custom window duration) - 3. daily: Legacy daily reset at daily_reset_time_utc - """ - if self._usage_data is None: - return - - now_utc = datetime.now(timezone.utc) - now_ts = time.time() - today_str = now_utc.date().isoformat() - needs_saving = False - - for key, data in self._usage_data.items(): - reset_config = self._get_usage_reset_config(key) - - if reset_config: - reset_mode = reset_config.get("mode", "credential") - - if reset_mode == "per_model": - # Per-model window reset - needs_saving |= await self._check_per_model_resets( - key, data, reset_config, now_ts - ) - else: - # Credential-level window reset (legacy) - needs_saving |= await self._check_window_reset( - key, data, reset_config, now_ts - ) - elif self.daily_reset_time_utc: - # Legacy daily reset - needs_saving |= await self._check_daily_reset( - key, data, now_utc, today_str, now_ts - ) - - if needs_saving: - await self._save_usage() - - async def _check_per_model_resets( - self, - key: str, - data: Dict[str, Any], - reset_config: Dict[str, Any], - now_ts: float, - ) -> bool: - """ - Check and perform per-model resets for a credential. - - Each model resets independently based on: - 1. quota_reset_ts (authoritative, from quota exhausted error) if set - 2. window_start_ts + window_seconds (fallback) otherwise - - Grouped models reset together - all models in a group must be ready. - - Args: - key: Credential identifier - data: Usage data for this credential - reset_config: Provider's reset configuration - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - window_seconds = reset_config.get("window_seconds", 86400) - models_data = data.get("models", {}) - - if not models_data: - return False - - modified = False - processed_groups = set() - - for model, model_data in list(models_data.items()): - # Check if this model is in a quota group - group = self._get_model_quota_group(key, model) - - if group: - if group in processed_groups: - continue # Already handled this group - - # Check if entire group should reset - if self._should_group_reset( - key, group, models_data, window_seconds, now_ts - ): - # Archive and reset all models in group - grouped_models = self._get_grouped_models(key, group) - archived_count = 0 - - for grouped_model in grouped_models: - if grouped_model in models_data: - gm_data = models_data[grouped_model] - self._archive_model_to_global(data, grouped_model, gm_data) - self._reset_model_data(gm_data) - archived_count += 1 - - if archived_count > 0: - lib_logger.info( - f"Reset model group '{group}' ({archived_count} models) for {mask_credential(key)}" - ) - modified = True - - processed_groups.add(group) - - else: - # Ungrouped model - check individually - if self._should_model_reset(model_data, window_seconds, now_ts): - self._archive_model_to_global(data, model, model_data) - self._reset_model_data(model_data) - lib_logger.info(f"Reset model {model} for {mask_credential(key)}") - modified = True - - # Preserve unexpired cooldowns - if modified: - self._preserve_unexpired_cooldowns(key, data, now_ts) - if "failures" in data: - data["failures"] = {} - - return modified - - def _should_model_reset( - self, model_data: Dict[str, Any], window_seconds: int, now_ts: float - ) -> bool: - """ - Check if a single model should reset. - - Returns True if: - - quota_reset_ts is set AND now >= quota_reset_ts, OR - - quota_reset_ts is NOT set AND now >= window_start_ts + window_seconds - """ - quota_reset = model_data.get("quota_reset_ts") - window_start = model_data.get("window_start_ts") - - if quota_reset: - return now_ts >= quota_reset - elif window_start: - return now_ts >= window_start + window_seconds - return False - - def _should_group_reset( - self, - key: str, - group: str, - models_data: Dict[str, Dict], - window_seconds: int, - now_ts: float, - ) -> bool: - """ - Check if all models in a group should reset. - - All models in the group must be ready to reset. - If any model has an active cooldown/window, the whole group waits. - """ - grouped_models = self._get_grouped_models(key, group) - - # Track if any model in group has data - any_has_data = False - - for grouped_model in grouped_models: - model_data = models_data.get(grouped_model, {}) - - if not model_data or ( - model_data.get("window_start_ts") is None - and model_data.get("success_count", 0) == 0 - ): - continue # No stats for this model yet - - any_has_data = True - - if not self._should_model_reset(model_data, window_seconds, now_ts): - return False # At least one model not ready - - return any_has_data - - def _archive_model_to_global( - self, data: Dict[str, Any], model: str, model_data: Dict[str, Any] - ) -> None: - """Archive a single model's stats to global.""" - global_data = data.setdefault("global", {"models": {}}) - global_model = global_data["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - - global_model["success_count"] += model_data.get("success_count", 0) - global_model["prompt_tokens"] += model_data.get("prompt_tokens", 0) - global_model["prompt_tokens_cached"] = global_model.get( - "prompt_tokens_cached", 0 - ) + model_data.get("prompt_tokens_cached", 0) - global_model["completion_tokens"] += model_data.get("completion_tokens", 0) - global_model["approx_cost"] += model_data.get("approx_cost", 0.0) - - def _reset_model_data(self, model_data: Dict[str, Any]) -> None: - """Reset a model's window and stats.""" - model_data["window_start_ts"] = None - model_data["quota_reset_ts"] = None - model_data["success_count"] = 0 - model_data["failure_count"] = 0 - model_data["request_count"] = 0 - model_data["prompt_tokens"] = 0 - model_data["completion_tokens"] = 0 - model_data["approx_cost"] = 0.0 - # Reset quota baseline fields only if they exist (Antigravity-specific) - # These are added by update_quota_baseline(), only called for Antigravity - if "baseline_remaining_fraction" in model_data: - model_data["baseline_remaining_fraction"] = None - model_data["baseline_fetched_at"] = None - model_data["requests_at_baseline"] = None - # Reset quota display but keep max_requests (it doesn't change between periods) - max_req = model_data.get("quota_max_requests") - if max_req: - model_data["quota_display"] = f"0/{max_req}" - - async def _check_window_reset( - self, - key: str, - data: Dict[str, Any], - reset_config: Dict[str, Any], - now_ts: float, - ) -> bool: - """ - Check and perform rolling window reset for a credential. - - Args: - key: Credential identifier - data: Usage data for this credential - reset_config: Provider's reset configuration - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - window_seconds = reset_config.get("window_seconds", 86400) # Default 24h - field_name = reset_config.get("field_name", "window") - description = reset_config.get("description", "rolling window") - - # Get current window data - window_data = data.get(field_name, {}) - window_start = window_data.get("start_ts") - - # No window started yet - nothing to reset - if window_start is None: - return False - - # Check if window has expired - window_end = window_start + window_seconds - if now_ts < window_end: - # Window still active - return False - - # Window expired - perform reset - hours_elapsed = (now_ts - window_start) / 3600 - lib_logger.info( - f"Resetting {field_name} for {mask_credential(key)} - " - f"{description} expired after {hours_elapsed:.1f}h" - ) - - # Archive to global - self._archive_to_global(data, window_data) - - # Preserve unexpired cooldowns - self._preserve_unexpired_cooldowns(key, data, now_ts) - - # Reset window stats (but don't start new window until first request) - data[field_name] = {"start_ts": None, "models": {}} - - # Reset consecutive failures - if "failures" in data: - data["failures"] = {} - - return True - - async def _check_daily_reset( - self, - key: str, - data: Dict[str, Any], - now_utc: datetime, - today_str: str, - now_ts: float, - ) -> bool: - """ - Check and perform legacy daily reset for a credential. - - Args: - key: Credential identifier - data: Usage data for this credential - now_utc: Current datetime in UTC - today_str: Today's date as ISO string - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - last_reset_str = data.get("last_daily_reset", "") - - if last_reset_str == today_str: - return False - - last_reset_dt = None - if last_reset_str: - try: - last_reset_dt = datetime.fromisoformat(last_reset_str).replace( - tzinfo=timezone.utc - ) - except ValueError: - pass - - # Determine the reset threshold for today - reset_threshold_today = datetime.combine( - now_utc.date(), self.daily_reset_time_utc - ) - - if not ( - last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc - ): - return False - - lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}") - - # Preserve unexpired cooldowns - self._preserve_unexpired_cooldowns(key, data, now_ts) - - # Reset consecutive failures - if "failures" in data: - data["failures"] = {} - - # Archive daily stats to global - daily_data = data.get("daily", {}) - if daily_data: - self._archive_to_global(data, daily_data) - - # Reset daily stats - data["daily"] = {"date": today_str, "models": {}} - data["last_daily_reset"] = today_str - - return True - - def _archive_to_global( - self, data: Dict[str, Any], source_data: Dict[str, Any] - ) -> None: - """ - Archive usage stats from a source field (daily/window) to global. - - Args: - data: The credential's usage data - source_data: The source field data to archive (has "models" key) - """ - global_data = data.setdefault("global", {"models": {}}) - for model, stats in source_data.get("models", {}).items(): - global_model_stats = global_data["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - global_model_stats["success_count"] += stats.get("success_count", 0) - global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) - global_model_stats["prompt_tokens_cached"] = global_model_stats.get( - "prompt_tokens_cached", 0 - ) + stats.get("prompt_tokens_cached", 0) - global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) - global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) - - def _preserve_unexpired_cooldowns( - self, key: str, data: Dict[str, Any], now_ts: float - ) -> None: - """ - Preserve unexpired cooldowns during reset (important for long quota cooldowns). - - Args: - key: Credential identifier (for logging) - data: The credential's usage data - now_ts: Current timestamp - """ - # Preserve unexpired model cooldowns - if "model_cooldowns" in data: - active_cooldowns = { - model: end_time - for model, end_time in data["model_cooldowns"].items() - if end_time > now_ts - } - if active_cooldowns: - max_remaining = max( - end_time - now_ts for end_time in active_cooldowns.values() - ) - hours_remaining = max_remaining / 3600 - lib_logger.info( - f"Preserving {len(active_cooldowns)} active cooldown(s) " - f"for key {mask_credential(key)} during reset " - f"(longest: {hours_remaining:.1f}h remaining)" - ) - data["model_cooldowns"] = active_cooldowns - else: - data["model_cooldowns"] = {} - - # Preserve unexpired key-level cooldown - if data.get("key_cooldown_until"): - if data["key_cooldown_until"] <= now_ts: - data["key_cooldown_until"] = None - else: - hours_remaining = (data["key_cooldown_until"] - now_ts) / 3600 - lib_logger.info( - f"Preserving key-level cooldown for {mask_credential(key)} " - f"during reset ({hours_remaining:.1f}h remaining)" - ) - else: - data["key_cooldown_until"] = None - - def _initialize_key_states(self, keys: List[str]): - """Initializes state tracking for all provided keys if not already present.""" - for key in keys: - if key not in self.key_states: - self.key_states[key] = { - "lock": asyncio.Lock(), - "condition": asyncio.Condition(), - "models_in_use": {}, # Dict[model_name, concurrent_count] - } - - def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str: - """ - Selects a credential using weighted random selection based on usage counts. - - Args: - candidates: List of (credential_id, usage_count) tuples - tolerance: Tolerance value for weight calculation - - Returns: - Selected credential ID - - Formula: - weight = (max_usage - credential_usage) + tolerance + 1 - - This formula ensures: - - Lower usage = higher weight = higher selection probability - - Tolerance adds variability: higher tolerance means more randomness - - The +1 ensures all credentials have at least some chance of selection - """ - if not candidates: - raise ValueError("Cannot select from empty candidate list") - - if len(candidates) == 1: - return candidates[0][0] - - # Extract usage counts - usage_counts = [usage for _, usage in candidates] - max_usage = max(usage_counts) - - # Calculate weights using the formula: (max - current) + tolerance + 1 - weights = [] - for credential, usage in candidates: - weight = (max_usage - usage) + tolerance + 1 - weights.append(weight) - - # Log weight distribution for debugging - if lib_logger.isEnabledFor(logging.DEBUG): - total_weight = sum(weights) - weight_info = ", ".join( - f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)" - for (cred, _), w in zip(candidates, weights) - ) - # lib_logger.debug(f"Weighted selection candidates: {weight_info}") - - # Random selection with weights - selected_credential = random.choices( - [cred for cred, _ in candidates], weights=weights, k=1 - )[0] - - return selected_credential - - async def acquire_key( - self, - available_keys: List[str], - model: str, - deadline: float, - max_concurrent: int = 1, - credential_priorities: Optional[Dict[str, int]] = None, - credential_tier_names: Optional[Dict[str, str]] = None, - all_provider_credentials: Optional[List[str]] = None, - ) -> str: - """ - Acquires the best available key using a tiered, model-aware locking strategy, - respecting a global deadline and credential priorities. - - Priority Logic: - - Groups credentials by priority level (1=highest, 2=lower, etc.) - - Always tries highest priority (lowest number) first - - Within same priority, sorts by usage count (load balancing) - - Only moves to next priority if all higher-priority keys exhausted/busy - - Args: - available_keys: List of credential identifiers to choose from - model: Model name being requested - deadline: Timestamp after which to stop trying - max_concurrent: Maximum concurrent requests allowed per credential - credential_priorities: Optional dict mapping credentials to priority levels (1=highest) - credential_tier_names: Optional dict mapping credentials to tier names (for logging) - all_provider_credentials: Full list of provider credentials (used for cycle reset checks) - - Returns: - Selected credential identifier - - Raises: - NoAvailableKeysError: If no key could be acquired within the deadline - """ - await self._lazy_init() - await self._reset_daily_stats_if_needed() - self._initialize_key_states(available_keys) - - # Normalize model name for consistent cooldown lookup - # (cooldowns are stored under normalized names by record_failure) - # Use first credential for provider detection; all credentials passed here - # are for the same provider (filtered by client.py before calling acquire_key). - # For providers without normalize_model_for_tracking (non-Antigravity), - # this returns the model unchanged, so cooldown lookups work as before. - normalized_model = ( - self._normalize_model(available_keys[0], model) if available_keys else model - ) - - # This loop continues as long as the global deadline has not been met. - while time.time() < deadline: - now = time.time() - - # Group credentials by priority level (if priorities provided) - if credential_priorities: - # Group keys by priority level - priority_groups = {} - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - - # Skip keys on cooldown (use normalized model for lookup) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) - or 0 - ) > now: - continue - - # Get priority for this key (default to 999 if not specified) - priority = credential_priorities.get(key, 999) - - # Get usage count for load balancing within priority groups - # Uses grouped usage if model is in a quota group - usage_count = self._get_grouped_usage_count(key, model) - - # Group by priority - if priority not in priority_groups: - priority_groups[priority] = [] - priority_groups[priority].append((key, usage_count)) - - # Try priority groups in order (1, 2, 3, ...) - sorted_priorities = sorted(priority_groups.keys()) - - for priority_level in sorted_priorities: - keys_in_priority = priority_groups[priority_level] - - # Determine selection method based on provider's rotation mode - provider = model.split("/")[0] if "/" in model else "" - rotation_mode = self._get_rotation_mode(provider) - - # Fair cycle filtering - if provider and self._is_fair_cycle_enabled( - provider, rotation_mode - ): - tier_key = self._get_tier_key(provider, priority_level) - tracking_key = self._get_tracking_key( - keys_in_priority[0][0] if keys_in_priority else "", - model, - provider, - ) - - # Get all credentials for this tier (for cycle completion check) - all_tier_creds = self._get_all_credentials_for_tier_key( - provider, - tier_key, - all_provider_credentials or available_keys, - credential_priorities, - ) - - # Check if cycle should reset (all exhausted, expired, or none available) - if self._should_reset_cycle( - provider, - tier_key, - tracking_key, - all_tier_creds, - available_not_on_cooldown=[ - key for key, _ in keys_in_priority - ], - ): - self._reset_cycle(provider, tier_key, tracking_key) - - # Filter out exhausted credentials - filtered_keys = [] - for key, usage_count in keys_in_priority: - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ): - filtered_keys.append((key, usage_count)) - - keys_in_priority = filtered_keys - - # Calculate effective concurrency based on priority tier - multiplier = self._get_priority_multiplier( - provider, priority_level, rotation_mode - ) - effective_max_concurrent = max_concurrent * multiplier - - # Within each priority group, use existing tier1/tier2 logic - tier1_keys, tier2_keys = [], [] - for key, usage_count in keys_in_priority: - key_state = self.key_states[key] - - # Tier 1: Completely idle keys (preferred) - if not key_state["models_in_use"]: - tier1_keys.append((key, usage_count)) - # Tier 2: Keys that can accept more concurrent requests - elif ( - key_state["models_in_use"].get(model, 0) - < effective_max_concurrent - ): - tier2_keys.append((key, usage_count)) - - if rotation_mode == "sequential": - # Sequential mode: sort credentials by priority, usage, recency - # Keep all candidates in sorted order (no filtering to single key) - selection_method = "sequential" - if tier1_keys: - tier1_keys = self._sort_sequential( - tier1_keys, credential_priorities - ) - if tier2_keys: - tier2_keys = self._sort_sequential( - tier2_keys, credential_priorities - ) - elif self.rotation_tolerance > 0: - # Balanced mode with weighted randomness - selection_method = "weighted-random" - if tier1_keys: - selected_key = self._select_weighted_random( - tier1_keys, self.rotation_tolerance - ) - tier1_keys = [ - (k, u) for k, u in tier1_keys if k == selected_key - ] - if tier2_keys: - selected_key = self._select_weighted_random( - tier2_keys, self.rotation_tolerance - ) - tier2_keys = [ - (k, u) for k, u in tier2_keys if k == selected_key - ] - else: - # Deterministic: sort by usage within each tier - selection_method = "least-used" - tier1_keys.sort(key=lambda x: x[1]) - tier2_keys.sort(key=lambda x: x[1]) - - # Try to acquire from Tier 1 first - for key, usage in tier1_keys: - state = self.key_states[key] - async with state["lock"]: - if not state["models_in_use"]: - state["models_in_use"][model] = 1 - tier_name = ( - credential_tier_names.get(key, "unknown") - if credential_tier_names - else "unknown" - ) - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, {quota_display})" - ) - return key - - # Then try Tier 2 - for key, usage in tier2_keys: - state = self.key_states[key] - async with state["lock"]: - current_count = state["models_in_use"].get(model, 0) - if current_count < effective_max_concurrent: - state["models_in_use"][model] = current_count + 1 - tier_name = ( - credential_tier_names.get(key, "unknown") - if credential_tier_names - else "unknown" - ) - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" - ) - return key - - # If we get here, all priority groups were exhausted but keys might become available - # Collect all keys across all priorities for waiting - all_potential_keys = [] - for keys_list in priority_groups.values(): - all_potential_keys.extend(keys_list) - - if not all_potential_keys: - # All credentials are on cooldown - check if waiting makes sense - soonest_end = await self.get_soonest_cooldown_end( - available_keys, model - ) - - if soonest_end is None: - # No cooldowns active but no keys available (shouldn't happen) - lib_logger.warning( - "No keys eligible and no cooldowns active. Re-evaluating..." - ) - await asyncio.sleep(1) - continue - - remaining_budget = deadline - time.time() - wait_needed = soonest_end - time.time() - - if wait_needed > remaining_budget: - # Fail fast - no credential will be available in time - lib_logger.warning( - f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " - f"but only {remaining_budget:.1f}s budget remaining. Failing fast." - ) - break # Exit loop, will raise NoAvailableKeysError - - # Wait for the credential to become available - lib_logger.info( - f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." - ) - await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) - continue - - # Wait for the highest priority key with lowest usage - best_priority = min(priority_groups.keys()) - best_priority_keys = priority_groups[best_priority] - best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] - wait_condition = self.key_states[best_wait_key]["condition"] - - lib_logger.info( - f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." - ) - - else: - # Original logic when no priorities specified - - # Determine selection method based on provider's rotation mode - provider = model.split("/")[0] if "/" in model else "" - rotation_mode = self._get_rotation_mode(provider) - - # Calculate effective concurrency for default priority (999) - # When no priorities are specified, all credentials get default priority - default_priority = 999 - multiplier = self._get_priority_multiplier( - provider, default_priority, rotation_mode - ) - effective_max_concurrent = max_concurrent * multiplier - - tier1_keys, tier2_keys = [], [] - - # First, filter the list of available keys to exclude any on cooldown. - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - - # Skip keys on cooldown (use normalized model for lookup) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) - or 0 - ) > now: - continue - - # Prioritize keys based on their current usage to ensure load balancing. - # Uses grouped usage if model is in a quota group - usage_count = self._get_grouped_usage_count(key, model) - key_state = self.key_states[key] - - # Tier 1: Completely idle keys (preferred). - if not key_state["models_in_use"]: - tier1_keys.append((key, usage_count)) - # Tier 2: Keys that can accept more concurrent requests for this model. - elif ( - key_state["models_in_use"].get(model, 0) - < effective_max_concurrent - ): - tier2_keys.append((key, usage_count)) - - # Fair cycle filtering (non-priority case) - if provider and self._is_fair_cycle_enabled(provider, rotation_mode): - tier_key = self._get_tier_key(provider, default_priority) - tracking_key = self._get_tracking_key( - available_keys[0] if available_keys else "", - model, - provider, - ) - - # Get all credentials for this tier (for cycle completion check) - all_tier_creds = self._get_all_credentials_for_tier_key( - provider, - tier_key, - all_provider_credentials or available_keys, - None, - ) - - # Check if cycle should reset (all exhausted, expired, or none available) - if self._should_reset_cycle( - provider, - tier_key, - tracking_key, - all_tier_creds, - available_not_on_cooldown=[ - key for key, _ in (tier1_keys + tier2_keys) - ], - ): - self._reset_cycle(provider, tier_key, tracking_key) - - # Filter out exhausted credentials from both tiers - tier1_keys = [ - (key, usage) - for key, usage in tier1_keys - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ) - ] - tier2_keys = [ - (key, usage) - for key, usage in tier2_keys - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ) - ] - - if rotation_mode == "sequential": - # Sequential mode: sort credentials by priority, usage, recency - # Keep all candidates in sorted order (no filtering to single key) - selection_method = "sequential" - if tier1_keys: - tier1_keys = self._sort_sequential( - tier1_keys, credential_priorities - ) - if tier2_keys: - tier2_keys = self._sort_sequential( - tier2_keys, credential_priorities - ) - elif self.rotation_tolerance > 0: - # Balanced mode with weighted randomness - selection_method = "weighted-random" - if tier1_keys: - selected_key = self._select_weighted_random( - tier1_keys, self.rotation_tolerance - ) - tier1_keys = [ - (k, u) for k, u in tier1_keys if k == selected_key - ] - if tier2_keys: - selected_key = self._select_weighted_random( - tier2_keys, self.rotation_tolerance - ) - tier2_keys = [ - (k, u) for k, u in tier2_keys if k == selected_key - ] - else: - # Deterministic: sort by usage within each tier - selection_method = "least-used" - tier1_keys.sort(key=lambda x: x[1]) - tier2_keys.sort(key=lambda x: x[1]) - - # Attempt to acquire a key from Tier 1 first. - for key, usage in tier1_keys: - state = self.key_states[key] - async with state["lock"]: - if not state["models_in_use"]: - state["models_in_use"][model] = 1 - tier_name = ( - credential_tier_names.get(key) - if credential_tier_names - else None - ) - tier_info = f"tier: {tier_name}, " if tier_name else "" - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"({tier_info}selection: {selection_method}, {quota_display})" - ) - return key - - # If no Tier 1 keys are available, try Tier 2. - for key, usage in tier2_keys: - state = self.key_states[key] - async with state["lock"]: - current_count = state["models_in_use"].get(model, 0) - if current_count < effective_max_concurrent: - state["models_in_use"][model] = current_count + 1 - tier_name = ( - credential_tier_names.get(key) - if credential_tier_names - else None - ) - tier_info = f"tier: {tier_name}, " if tier_name else "" - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" - ) - return key - - # If all eligible keys are locked, wait for a key to be released. - lib_logger.info( - "All eligible keys are currently locked for this model. Waiting..." - ) - - all_potential_keys = tier1_keys + tier2_keys - if not all_potential_keys: - # All credentials are on cooldown - check if waiting makes sense - soonest_end = await self.get_soonest_cooldown_end( - available_keys, model - ) - - if soonest_end is None: - # No cooldowns active but no keys available (shouldn't happen) - lib_logger.warning( - "No keys eligible and no cooldowns active. Re-evaluating..." - ) - await asyncio.sleep(1) - continue - - remaining_budget = deadline - time.time() - wait_needed = soonest_end - time.time() - - if wait_needed > remaining_budget: - # Fail fast - no credential will be available in time - lib_logger.warning( - f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " - f"but only {remaining_budget:.1f}s budget remaining. Failing fast." - ) - break # Exit loop, will raise NoAvailableKeysError - - # Wait for the credential to become available - lib_logger.info( - f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." - ) - await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) - continue - - # Wait on the condition of the key with the lowest current usage. - best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] - wait_condition = self.key_states[best_wait_key]["condition"] - - try: - async with wait_condition: - remaining_budget = deadline - time.time() - if remaining_budget <= 0: - break # Exit if the budget has already been exceeded. - # Wait for a notification, but no longer than the remaining budget or 1 second. - await asyncio.wait_for( - wait_condition.wait(), timeout=min(1, remaining_budget) - ) - lib_logger.info("Notified that a key was released. Re-evaluating...") - except asyncio.TimeoutError: - # This is not an error, just a timeout for the wait. The main loop will re-evaluate. - lib_logger.info("Wait timed out. Re-evaluating for any available key.") - - # If the loop exits, it means the deadline was exceeded. - raise NoAvailableKeysError( - f"Could not acquire a key for model {model} within the global time budget." - ) - - async def release_key(self, key: str, model: str): - """Releases a key's lock for a specific model and notifies waiting tasks.""" - if key not in self.key_states: - return - - state = self.key_states[key] - async with state["lock"]: - if model in state["models_in_use"]: - state["models_in_use"][model] -= 1 - remaining = state["models_in_use"][model] - if remaining <= 0: - del state["models_in_use"][model] # Clean up when count reaches 0 - lib_logger.info( - f"Released credential {mask_credential(key)} from model {model} " - f"(remaining concurrent: {max(0, remaining)})" - ) - else: - lib_logger.warning( - f"Attempted to release credential {mask_credential(key)} for model {model}, but it was not in use." - ) - - # Notify all tasks waiting on this key's condition - async with state["condition"]: - state["condition"].notify_all() - - async def record_success( - self, - key: str, - model: str, - completion_response: Optional[litellm.ModelResponse] = None, - ): - """ - Records a successful API call, resetting failure counters. - It safely handles cases where token usage data is not available. - - Supports two modes based on provider configuration: - - per_model: Each model has its own window_start_ts and stats in key_data["models"] - - credential: Legacy mode with key_data["daily"]["models"] - """ - await self._lazy_init() - - # Normalize model name to public-facing name for consistent tracking - model = self._normalize_model(key, model) - - async with self._data_lock: - now_ts = time.time() - today_utc_str = datetime.now(timezone.utc).date().isoformat() - - reset_config = self._get_usage_reset_config(key) - reset_mode = ( - reset_config.get("mode", "credential") if reset_config else "credential" - ) - - if reset_mode == "per_model": - # New per-model structure - key_data = self._usage_data.setdefault( - key, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Ensure models dict exists - if "models" not in key_data: - key_data["models"] = {} - - # Get or create per-model data with window tracking - model_data = key_data["models"].setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - - # Start window on first request for this model - if model_data.get("window_start_ts") is None: - model_data["window_start_ts"] = now_ts - - # Set expected quota reset time from provider config - window_seconds = ( - reset_config.get("window_seconds", 0) if reset_config else 0 - ) - if window_seconds > 0: - model_data["quota_reset_ts"] = now_ts + window_seconds - - window_hours = window_seconds / 3600 if window_seconds else 0 - lib_logger.info( - f"Started {window_hours:.1f}h window for model {model} on {mask_credential(key)}" - ) - - # Record stats - model_data["success_count"] += 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Sync request_count across quota group (for providers with shared quota pools) - new_request_count = model_data["request_count"] - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = key_data["models"].setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - other_model_data["request_count"] = new_request_count - # Sync window timing (shared quota pool = shared window) - window_start = model_data.get("window_start_ts") - if window_start: - other_model_data["window_start_ts"] = window_start - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - other_model_data["quota_max_requests"] = max_req - other_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - - # Update quota_display if max_requests is set (Antigravity-specific) - max_req = model_data.get("quota_max_requests") - if max_req: - model_data["quota_display"] = ( - f"{model_data['request_count']}/{max_req}" - ) - - # Check custom cap - if self._check_and_apply_custom_cap( - key, model, model_data["request_count"] - ): - # Custom cap exceeded, cooldown applied - # Continue to record tokens/cost but credential will be skipped next time - pass - - usage_data_ref = model_data # For token/cost recording below - - else: - # Legacy credential-level structure - key_data = self._usage_data.setdefault( - key, - { - "daily": {"date": today_utc_str, "models": {}}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - if "last_daily_reset" not in key_data: - key_data["last_daily_reset"] = today_utc_str - - # Get or create model data in daily structure - usage_data_ref = key_data["daily"]["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - usage_data_ref["success_count"] += 1 - - # Reset failures for this model - model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) - model_failures["consecutive_failures"] = 0 - - # Clear transient cooldown on success (but NOT quota_reset_ts) - if model in key_data.get("model_cooldowns", {}): - del key_data["model_cooldowns"][model] - - # Record token and cost usage - if ( - completion_response - and hasattr(completion_response, "usage") - and completion_response.usage - ): - usage = completion_response.usage - prompt_total = usage.prompt_tokens - - # Extract cached tokens from prompt_tokens_details if present - cached_tokens = 0 - prompt_details = getattr(usage, "prompt_tokens_details", None) - if prompt_details: - if isinstance(prompt_details, dict): - cached_tokens = prompt_details.get("cached_tokens", 0) or 0 - elif hasattr(prompt_details, "cached_tokens"): - cached_tokens = prompt_details.cached_tokens or 0 - - # Store uncached tokens (prompt_tokens is total, subtract cached) - uncached_tokens = prompt_total - cached_tokens - usage_data_ref["prompt_tokens"] += uncached_tokens - - # Store cached tokens separately - if cached_tokens > 0: - usage_data_ref["prompt_tokens_cached"] = ( - usage_data_ref.get("prompt_tokens_cached", 0) + cached_tokens - ) - - usage_data_ref["completion_tokens"] += getattr( - usage, "completion_tokens", 0 - ) - lib_logger.info( - f"Recorded usage from response object for key {mask_credential(key)}" - ) - try: - provider_name = model.split("/")[0] - provider_instance = self._get_provider_instance(provider_name) - - if provider_instance and getattr( - provider_instance, "skip_cost_calculation", False - ): - lib_logger.debug( - f"Skipping cost calculation for provider '{provider_name}' (custom provider)." - ) - else: - if isinstance(completion_response, litellm.EmbeddingResponse): - model_info = litellm.get_model_info(model) - input_cost = model_info.get("input_cost_per_token") - if input_cost: - cost = ( - completion_response.usage.prompt_tokens * input_cost - ) - else: - cost = None - else: - cost = litellm.completion_cost( - completion_response=completion_response, model=model - ) - - if cost is not None: - usage_data_ref["approx_cost"] += cost - except Exception as e: - lib_logger.warning( - f"Could not calculate cost for model {model}: {e}" - ) - elif isinstance(completion_response, asyncio.Future) or hasattr( - completion_response, "__aiter__" - ): - pass # Stream - usage recorded from chunks - else: - lib_logger.warning( - f"No usage data found in completion response for model {model}. Recording success without token count." - ) - - key_data["last_used_ts"] = now_ts - - await self._save_usage() - - async def record_failure( - self, - key: str, - model: str, - classified_error: ClassifiedError, - increment_consecutive_failures: bool = True, - ): - """Records a failure and applies cooldowns based on error type. - - Distinguishes between: - - quota_exceeded: Long cooldown with exact reset time (from quota_reset_timestamp) - Sets quota_reset_ts on model (and group) - this becomes authoritative stats reset time - - rate_limit: Short transient cooldown (just wait and retry) - Only sets model_cooldowns - does NOT affect stats reset timing - - Args: - key: The API key or credential identifier - model: The model name - classified_error: The classified error object - increment_consecutive_failures: Whether to increment the failure counter. - Set to False for provider-level errors that shouldn't count against the key. - """ - await self._lazy_init() - - # Normalize model name to public-facing name for consistent tracking - model = self._normalize_model(key, model) - - async with self._data_lock: - now_ts = time.time() - today_utc_str = datetime.now(timezone.utc).date().isoformat() - - reset_config = self._get_usage_reset_config(key) - reset_mode = ( - reset_config.get("mode", "credential") if reset_config else "credential" - ) - - # Initialize key data with appropriate structure - if reset_mode == "per_model": - key_data = self._usage_data.setdefault( - key, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - else: - key_data = self._usage_data.setdefault( - key, - { - "daily": {"date": today_utc_str, "models": {}}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Provider-level errors (transient issues) should not count against the key - provider_level_errors = {"server_error", "api_connection"} - - # Determine if we should increment the failure counter - should_increment = ( - increment_consecutive_failures - and classified_error.error_type not in provider_level_errors - ) - - # Calculate cooldown duration based on error type - cooldown_seconds = None - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - - # Capture existing cooldown BEFORE we modify it - # Used to determine if this is a fresh exhaustion vs re-processing - existing_cooldown_before = model_cooldowns.get(model) - was_already_on_cooldown = ( - existing_cooldown_before is not None - and existing_cooldown_before > now_ts - ) - - if classified_error.error_type == "quota_exceeded": - # Quota exhausted - use authoritative reset timestamp if available - quota_reset_ts = classified_error.quota_reset_timestamp - cooldown_seconds = ( - classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT - ) - - if quota_reset_ts and reset_mode == "per_model": - # Set quota_reset_ts on model - this becomes authoritative stats reset time - models_data = key_data.setdefault("models", {}) - model_data = models_data.setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - model_data["quota_reset_ts"] = quota_reset_ts - # Track failure for quota estimation (request still consumes quota) - model_data["failure_count"] = model_data.get("failure_count", 0) + 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Clamp request_count to quota_max_requests when quota is exhausted - # This prevents display overflow (e.g., 151/150) when requests are - # counted locally before API refresh corrects the value - max_req = model_data.get("quota_max_requests") - if max_req is not None and model_data["request_count"] > max_req: - model_data["request_count"] = max_req - # Update quota_display with clamped value - model_data["quota_display"] = f"{max_req}/{max_req}" - new_request_count = model_data["request_count"] - - # Apply to all models in the same quota group - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - group_model_data = models_data.setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - group_model_data["quota_reset_ts"] = quota_reset_ts - # Sync request_count across quota group - group_model_data["request_count"] = new_request_count - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - group_model_data["quota_max_requests"] = max_req - group_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - # Also set transient cooldown for selection logic - model_cooldowns[grouped_model] = quota_reset_ts - - reset_dt = datetime.fromtimestamp( - quota_reset_ts, tz=timezone.utc - ) - lib_logger.info( - f"Quota exhausted for group '{group}' ({len(grouped_models)} models) " - f"on {mask_credential(key)}. Resets at {reset_dt.isoformat()}" - ) - else: - reset_dt = datetime.fromtimestamp( - quota_reset_ts, tz=timezone.utc - ) - hours = (quota_reset_ts - now_ts) / 3600 - lib_logger.info( - f"Quota exhausted for model {model} on {mask_credential(key)}. " - f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" - ) - - # Set transient cooldown for selection logic - model_cooldowns[model] = quota_reset_ts - else: - # No authoritative timestamp or legacy mode - just use retry_after - model_cooldowns[model] = now_ts + cooldown_seconds - hours = cooldown_seconds / 3600 - lib_logger.info( - f"Quota exhausted on {mask_credential(key)} for model {model}. " - f"Cooldown: {cooldown_seconds}s ({hours:.1f}h)" - ) - - # Mark credential as exhausted for fair cycle if cooldown exceeds threshold - # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) - # This prevents re-marking after cycle reset - if not was_already_on_cooldown: - effective_cooldown = ( - (quota_reset_ts - now_ts) - if quota_reset_ts - else (cooldown_seconds or 0) - ) - provider = self._get_provider_from_credential(key) - if provider: - threshold = self._get_exhaustion_cooldown_threshold(provider) - if effective_cooldown > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - priority = self._get_credential_priority(key, provider) - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key( - key, model, provider - ) - self._mark_credential_exhausted( - key, provider, tier_key, tracking_key - ) - - elif classified_error.error_type == "rate_limit": - # Transient rate limit - just set short cooldown (does NOT set quota_reset_ts) - cooldown_seconds = ( - classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT - ) - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.info( - f"Rate limit on {mask_credential(key)} for model {model}. " - f"Transient cooldown: {cooldown_seconds}s" - ) - - elif classified_error.error_type == "authentication": - # Apply a 5-minute key-level lockout for auth errors - key_data["key_cooldown_until"] = now_ts + COOLDOWN_AUTH_ERROR - cooldown_seconds = COOLDOWN_AUTH_ERROR - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.warning( - f"Authentication error on key {mask_credential(key)}. Applying 5-minute key-level lockout." - ) - - # If we should increment failures, calculate escalating backoff - if should_increment: - failures_data = key_data.setdefault("failures", {}) - model_failures = failures_data.setdefault( - model, {"consecutive_failures": 0} - ) - model_failures["consecutive_failures"] += 1 - count = model_failures["consecutive_failures"] - - # If cooldown wasn't set by specific error type, use escalating backoff - if cooldown_seconds is None: - cooldown_seconds = COOLDOWN_BACKOFF_TIERS.get( - count, COOLDOWN_BACKOFF_MAX - ) - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.warning( - f"Failure #{count} for key {mask_credential(key)} with model {model}. " - f"Error type: {classified_error.error_type}, cooldown: {cooldown_seconds}s" - ) - else: - # Provider-level errors: apply short cooldown but don't count against key - if cooldown_seconds is None: - cooldown_seconds = COOLDOWN_TRANSIENT_ERROR - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.info( - f"Provider-level error ({classified_error.error_type}) for key {mask_credential(key)} " - f"with model {model}. NOT incrementing failures. Cooldown: {cooldown_seconds}s" - ) - - # Check for key-level lockout condition - await self._check_key_lockout(key, key_data) - - # Track failure count for quota estimation (all failures consume quota) - # This is separate from consecutive_failures which is for backoff logic - if reset_mode == "per_model": - models_data = key_data.setdefault("models", {}) - model_data = models_data.setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - # Only increment if not already incremented in quota_exceeded branch - if classified_error.error_type != "quota_exceeded": - model_data["failure_count"] = model_data.get("failure_count", 0) + 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Sync request_count across quota group - new_request_count = model_data["request_count"] - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = models_data.setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - other_model_data["request_count"] = new_request_count - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - other_model_data["quota_max_requests"] = max_req - other_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - - key_data["last_failure"] = { - "timestamp": now_ts, - "model": model, - "error": str(classified_error.original_exception), - } - - await self._save_usage() - - async def update_quota_baseline( - self, - credential: str, - model: str, - remaining_fraction: float, - max_requests: Optional[int] = None, - reset_timestamp: Optional[float] = None, - ) -> Optional[Dict[str, Any]]: - """ - Update quota baseline data for a credential/model after fetching from API. - - This stores the current quota state as a baseline, which is used to - estimate remaining quota based on subsequent request counts. - - When quota is exhausted (remaining_fraction <= 0.0) and a valid reset_timestamp - is provided, this also sets model_cooldowns to prevent wasted requests. - - Args: - credential: Credential identifier (file path or env:// URI) - model: Model name (with or without provider prefix) - remaining_fraction: Current remaining quota as fraction (0.0 to 1.0) - max_requests: Maximum requests allowed per quota period (e.g., 250 for Claude) - reset_timestamp: Unix timestamp when quota resets. Only trusted when - remaining_fraction < 1.0 (quota has been used). API returns garbage - reset times for unused quota (100%). - - Returns: - None if no cooldown was set/updated, otherwise: - { - "group_or_model": str, # quota group name or model name if ungrouped - "hours_until_reset": float, - } - """ - await self._lazy_init() - async with self._data_lock: - now_ts = time.time() - - # Get or create key data structure - key_data = self._usage_data.setdefault( - credential, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Ensure models dict exists - if "models" not in key_data: - key_data["models"] = {} - - # Get or create per-model data - model_data = key_data["models"].setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - "baseline_remaining_fraction": None, - "baseline_fetched_at": None, - "requests_at_baseline": None, - }, - ) - - # Calculate actual used requests from API's remaining fraction - # The API is authoritative - sync our local count to match reality - if max_requests is not None: - used_requests = int((1.0 - remaining_fraction) * max_requests) - else: - # Estimate max_requests from provider's quota cost - # This matches how get_max_requests_for_model() calculates it - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr( - plugin_instance, "get_max_requests_for_model" - ): - # Get tier from provider's cache - tier = getattr(plugin_instance, "project_tier_cache", {}).get( - credential, "standard-tier" - ) - # Strip provider prefix from model if present - clean_model = model.split("/")[-1] if "/" in model else model - max_requests = plugin_instance.get_max_requests_for_model( - clean_model, tier - ) - used_requests = int((1.0 - remaining_fraction) * max_requests) - else: - # Fallback: keep existing count if we can't calculate - used_requests = model_data.get("request_count", 0) - max_requests = model_data.get("quota_max_requests") - - # Sync local request count to API's authoritative value - # Use max() to prevent API from resetting our count if it returns stale/cached 100% - # The API can only increase our count (if we missed requests), not decrease it - # See: https://github.com/Mirrowel/LLM-API-Key-Proxy/issues/75 - current_count = model_data.get("request_count", 0) - synced_count = max(current_count, used_requests) - model_data["request_count"] = synced_count - model_data["requests_at_baseline"] = synced_count - - # Update baseline fields - model_data["baseline_remaining_fraction"] = remaining_fraction - model_data["baseline_fetched_at"] = now_ts - - # Update max_requests and quota_display - if max_requests is not None: - model_data["quota_max_requests"] = max_requests - model_data["quota_display"] = f"{synced_count}/{max_requests}" - - # Handle reset_timestamp: only trust it when quota has been used (< 100%) - # API returns garbage reset times for unused quota - valid_reset_ts = ( - reset_timestamp is not None - and remaining_fraction < 1.0 - and reset_timestamp > now_ts - ) - - if valid_reset_ts: - model_data["quota_reset_ts"] = reset_timestamp - - # Set cooldowns when quota is exhausted - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - is_exhausted = remaining_fraction <= 0.0 - cooldown_set_info = ( - None # Will be returned if cooldown was newly set/updated - ) - - if is_exhausted and valid_reset_ts: - # Check if there was an existing ACTIVE cooldown before we update - # This distinguishes between fresh exhaustion vs refresh of existing state - existing_cooldown = model_cooldowns.get(model) - was_already_on_cooldown = ( - existing_cooldown is not None and existing_cooldown > now_ts - ) - - # Only update cooldown if not set or differs by more than 5 minutes - should_update = ( - existing_cooldown is None - or abs(existing_cooldown - reset_timestamp) > 300 - ) - if should_update: - model_cooldowns[model] = reset_timestamp - hours_until_reset = (reset_timestamp - now_ts) / 3600 - # Determine group or model name for logging - group = self._get_model_quota_group(credential, model) - cooldown_set_info = { - "group_or_model": group if group else model.split("/")[-1], - "hours_until_reset": hours_until_reset, - } - - # Mark credential as exhausted in fair cycle if cooldown exceeds threshold - # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) - # This prevents re-marking after cycle reset when quota refresh sees existing cooldown - if not was_already_on_cooldown: - cooldown_duration = reset_timestamp - now_ts - provider = self._get_provider_from_credential(credential) - if provider: - threshold = self._get_exhaustion_cooldown_threshold(provider) - if cooldown_duration > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - priority = self._get_credential_priority( - credential, provider - ) - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key( - credential, model, provider - ) - self._mark_credential_exhausted( - credential, provider, tier_key, tracking_key - ) - - # Defensive clamp: ensure request_count doesn't exceed max when exhausted - if ( - max_requests is not None - and model_data["request_count"] > max_requests - ): - model_data["request_count"] = max_requests - model_data["quota_display"] = f"{max_requests}/{max_requests}" - - # Sync baseline fields and quota info across quota group - group = self._get_model_quota_group(credential, model) - if group: - grouped_models = self._get_grouped_models(credential, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = key_data["models"].setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - # Sync request tracking (use synced_count to prevent reset bug) - other_model_data["request_count"] = synced_count - if max_requests is not None: - other_model_data["quota_max_requests"] = max_requests - other_model_data["quota_display"] = ( - f"{synced_count}/{max_requests}" - ) - # Sync baseline fields - other_model_data["baseline_remaining_fraction"] = ( - remaining_fraction - ) - other_model_data["baseline_fetched_at"] = now_ts - other_model_data["requests_at_baseline"] = synced_count - # Sync reset timestamp if valid - if valid_reset_ts: - other_model_data["quota_reset_ts"] = reset_timestamp - # Sync window start time - window_start = model_data.get("window_start_ts") - if window_start: - other_model_data["window_start_ts"] = window_start - # Sync cooldown if exhausted (with ±5 min check) - if is_exhausted and valid_reset_ts: - existing_grouped = model_cooldowns.get(grouped_model) - should_update_grouped = ( - existing_grouped is None - or abs(existing_grouped - reset_timestamp) > 300 - ) - if should_update_grouped: - model_cooldowns[grouped_model] = reset_timestamp - - # Defensive clamp for grouped models when exhausted - if ( - max_requests is not None - and other_model_data["request_count"] > max_requests - ): - other_model_data["request_count"] = max_requests - other_model_data["quota_display"] = ( - f"{max_requests}/{max_requests}" - ) - - lib_logger.debug( - f"Updated quota baseline for {mask_credential(credential)} model={model}: " - f"remaining={remaining_fraction:.2%}, synced_request_count={synced_count}" - ) - - await self._save_usage() - return cooldown_set_info - - async def _check_key_lockout(self, key: str, key_data: Dict): - """ - Checks if a key should be locked out due to multiple model failures. - - NOTE: This check is currently disabled. The original logic counted individual - models in long-term lockout, but this caused issues with quota groups - when - a single quota group (e.g., "claude" with 5 models) was exhausted, it would - count as 5 lockouts and trigger key-level lockout, blocking other quota groups - (like gemini) that were still available. - - The per-model and per-group cooldowns already handle quota exhaustion properly. - """ - # Disabled - see docstring above - pass - - async def get_stats_for_endpoint( - self, - provider_filter: Optional[str] = None, - include_global: bool = True, - ) -> Dict[str, Any]: - """ - Get usage stats formatted for the /v1/quota-stats endpoint. - - Aggregates data from key_usage.json grouped by provider. - Includes both current period stats and global (lifetime) stats. - - Args: - provider_filter: If provided, only return stats for this provider - include_global: If True, include global/lifetime stats alongside current - - Returns: - { - "providers": { - "provider_name": { - "credential_count": int, - "active_count": int, - "on_cooldown_count": int, - "total_requests": int, - "tokens": { - "input_cached": int, - "input_uncached": int, - "input_cache_pct": float, - "output": int - }, - "approx_cost": float | None, - "credentials": [...], - "global": {...} # If include_global is True - } - }, - "summary": {...}, - "global_summary": {...}, # If include_global is True - "timestamp": float - } - """ - await self._lazy_init() - - now_ts = time.time() - providers: Dict[str, Dict[str, Any]] = {} - # Track global stats separately - global_providers: Dict[str, Dict[str, Any]] = {} - - async with self._data_lock: - if not self._usage_data: - return { - "providers": {}, - "summary": { - "total_providers": 0, - "total_credentials": 0, - "active_credentials": 0, - "exhausted_credentials": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_total_cost": 0.0, - }, - "global_summary": { - "total_providers": 0, - "total_credentials": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_total_cost": 0.0, - }, - "data_source": "cache", - "timestamp": now_ts, - } - - for credential, cred_data in self._usage_data.items(): - # Extract provider from credential path - provider = self._get_provider_from_credential(credential) - if not provider: - continue - - # Apply filter if specified - if provider_filter and provider != provider_filter: - continue - - # Initialize provider entry - if provider not in providers: - providers[provider] = { - "credential_count": 0, - "active_count": 0, - "on_cooldown_count": 0, - "exhausted_count": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_cost": 0.0, - "credentials": [], - } - global_providers[provider] = { - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_cost": 0.0, - } - - prov_stats = providers[provider] - prov_stats["credential_count"] += 1 - - # Determine credential status and cooldowns - key_cooldown = cred_data.get("key_cooldown_until", 0) or 0 - model_cooldowns = cred_data.get("model_cooldowns", {}) - - # Build active cooldowns with remaining time - active_cooldowns = {} - for model, cooldown_ts in model_cooldowns.items(): - if cooldown_ts > now_ts: - remaining_seconds = int(cooldown_ts - now_ts) - active_cooldowns[model] = { - "until_ts": cooldown_ts, - "remaining_seconds": remaining_seconds, - } - - key_cooldown_remaining = None - if key_cooldown > now_ts: - key_cooldown_remaining = int(key_cooldown - now_ts) - - has_active_cooldown = key_cooldown > now_ts or len(active_cooldowns) > 0 - - # Check if exhausted (all quota groups exhausted for Antigravity) - is_exhausted = False - models_data = cred_data.get("models", {}) - if models_data: - # Check if any model has remaining quota - all_exhausted = True - for model_stats in models_data.values(): - if isinstance(model_stats, dict): - baseline = model_stats.get("baseline_remaining_fraction") - if baseline is None or baseline > 0: - all_exhausted = False - break - if all_exhausted and len(models_data) > 0: - is_exhausted = True - - if is_exhausted: - prov_stats["exhausted_count"] += 1 - status = "exhausted" - elif has_active_cooldown: - prov_stats["on_cooldown_count"] += 1 - status = "cooldown" - else: - prov_stats["active_count"] += 1 - status = "active" - - # Aggregate token stats (current period) - cred_tokens = { - "input_cached": 0, - "input_uncached": 0, - "output": 0, - } - cred_requests = 0 - cred_cost = 0.0 - - # Aggregate global token stats - cred_global_tokens = { - "input_cached": 0, - "input_uncached": 0, - "output": 0, - } - cred_global_requests = 0 - cred_global_cost = 0.0 - - # Handle per-model structure (current period) - if models_data: - for model_name, model_stats in models_data.items(): - if not isinstance(model_stats, dict): - continue - # Prefer request_count if available and non-zero, else fall back to success+failure - req_count = model_stats.get("request_count", 0) - if req_count > 0: - cred_requests += req_count - else: - cred_requests += model_stats.get("success_count", 0) - cred_requests += model_stats.get("failure_count", 0) - # Token stats - track cached separately - cred_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_tokens["input_uncached"] += model_stats.get( - "prompt_tokens", 0 - ) - cred_tokens["output"] += model_stats.get("completion_tokens", 0) - cred_cost += model_stats.get("approx_cost", 0.0) - - # Handle legacy daily structure - daily_data = cred_data.get("daily", {}) - daily_models = daily_data.get("models", {}) - for model_name, model_stats in daily_models.items(): - if not isinstance(model_stats, dict): - continue - cred_requests += model_stats.get("success_count", 0) - cred_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_tokens["input_uncached"] += model_stats.get("prompt_tokens", 0) - cred_tokens["output"] += model_stats.get("completion_tokens", 0) - cred_cost += model_stats.get("approx_cost", 0.0) - - # Handle global stats - global_data = cred_data.get("global", {}) - global_models = global_data.get("models", {}) - for model_name, model_stats in global_models.items(): - if not isinstance(model_stats, dict): - continue - cred_global_requests += model_stats.get("success_count", 0) - cred_global_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_global_tokens["input_uncached"] += model_stats.get( - "prompt_tokens", 0 - ) - cred_global_tokens["output"] += model_stats.get( - "completion_tokens", 0 - ) - cred_global_cost += model_stats.get("approx_cost", 0.0) - - # Add current period stats to global totals - cred_global_requests += cred_requests - cred_global_tokens["input_cached"] += cred_tokens["input_cached"] - cred_global_tokens["input_uncached"] += cred_tokens["input_uncached"] - cred_global_tokens["output"] += cred_tokens["output"] - cred_global_cost += cred_cost - - # Build credential entry - # Mask credential identifier for display - if credential.startswith("env://"): - identifier = credential - else: - identifier = Path(credential).name - - cred_entry = { - "identifier": identifier, - "full_path": credential, - "status": status, - "last_used_ts": cred_data.get("last_used_ts"), - "requests": cred_requests, - "tokens": cred_tokens, - "approx_cost": cred_cost if cred_cost > 0 else None, - } - - # Add cooldown info - if key_cooldown_remaining is not None: - cred_entry["key_cooldown_remaining"] = key_cooldown_remaining - if active_cooldowns: - cred_entry["model_cooldowns"] = active_cooldowns - - # Add global stats for this credential - if include_global: - # Calculate global cache percentage - global_total_input = ( - cred_global_tokens["input_cached"] - + cred_global_tokens["input_uncached"] - ) - global_cache_pct = ( - round( - cred_global_tokens["input_cached"] - / global_total_input - * 100, - 1, - ) - if global_total_input > 0 - else 0 - ) - - cred_entry["global"] = { - "requests": cred_global_requests, - "tokens": { - "input_cached": cred_global_tokens["input_cached"], - "input_uncached": cred_global_tokens["input_uncached"], - "input_cache_pct": global_cache_pct, - "output": cred_global_tokens["output"], - }, - "approx_cost": cred_global_cost - if cred_global_cost > 0 - else None, - } - - # Add model-specific data for providers with per-model tracking - if models_data: - cred_entry["models"] = {} - for model_name, model_stats in models_data.items(): - if not isinstance(model_stats, dict): - continue - cred_entry["models"][model_name] = { - "requests": model_stats.get("success_count", 0) - + model_stats.get("failure_count", 0), - "request_count": model_stats.get("request_count", 0), - "success_count": model_stats.get("success_count", 0), - "failure_count": model_stats.get("failure_count", 0), - "prompt_tokens": model_stats.get("prompt_tokens", 0), - "prompt_tokens_cached": model_stats.get( - "prompt_tokens_cached", 0 - ), - "completion_tokens": model_stats.get( - "completion_tokens", 0 - ), - "approx_cost": model_stats.get("approx_cost", 0.0), - "window_start_ts": model_stats.get("window_start_ts"), - "quota_reset_ts": model_stats.get("quota_reset_ts"), - # Quota baseline fields (Antigravity-specific) - "baseline_remaining_fraction": model_stats.get( - "baseline_remaining_fraction" - ), - "baseline_fetched_at": model_stats.get( - "baseline_fetched_at" - ), - "quota_max_requests": model_stats.get("quota_max_requests"), - "quota_display": model_stats.get("quota_display"), - } - - prov_stats["credentials"].append(cred_entry) - - # Aggregate to provider totals (current period) - prov_stats["total_requests"] += cred_requests - prov_stats["tokens"]["input_cached"] += cred_tokens["input_cached"] - prov_stats["tokens"]["input_uncached"] += cred_tokens["input_uncached"] - prov_stats["tokens"]["output"] += cred_tokens["output"] - if cred_cost > 0: - prov_stats["approx_cost"] += cred_cost - - # Aggregate to global provider totals - global_providers[provider]["total_requests"] += cred_global_requests - global_providers[provider]["tokens"]["input_cached"] += ( - cred_global_tokens["input_cached"] - ) - global_providers[provider]["tokens"]["input_uncached"] += ( - cred_global_tokens["input_uncached"] - ) - global_providers[provider]["tokens"]["output"] += cred_global_tokens[ - "output" - ] - global_providers[provider]["approx_cost"] += cred_global_cost - - # Calculate cache percentages for each provider - for provider, prov_stats in providers.items(): - total_input = ( - prov_stats["tokens"]["input_cached"] - + prov_stats["tokens"]["input_uncached"] - ) - if total_input > 0: - prov_stats["tokens"]["input_cache_pct"] = round( - prov_stats["tokens"]["input_cached"] / total_input * 100, 1 - ) - # Set cost to None if 0 - if prov_stats["approx_cost"] == 0: - prov_stats["approx_cost"] = None - - # Calculate global cache percentages - if include_global and provider in global_providers: - gp = global_providers[provider] - global_total = ( - gp["tokens"]["input_cached"] + gp["tokens"]["input_uncached"] - ) - if global_total > 0: - gp["tokens"]["input_cache_pct"] = round( - gp["tokens"]["input_cached"] / global_total * 100, 1 - ) - if gp["approx_cost"] == 0: - gp["approx_cost"] = None - prov_stats["global"] = gp - - # Build summary (current period) - total_creds = sum(p["credential_count"] for p in providers.values()) - active_creds = sum(p["active_count"] for p in providers.values()) - exhausted_creds = sum(p["exhausted_count"] for p in providers.values()) - total_requests = sum(p["total_requests"] for p in providers.values()) - total_input_cached = sum( - p["tokens"]["input_cached"] for p in providers.values() - ) - total_input_uncached = sum( - p["tokens"]["input_uncached"] for p in providers.values() - ) - total_output = sum(p["tokens"]["output"] for p in providers.values()) - total_cost = sum(p["approx_cost"] or 0 for p in providers.values()) - - total_input = total_input_cached + total_input_uncached - input_cache_pct = ( - round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0 - ) - - result = { - "providers": providers, - "summary": { - "total_providers": len(providers), - "total_credentials": total_creds, - "active_credentials": active_creds, - "exhausted_credentials": exhausted_creds, - "total_requests": total_requests, - "tokens": { - "input_cached": total_input_cached, - "input_uncached": total_input_uncached, - "input_cache_pct": input_cache_pct, - "output": total_output, - }, - "approx_total_cost": total_cost if total_cost > 0 else None, - }, - "data_source": "cache", - "timestamp": now_ts, - } - - # Build global summary - if include_global: - global_total_requests = sum( - gp["total_requests"] for gp in global_providers.values() - ) - global_total_input_cached = sum( - gp["tokens"]["input_cached"] for gp in global_providers.values() - ) - global_total_input_uncached = sum( - gp["tokens"]["input_uncached"] for gp in global_providers.values() - ) - global_total_output = sum( - gp["tokens"]["output"] for gp in global_providers.values() - ) - global_total_cost = sum( - gp["approx_cost"] or 0 for gp in global_providers.values() - ) - - global_total_input = global_total_input_cached + global_total_input_uncached - global_input_cache_pct = ( - round(global_total_input_cached / global_total_input * 100, 1) - if global_total_input > 0 - else 0 - ) - - result["global_summary"] = { - "total_providers": len(global_providers), - "total_credentials": total_creds, - "total_requests": global_total_requests, - "tokens": { - "input_cached": global_total_input_cached, - "input_uncached": global_total_input_uncached, - "input_cache_pct": global_input_cache_pct, - "output": global_total_output, - }, - "approx_total_cost": global_total_cost - if global_total_cost > 0 - else None, - } - - return result - - async def reload_from_disk(self) -> None: - """ - Force reload usage data from disk. - - Useful when another process may have updated the file. - """ - async with self._init_lock: - self._initialized.clear() - await self._load_usage() - await self._reset_daily_stats_if_needed() - self._initialized.set() +__all__ = ["UsageManager", "CredentialContext"] diff --git a/src/rotator_library/utils/__init__.py b/src/rotator_library/utils/__init__.py index ce8d959d..a51d1db7 100644 --- a/src/rotator_library/utils/__init__.py +++ b/src/rotator_library/utils/__init__.py @@ -17,6 +17,7 @@ ResilientStateWriter, safe_write_json, safe_log_write, + safe_read_json, safe_mkdir, ) from .suppress_litellm_warnings import suppress_litellm_serialization_warnings @@ -34,6 +35,7 @@ "ResilientStateWriter", "safe_write_json", "safe_log_write", + "safe_read_json", "safe_mkdir", "suppress_litellm_serialization_warnings", ] diff --git a/src/rotator_library/utils/resilient_io.py b/src/rotator_library/utils/resilient_io.py index 1125e3b7..91e96f37 100644 --- a/src/rotator_library/utils/resilient_io.py +++ b/src/rotator_library/utils/resilient_io.py @@ -660,6 +660,36 @@ def safe_log_write( return False +def safe_read_json( + path: Union[str, Path], + logger: logging.Logger, + *, + parse_json: bool = True, +) -> Optional[Any]: + """ + Read file contents with error handling. + + Args: + path: File path to read from + logger: Logger for warnings/errors + parse_json: When True, parse JSON; when False, return raw text + + Returns: + Parsed JSON dict, raw text, or None on failure + """ + path = Path(path) + try: + with open(path, "r", encoding="utf-8") as f: + if parse_json: + return json.load(f) + return f.read() + except FileNotFoundError: + return None + except (OSError, PermissionError, IOError, json.JSONDecodeError) as e: + logger.error(f"Failed to read {path}: {e}") + return None + + def safe_mkdir(path: Union[str, Path], logger: logging.Logger) -> bool: """ Create directory with error handling. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..07ec5a39 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +SRC_DIR = ROOT / "src" + +if str(SRC_DIR) not in sys.path: + sys.path.insert(0, str(SRC_DIR)) diff --git a/tests/fixtures/openai_codex/error_missing_instructions.json b/tests/fixtures/openai_codex/error_missing_instructions.json new file mode 100644 index 00000000..528acb8e --- /dev/null +++ b/tests/fixtures/openai_codex/error_missing_instructions.json @@ -0,0 +1 @@ +{"detail":"Instructions are required"} diff --git a/tests/fixtures/openai_codex/error_stream_required.json b/tests/fixtures/openai_codex/error_stream_required.json new file mode 100644 index 00000000..f90fb371 --- /dev/null +++ b/tests/fixtures/openai_codex/error_stream_required.json @@ -0,0 +1 @@ +{"detail":"Stream must be set to true"} diff --git a/tests/fixtures/openai_codex/error_unsupported_verbosity.json b/tests/fixtures/openai_codex/error_unsupported_verbosity.json new file mode 100644 index 00000000..0ab6622e --- /dev/null +++ b/tests/fixtures/openai_codex/error_unsupported_verbosity.json @@ -0,0 +1,8 @@ +{ + "error": { + "message": "Unsupported value: 'low' is not supported with the 'gpt-5.1-codex' model. Supported values are: 'medium'.", + "type": "invalid_request_error", + "param": "text.verbosity", + "code": "unsupported_value" + } +} diff --git a/tests/fixtures/openai_codex/protocol_notes.md b/tests/fixtures/openai_codex/protocol_notes.md new file mode 100644 index 00000000..2aa6f459 --- /dev/null +++ b/tests/fixtures/openai_codex/protocol_notes.md @@ -0,0 +1,72 @@ +# OpenAI Codex protocol capture (2026-02-12) + +Captured against `https://chatgpt.com/backend-api/codex/responses` using a valid Codex OAuth token from `~/.codex/auth.json`. + +## OAuth + +- Authorization endpoint: `https://auth.openai.com/oauth/authorize` +- Token endpoint: `https://auth.openai.com/oauth/token` +- Authorization code token exchange params: + - `grant_type=authorization_code` + - `client_id=app_EMoamEEZ73f0CkXaXp7hrann` + - `redirect_uri=http://localhost:/auth/callback` + - `code_verifier=` +- Refresh params: + - `grant_type=refresh_token` + - `refresh_token=` + - `client_id=app_EMoamEEZ73f0CkXaXp7hrann` + +## Endpoint + request shape + +- Endpoint: `POST /codex/responses` +- Requires `stream=true` (non-stream returns 400 with `{"detail":"Stream must be set to true"}`) +- Requires non-empty `instructions` (missing instructions returns 400 with `{"detail":"Instructions are required"}`) + +Observed working request body fields: + +- `model` +- `stream` (must be `true`) +- `store` (`false`) +- `instructions` +- `input` (Responses input format) +- `text.verbosity` (for `gpt-5.1-codex`, `low` was rejected; `medium` worked) +- `tool_choice` +- `parallel_tool_calls` + +## Headers + +Observed and/or validated for provider implementation: + +- `Authorization: Bearer ` +- `chatgpt-account-id: ` +- `OpenAI-Beta: responses=experimental` +- `originator: pi` +- `Accept: text/event-stream` +- `Content-Type: application/json` + +## SSE event taxonomy (observed) + +- `response.created` +- `response.in_progress` +- `response.output_item.added` +- `response.output_item.done` +- `response.content_part.added` +- `response.output_text.delta` +- `response.output_text.done` +- `response.content_part.done` +- `response.completed` + +Provider additionally supports planned aliases/events: + +- `response.content_part.delta` +- `response.function_call_arguments.delta` +- `response.function_call_arguments.done` +- `response.incomplete` +- `response.failed` +- `error` + +## Error body fixtures + +- `error_missing_instructions.json` +- `error_stream_required.json` +- `error_unsupported_verbosity.json` diff --git a/tests/fixtures/openai_codex/response_completed_event.json b/tests/fixtures/openai_codex/response_completed_event.json new file mode 100644 index 00000000..22f83f6a --- /dev/null +++ b/tests/fixtures/openai_codex/response_completed_event.json @@ -0,0 +1,77 @@ +{ + "type": "response.completed", + "response": { + "id": "id_redacted_10", + "object": "response", + "created_at": 1770926997, + "status": "completed", + "background": false, + "completed_at": 1770926998, + "error": null, + "frequency_penalty": 0.0, + "incomplete_details": null, + "instructions": "You are a concise assistant.", + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-5.1-codex", + "output": [ + { + "id": "id_redacted_11", + "type": "reasoning", + "summary": [] + }, + { + "id": "id_redacted_12", + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "annotations": [], + "logprobs": [], + "text": "pong" + } + ], + "role": "assistant" + } + ], + "parallel_tool_calls": true, + "presence_penalty": 0.0, + "previous_response_id": null, + "prompt_cache_key": "prompt_cache_key_redacted_4", + "prompt_cache_retention": null, + "reasoning": { + "effort": "medium", + "summary": null + }, + "safety_identifier": "safety_identifier_redacted_4", + "service_tier": "default", + "store": false, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + }, + "verbosity": "medium" + }, + "tool_choice": "auto", + "tools": [], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": { + "input_tokens": 21, + "input_tokens_details": { + "cached_tokens": 0 + }, + "output_tokens": 13, + "output_tokens_details": { + "reasoning_tokens": 0 + }, + "total_tokens": 34 + }, + "user": null, + "metadata": {} + }, + "sequence_number": 10 +} \ No newline at end of file diff --git a/tests/fixtures/openai_codex/stream_content_part_delta_events.json b/tests/fixtures/openai_codex/stream_content_part_delta_events.json new file mode 100644 index 00000000..e90034cc --- /dev/null +++ b/tests/fixtures/openai_codex/stream_content_part_delta_events.json @@ -0,0 +1,44 @@ +[ + { + "type": "response.created", + "response": { + "id": "resp_delta_1", + "created_at": 1770927001, + "status": "in_progress" + } + }, + { + "type": "response.output_item.added", + "item": { + "id": "msg_1", + "type": "message", + "status": "in_progress", + "role": "assistant" + } + }, + { + "type": "response.content_part.delta", + "item_id": "msg_1", + "delta": "Hello" + }, + { + "type": "response.content_part.delta", + "item_id": "msg_1", + "delta": " world" + }, + { + "type": "response.incomplete", + "response": { + "id": "resp_delta_1", + "status": "incomplete", + "incomplete_details": { + "reason": "max_output_tokens" + }, + "usage": { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30 + } + } + } +] diff --git a/tests/fixtures/openai_codex/stream_success_events.json b/tests/fixtures/openai_codex/stream_success_events.json new file mode 100644 index 00000000..c0028c2b --- /dev/null +++ b/tests/fixtures/openai_codex/stream_success_events.json @@ -0,0 +1,269 @@ +[ + { + "type": "response.created", + "response": { + "id": "id_redacted_1", + "object": "response", + "created_at": 1770926997, + "status": "in_progress", + "background": false, + "completed_at": null, + "error": null, + "frequency_penalty": 0.0, + "incomplete_details": null, + "instructions": "You are a concise assistant.", + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-5.1-codex", + "output": [], + "parallel_tool_calls": true, + "presence_penalty": 0.0, + "previous_response_id": null, + "prompt_cache_key": "prompt_cache_key_redacted_1", + "prompt_cache_retention": null, + "reasoning": { + "effort": "medium", + "summary": null + }, + "safety_identifier": "safety_identifier_redacted_1", + "service_tier": "auto", + "store": false, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + }, + "verbosity": "medium" + }, + "tool_choice": "auto", + "tools": [], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": null, + "user": null, + "metadata": {} + }, + "sequence_number": 0 + }, + { + "type": "response.in_progress", + "response": { + "id": "id_redacted_2", + "object": "response", + "created_at": 1770926997, + "status": "in_progress", + "background": false, + "completed_at": null, + "error": null, + "frequency_penalty": 0.0, + "incomplete_details": null, + "instructions": "You are a concise assistant.", + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-5.1-codex", + "output": [], + "parallel_tool_calls": true, + "presence_penalty": 0.0, + "previous_response_id": null, + "prompt_cache_key": "prompt_cache_key_redacted_2", + "prompt_cache_retention": null, + "reasoning": { + "effort": "medium", + "summary": null + }, + "safety_identifier": "safety_identifier_redacted_2", + "service_tier": "auto", + "store": false, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + }, + "verbosity": "medium" + }, + "tool_choice": "auto", + "tools": [], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": null, + "user": null, + "metadata": {} + }, + "sequence_number": 1 + }, + { + "type": "response.output_item.added", + "item": { + "id": "id_redacted_3", + "type": "reasoning", + "summary": [] + }, + "output_index": 0, + "sequence_number": 2 + }, + { + "type": "response.output_item.done", + "item": { + "id": "id_redacted_4", + "type": "reasoning", + "summary": [] + }, + "output_index": 0, + "sequence_number": 3 + }, + { + "type": "response.output_item.added", + "item": { + "id": "id_redacted_5", + "type": "message", + "status": "in_progress", + "content": [], + "role": "assistant" + }, + "output_index": 1, + "sequence_number": 4 + }, + { + "type": "response.content_part.added", + "content_index": 0, + "item_id": "item_id_redacted_1", + "output_index": 1, + "part": { + "type": "output_text", + "annotations": [], + "logprobs": [], + "text": "" + }, + "sequence_number": 5 + }, + { + "type": "response.output_text.delta", + "content_index": 0, + "delta": "pong", + "item_id": "item_id_redacted_2", + "logprobs": [], + "obfuscation": "obfuscation_redacted_1", + "output_index": 1, + "sequence_number": 6 + }, + { + "type": "response.output_text.done", + "content_index": 0, + "item_id": "item_id_redacted_3", + "logprobs": [], + "output_index": 1, + "sequence_number": 7, + "text": "pong" + }, + { + "type": "response.content_part.done", + "content_index": 0, + "item_id": "item_id_redacted_4", + "output_index": 1, + "part": { + "type": "output_text", + "annotations": [], + "logprobs": [], + "text": "pong" + }, + "sequence_number": 8 + }, + { + "type": "response.output_item.done", + "item": { + "id": "id_redacted_6", + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "annotations": [], + "logprobs": [], + "text": "pong" + } + ], + "role": "assistant" + }, + "output_index": 1, + "sequence_number": 9 + }, + { + "type": "response.completed", + "response": { + "id": "id_redacted_7", + "object": "response", + "created_at": 1770926997, + "status": "completed", + "background": false, + "completed_at": 1770926998, + "error": null, + "frequency_penalty": 0.0, + "incomplete_details": null, + "instructions": "You are a concise assistant.", + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-5.1-codex", + "output": [ + { + "id": "id_redacted_8", + "type": "reasoning", + "summary": [] + }, + { + "id": "id_redacted_9", + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "annotations": [], + "logprobs": [], + "text": "pong" + } + ], + "role": "assistant" + } + ], + "parallel_tool_calls": true, + "presence_penalty": 0.0, + "previous_response_id": null, + "prompt_cache_key": "prompt_cache_key_redacted_3", + "prompt_cache_retention": null, + "reasoning": { + "effort": "medium", + "summary": null + }, + "safety_identifier": "safety_identifier_redacted_3", + "service_tier": "default", + "store": false, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + }, + "verbosity": "medium" + }, + "tool_choice": "auto", + "tools": [], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": { + "input_tokens": 21, + "input_tokens_details": { + "cached_tokens": 0 + }, + "output_tokens": 13, + "output_tokens_details": { + "reasoning_tokens": 0 + }, + "total_tokens": 34 + }, + "user": null, + "metadata": {} + }, + "sequence_number": 10 + } +] \ No newline at end of file diff --git a/tests/fixtures/openai_codex/stream_tool_call_events.json b/tests/fixtures/openai_codex/stream_tool_call_events.json new file mode 100644 index 00000000..e2aca1f5 --- /dev/null +++ b/tests/fixtures/openai_codex/stream_tool_call_events.json @@ -0,0 +1,50 @@ +[ + { + "type": "response.created", + "response": { + "id": "resp_tool_1", + "created_at": 1770927000, + "status": "in_progress" + } + }, + { + "type": "response.output_item.added", + "item": { + "id": "call_item_1", + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "" + } + }, + { + "type": "response.function_call_arguments.delta", + "call_id": "call_1", + "delta": "{\"city\":\"San" + }, + { + "type": "response.function_call_arguments.delta", + "call_id": "call_1", + "delta": " Francisco\"}" + }, + { + "type": "response.function_call_arguments.done", + "call_id": "call_1", + "arguments": "{\"city\":\"San Francisco\"}" + }, + { + "type": "response.completed", + "response": { + "id": "resp_tool_1", + "status": "incomplete", + "incomplete_details": { + "reason": "tool_calls" + }, + "usage": { + "input_tokens": 50, + "output_tokens": 10, + "total_tokens": 60 + } + } + } +] diff --git a/tests/test_openai_codex_auth.py b/tests/test_openai_codex_auth.py new file mode 100644 index 00000000..0113812c --- /dev/null +++ b/tests/test_openai_codex_auth.py @@ -0,0 +1,379 @@ +import asyncio +import base64 +import json +import time +from pathlib import Path + +import pytest + +from rotator_library.providers.openai_codex_auth_base import ( + CALLBACK_PATH, + LEGACY_CALLBACK_PATH, + OpenAICodexAuthBase, +) + + +def _build_jwt(payload: dict) -> str: + header = {"alg": "HS256", "typ": "JWT"} + + def b64url(data: dict) -> str: + raw = json.dumps(data, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=") + + return f"{b64url(header)}.{b64url(payload)}.signature" + + +def test_callback_paths_match_codex_oauth_client_registration(): + assert CALLBACK_PATH == "/auth/callback" + assert LEGACY_CALLBACK_PATH == "/oauth2callback" + + +def test_decode_jwt_helper_valid_token(): + auth = OpenAICodexAuthBase() + payload = { + "sub": "user-123", + "email": "user@example.com", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"}, + } + token = _build_jwt(payload) + + decoded = auth._decode_jwt_unverified(token) + assert decoded is not None + assert decoded["sub"] == "user-123" + + +def test_decode_jwt_helper_malformed_token(): + auth = OpenAICodexAuthBase() + + assert auth._decode_jwt_unverified("not-a-jwt") is None + assert auth._decode_jwt_unverified("a.b") is None + + +def test_decode_jwt_helper_missing_claims_fallbacks(): + auth = OpenAICodexAuthBase() + + payload = {"sub": "fallback-sub", "exp": int(time.time()) + 300} + token = _build_jwt(payload) + + decoded = auth._decode_jwt_unverified(token) + email = auth._extract_email_from_payload(decoded) + account_id = auth._extract_account_id_from_payload(decoded) + + assert email == "fallback-sub" # email -> sub fallback chain + assert account_id is None + + +def test_ensure_proxy_metadata_prefers_id_token_explicit_email(): + auth = OpenAICodexAuthBase() + + access_payload = { + "sub": "workspace-sub-shared", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_workspace"}, + } + id_payload = { + "email": "real-user@example.com", + "sub": "user-sub-123", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_workspace"}, + } + + creds = { + "access_token": _build_jwt(access_payload), + "id_token": _build_jwt(id_payload), + "refresh_token": "rt_test", + } + + auth._ensure_proxy_metadata(creds) + + assert creds["_proxy_metadata"]["email"] == "real-user@example.com" + assert creds["_proxy_metadata"]["account_id"] == "acct_workspace" + + +def test_expiry_logic_with_proactive_buffer_and_true_expiry(): + auth = OpenAICodexAuthBase() + + now_ms = int(time.time() * 1000) + + # still valid (outside proactive buffer) + fresh = {"expiry_date": now_ms + 20 * 60 * 1000} + assert auth._is_token_expired(fresh) is False + assert auth._is_token_truly_expired(fresh) is False + + # proactive refresh window (expired for refresh, still truly valid) + near_expiry = {"expiry_date": now_ms + 60 * 1000} + assert auth._is_token_expired(near_expiry) is True + assert auth._is_token_truly_expired(near_expiry) is False + + # truly expired + expired = {"expiry_date": now_ms - 60 * 1000} + assert auth._is_token_expired(expired) is True + assert auth._is_token_truly_expired(expired) is True + + +@pytest.mark.asyncio +async def test_env_loading_legacy_and_numbered(monkeypatch): + auth = OpenAICodexAuthBase() + + payload = { + "sub": "env-user", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_env"}, + } + access = _build_jwt(payload) + refresh = "rt_env" + + monkeypatch.setenv("OPENAI_CODEX_ACCESS_TOKEN", access) + monkeypatch.setenv("OPENAI_CODEX_REFRESH_TOKEN", refresh) + + # legacy load + legacy = auth._load_from_env("0") + assert legacy is not None + assert legacy["access_token"] == access + assert legacy["_proxy_metadata"]["loaded_from_env"] is True + assert legacy["_proxy_metadata"]["account_id"] == "acct_env" + + # numbered load via env:// path + payload_n = { + "email": "numbered@example.com", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_num"}, + } + access_n = _build_jwt(payload_n) + monkeypatch.setenv("OPENAI_CODEX_1_ACCESS_TOKEN", access_n) + monkeypatch.setenv("OPENAI_CODEX_1_REFRESH_TOKEN", "rt_num") + + creds = await auth._load_credentials("env://openai_codex/1") + assert creds["access_token"] == access_n + assert creds["_proxy_metadata"]["env_credential_index"] == "1" + assert creds["_proxy_metadata"]["account_id"] == "acct_num" + + +@pytest.mark.asyncio +async def test_save_load_round_trip_with_proxy_metadata(tmp_path: Path): + auth = OpenAICodexAuthBase() + cred_path = tmp_path / "openai_codex_oauth_1.json" + + payload = { + "email": "roundtrip@example.com", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_roundtrip"}, + } + access = _build_jwt(payload) + + creds = { + "access_token": access, + "refresh_token": "rt_roundtrip", + "id_token": _build_jwt(payload), + "expiry_date": int((time.time() + 3600) * 1000), + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "roundtrip@example.com", + "account_id": "acct_roundtrip", + "last_check_timestamp": time.time(), + "loaded_from_env": False, + "env_credential_index": None, + }, + } + + assert await auth._save_credentials(str(cred_path), creds) is True + + # clear cache to verify disk round-trip + auth._credentials_cache.clear() + loaded = await auth._load_credentials(str(cred_path)) + + assert loaded["refresh_token"] == "rt_roundtrip" + assert loaded["_proxy_metadata"]["email"] == "roundtrip@example.com" + assert loaded["_proxy_metadata"]["account_id"] == "acct_roundtrip" + + +@pytest.mark.asyncio +async def test_is_credential_available_reauth_queue_and_ttl_cleanup(): + auth = OpenAICodexAuthBase() + path = "/tmp/openai_codex_oauth_1.json" + + # credential in active re-auth queue => unavailable + auth._unavailable_credentials[path] = time.time() + assert auth.is_credential_available(path) is False + + # stale unavailable entry should auto-clean and become available + auth._unavailable_credentials[path] = time.time() - 999 + auth._queued_credentials.add(path) + assert auth.is_credential_available(path) is True + assert path not in auth._unavailable_credentials + + # truly expired credential should be unavailable + auth._credentials_cache[path] = { + "expiry_date": int((time.time() - 10) * 1000), + "_proxy_metadata": {"loaded_from_env": False}, + } + assert auth.is_credential_available(path) is False + + # let background queue task schedule to avoid un-awaited coroutine warnings + await asyncio.sleep(0) + + +def test_find_existing_credential_identity_allows_same_email_different_account(tmp_path: Path): + auth = OpenAICodexAuthBase() + + existing = tmp_path / "openai_codex_oauth_1.json" + existing.write_text( + json.dumps( + { + "_proxy_metadata": { + "email": "shared@example.com", + "account_id": "acct_original", + } + } + ) + ) + + # Different account_id with same email should NOT be treated as an update target. + match = auth._find_existing_credential_by_identity( + email="shared@example.com", + account_id="acct_new", + base_dir=tmp_path, + ) + assert match is None + + # Exact account_id + email should still match. + match_same_identity = auth._find_existing_credential_by_identity( + email="shared@example.com", + account_id="acct_original", + base_dir=tmp_path, + ) + assert match_same_identity == existing + + # Email fallback should work when account_id is unknown. + match_email_fallback = auth._find_existing_credential_by_identity( + email="shared@example.com", + account_id=None, + base_dir=tmp_path, + ) + assert match_email_fallback == existing + + +def test_find_existing_credential_identity_allows_same_account_different_email(tmp_path: Path): + auth = OpenAICodexAuthBase() + + existing = tmp_path / "openai_codex_oauth_1.json" + existing.write_text( + json.dumps( + { + "_proxy_metadata": { + "email": "first@example.com", + "account_id": "acct_workspace", + } + } + ) + ) + + # Same account_id but different email should not auto-update when both + # identifiers are available (prevents workspace-level collisions). + match = auth._find_existing_credential_by_identity( + email="second@example.com", + account_id="acct_workspace", + base_dir=tmp_path, + ) + assert match is None + + +@pytest.mark.asyncio +async def test_setup_credential_creates_new_file_for_same_email_new_account(tmp_path: Path): + auth = OpenAICodexAuthBase() + + existing = tmp_path / "openai_codex_oauth_1.json" + existing.write_text( + json.dumps( + { + "access_token": "old_access", + "refresh_token": "old_refresh", + "expiry_date": int((time.time() + 3600) * 1000), + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "shared@example.com", + "account_id": "acct_original", + "loaded_from_env": False, + "env_credential_index": None, + }, + } + ) + ) + + async def fake_initialize_token(_creds): + return { + "access_token": "new_access", + "refresh_token": "new_refresh", + "id_token": "new_id", + "expiry_date": int((time.time() + 3600) * 1000), + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "shared@example.com", + "account_id": "acct_new", + "loaded_from_env": False, + "env_credential_index": None, + }, + } + + auth.initialize_token = fake_initialize_token + + result = await auth.setup_credential(base_dir=tmp_path) + + assert result.success is True + assert result.is_update is False + assert result.file_path is not None + assert result.file_path.endswith("openai_codex_oauth_2.json") + + files = sorted(p.name for p in tmp_path.glob("openai_codex_oauth_*.json")) + assert files == ["openai_codex_oauth_1.json", "openai_codex_oauth_2.json"] + + +@pytest.mark.asyncio +async def test_setup_credential_creates_new_file_for_same_account_new_email(tmp_path: Path): + auth = OpenAICodexAuthBase() + + existing = tmp_path / "openai_codex_oauth_1.json" + existing.write_text( + json.dumps( + { + "access_token": "old_access", + "refresh_token": "old_refresh", + "expiry_date": int((time.time() + 3600) * 1000), + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "first@example.com", + "account_id": "acct_workspace", + "loaded_from_env": False, + "env_credential_index": None, + }, + } + ) + ) + + async def fake_initialize_token(_creds): + return { + "access_token": "new_access", + "refresh_token": "new_refresh", + "id_token": "new_id", + "expiry_date": int((time.time() + 3600) * 1000), + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "second@example.com", + "account_id": "acct_workspace", + "loaded_from_env": False, + "env_credential_index": None, + }, + } + + auth.initialize_token = fake_initialize_token + + result = await auth.setup_credential(base_dir=tmp_path) + + assert result.success is True + assert result.is_update is False + assert result.file_path is not None + assert result.file_path.endswith("openai_codex_oauth_2.json") + + files = sorted(p.name for p in tmp_path.glob("openai_codex_oauth_*.json")) + assert files == ["openai_codex_oauth_1.json", "openai_codex_oauth_2.json"] diff --git a/tests/test_openai_codex_import.py b/tests/test_openai_codex_import.py new file mode 100644 index 00000000..a94c8f5a --- /dev/null +++ b/tests/test_openai_codex_import.py @@ -0,0 +1,217 @@ +import json +import os +import time +from pathlib import Path + +from rotator_library.credential_manager import CredentialManager + + +def _build_jwt(payload: dict) -> str: + import base64 + + header = {"alg": "HS256", "typ": "JWT"} + + def b64url(data: dict) -> str: + raw = json.dumps(data, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=") + + return f"{b64url(header)}.{b64url(payload)}.sig" + + +def _write_codex_auth_json(path: Path): + payload = { + "email": "single@example.com", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_single"}, + } + data = { + "auth_mode": "oauth", + "OPENAI_API_KEY": None, + "tokens": { + "id_token": _build_jwt(payload), + "access_token": _build_jwt(payload), + "refresh_token": "rt_single", + "account_id": "acct_single", + }, + "last_refresh": "2026-02-12T00:00:00Z", + } + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2)) + + +def _write_codex_accounts_json(path: Path): + payload_a = { + "email": "multi-a@example.com", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_a"}, + } + payload_b = { + "email": "multi-b@example.com", + "exp": int(time.time()) + 7200, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_b"}, + } + + data = { + "schemaVersion": 1, + "activeLabel": "A", + "accounts": [ + { + "label": "A", + "accountId": "acct_a", + "access": _build_jwt(payload_a), + "refresh": "rt_a", + "idToken": _build_jwt(payload_a), + "expires": int((time.time() + 3600) * 1000), + }, + { + "label": "B", + "accountId": "acct_b", + "access": _build_jwt(payload_b), + "refresh": "rt_b", + "idToken": _build_jwt(payload_b), + "expires": int((time.time() + 7200) * 1000), + }, + ], + } + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2)) + + +def test_import_from_codex_auth_and_accounts_formats(tmp_path: Path): + oauth_dir = tmp_path / "oauth_creds" + manager = CredentialManager(env_vars={}, oauth_dir=oauth_dir) + + auth_json = tmp_path / ".codex" / "auth.json" + accounts_json = tmp_path / ".codex-accounts.json" + _write_codex_auth_json(auth_json) + _write_codex_accounts_json(accounts_json) + + imported = manager._import_openai_codex_cli_credentials( + auth_json_path=auth_json, + accounts_json_path=accounts_json, + ) + + # one from auth.json + two from accounts.json + assert len(imported) == 3 + + imported_files = sorted(oauth_dir.glob("openai_codex_oauth_*.json")) + assert len(imported_files) == 3 + + payload = json.loads(imported_files[0].read_text()) + assert payload["refresh_token"].startswith("rt_") + assert "_proxy_metadata" in payload + assert payload["_proxy_metadata"].get("account_id") + + +def test_explicit_openai_codex_oauth_path_auth_json_is_normalized(tmp_path: Path): + oauth_dir = tmp_path / "oauth_creds" + + auth_json = tmp_path / ".codex" / "auth.json" + _write_codex_auth_json(auth_json) + + manager = CredentialManager( + env_vars={"OPENAI_CODEX_OAUTH_1": str(auth_json)}, + oauth_dir=oauth_dir, + ) + discovered = manager.discover_and_prepare() + + assert "openai_codex" in discovered + assert len(discovered["openai_codex"]) == 1 + + imported_file = oauth_dir / "openai_codex_oauth_1.json" + payload = json.loads(imported_file.read_text()) + + # normalized proxy schema at root level (not nested under "tokens") + assert "tokens" not in payload + assert isinstance(payload.get("access_token"), str) + assert isinstance(payload.get("refresh_token"), str) + assert payload.get("token_uri") == "https://auth.openai.com/oauth/token" + assert "_proxy_metadata" in payload + + +def test_skip_import_when_env_openai_codex_credentials_exist(tmp_path: Path): + oauth_dir = tmp_path / "oauth_creds" + manager = CredentialManager( + env_vars={ + "OPENAI_CODEX_ACCESS_TOKEN": "env_access", + "OPENAI_CODEX_REFRESH_TOKEN": "env_refresh", + }, + oauth_dir=oauth_dir, + ) + + discovered = manager.discover_and_prepare() + + assert discovered["openai_codex"] == ["env://openai_codex/0"] + assert list(oauth_dir.glob("openai_codex_oauth_*.json")) == [] + + +def test_skip_import_when_local_openai_codex_credentials_exist(tmp_path: Path): + oauth_dir = tmp_path / "oauth_creds" + oauth_dir.mkdir(parents=True, exist_ok=True) + + existing = oauth_dir / "openai_codex_oauth_1.json" + existing.write_text( + json.dumps( + { + "access_token": "existing", + "refresh_token": "existing_rt", + "expiry_date": int((time.time() + 3600) * 1000), + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "existing@example.com", + "account_id": "acct_existing", + "last_check_timestamp": time.time(), + "loaded_from_env": False, + "env_credential_index": None, + }, + }, + indent=2, + ) + ) + + manager = CredentialManager(env_vars={}, oauth_dir=oauth_dir) + discovered = manager.discover_and_prepare() + + assert "openai_codex" in discovered + assert discovered["openai_codex"] == [str(existing.resolve())] + + +def test_malformed_codex_source_files_are_handled_gracefully(tmp_path: Path): + oauth_dir = tmp_path / "oauth_creds" + manager = CredentialManager(env_vars={}, oauth_dir=oauth_dir) + + auth_json = tmp_path / ".codex" / "auth.json" + accounts_json = tmp_path / ".codex-accounts.json" + auth_json.parent.mkdir(parents=True, exist_ok=True) + + auth_json.write_text("{not valid json") + accounts_json.write_text(json.dumps({"schemaVersion": 1, "accounts": ["bad-entry"]})) + + imported = manager._import_openai_codex_cli_credentials( + auth_json_path=auth_json, + accounts_json_path=accounts_json, + ) + + assert imported == [] + assert list(oauth_dir.glob("openai_codex_oauth_*.json")) == [] + + +def test_codex_source_files_never_modified_during_import(tmp_path: Path): + oauth_dir = tmp_path / "oauth_creds" + manager = CredentialManager(env_vars={}, oauth_dir=oauth_dir) + + auth_json = tmp_path / ".codex" / "auth.json" + accounts_json = tmp_path / ".codex-accounts.json" + _write_codex_auth_json(auth_json) + _write_codex_accounts_json(accounts_json) + + auth_before = auth_json.read_text() + accounts_before = accounts_json.read_text() + + manager._import_openai_codex_cli_credentials( + auth_json_path=auth_json, + accounts_json_path=accounts_json, + ) + + assert auth_json.read_text() == auth_before + assert accounts_json.read_text() == accounts_before diff --git a/tests/test_openai_codex_provider.py b/tests/test_openai_codex_provider.py new file mode 100644 index 00000000..148d825d --- /dev/null +++ b/tests/test_openai_codex_provider.py @@ -0,0 +1,262 @@ +import base64 +import json +import time +from pathlib import Path + +import httpx +import pytest +import respx + +from rotator_library.providers.openai_codex_provider import OpenAICodexProvider + + +def _build_jwt(payload: dict) -> str: + header = {"alg": "HS256", "typ": "JWT"} + + def b64url(data: dict) -> str: + raw = json.dumps(data, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=") + + return f"{b64url(header)}.{b64url(payload)}.sig" + + +def _build_sse_payload(text: str = "pong") -> bytes: + events = [ + { + "type": "response.created", + "response": {"id": "resp_1", "created_at": int(time.time()), "status": "in_progress"}, + }, + { + "type": "response.output_item.added", + "item": { + "id": "msg_1", + "type": "message", + "status": "in_progress", + "content": [], + "role": "assistant", + }, + }, + { + "type": "response.content_part.added", + "item_id": "msg_1", + "part": {"type": "output_text", "text": ""}, + }, + { + "type": "response.output_text.delta", + "item_id": "msg_1", + "delta": text, + }, + { + "type": "response.completed", + "response": { + "id": "resp_1", + "status": "completed", + "usage": { + "input_tokens": 5, + "output_tokens": 3, + "total_tokens": 8, + }, + }, + }, + ] + + sse = "\n\n".join(f"data: {json.dumps(evt)}" for evt in events) + "\n\n" + return sse.encode("utf-8") + + +@pytest.fixture +def provider() -> OpenAICodexProvider: + return OpenAICodexProvider() + + +@pytest.fixture +def credential_file(tmp_path: Path) -> Path: + payload = { + "email": "provider@example.com", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_provider"}, + } + + cred_path = tmp_path / "openai_codex_oauth_1.json" + cred_path.write_text( + json.dumps( + { + "access_token": _build_jwt(payload), + "refresh_token": "rt_provider", + "id_token": _build_jwt(payload), + "expiry_date": int((time.time() + 3600) * 1000), + "token_uri": "https://auth.openai.com/oauth/token", + "_proxy_metadata": { + "email": "provider@example.com", + "account_id": "acct_provider", + "last_check_timestamp": time.time(), + "loaded_from_env": False, + "env_credential_index": None, + }, + }, + indent=2, + ) + ) + return cred_path + + +def test_chat_request_mapping_to_codex_payload(provider: OpenAICodexProvider): + payload = provider._build_codex_payload( + model_name="gpt-5.1-codex", + messages=[ + {"role": "system", "content": "System guidance"}, + {"role": "user", "content": "hello"}, + ], + temperature=0.2, + top_p=0.9, + max_tokens=123, + tools=[ + { + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup data", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + }, + } + ], + tool_choice="auto", + ) + + assert payload["model"] == "gpt-5.1-codex" + assert payload["stream"] is True + assert payload["store"] is False + assert payload["instructions"] == "System guidance" + assert payload["input"][0]["role"] == "user" + assert payload["temperature"] == 0.2 + assert payload["top_p"] == 0.9 + assert payload["max_output_tokens"] == 123 + assert payload["tool_choice"] == "auto" + assert payload["tools"][0]["name"] == "lookup" + + +@pytest.mark.asyncio +async def test_non_stream_response_mapping_and_header_construction( + provider: OpenAICodexProvider, + credential_file: Path, +): + endpoint = "https://chatgpt.com/backend-api/codex/responses" + + with respx.mock(assert_all_called=True) as mock_router: + route = mock_router.post(endpoint) + + def responder(request: httpx.Request) -> httpx.Response: + assert request.headers.get("authorization", "").startswith("Bearer ") + assert request.headers.get("chatgpt-account-id") == "acct_provider" + assert request.headers.get("openai-beta") == "responses=experimental" + assert request.headers.get("originator") == "pi" + + body = json.loads(request.content.decode("utf-8")) + assert body["stream"] is True + assert "instructions" in body + assert "input" in body + + return httpx.Response( + status_code=200, + content=_build_sse_payload("pong"), + headers={"content-type": "text/event-stream"}, + ) + + route.mock(side_effect=responder) + + async with httpx.AsyncClient() as client: + response = await provider.acompletion( + client, + model="openai_codex/gpt-5.1-codex", + messages=[{"role": "user", "content": "say pong"}], + stream=False, + credential_identifier=str(credential_file), + ) + + assert response.choices[0]["message"]["content"] == "pong" + assert response.usage["prompt_tokens"] == 5 + assert response.usage["completion_tokens"] == 3 + + +@pytest.mark.asyncio +async def test_env_credential_identifier_supported(monkeypatch): + provider = OpenAICodexProvider() + + payload = { + "email": "env-provider@example.com", + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": {"chatgpt_account_id": "acct_env_provider"}, + } + + monkeypatch.setenv("OPENAI_CODEX_1_ACCESS_TOKEN", _build_jwt(payload)) + monkeypatch.setenv("OPENAI_CODEX_1_REFRESH_TOKEN", "rt_env_provider") + + endpoint = "https://chatgpt.com/backend-api/codex/responses" + + with respx.mock(assert_all_called=True) as mock_router: + route = mock_router.post(endpoint) + + def responder(request: httpx.Request) -> httpx.Response: + assert request.headers.get("chatgpt-account-id") == "acct_env_provider" + return httpx.Response( + status_code=200, + content=_build_sse_payload("env-ok"), + headers={"content-type": "text/event-stream"}, + ) + + route.mock(side_effect=responder) + + async with httpx.AsyncClient() as client: + response = await provider.acompletion( + client, + model="openai_codex/gpt-5.1-codex", + messages=[{"role": "user", "content": "test env"}], + stream=False, + credential_identifier="env://openai_codex/1", + ) + + assert response.choices[0]["message"]["content"] == "env-ok" + + +def test_parse_quota_error_from_retry_after_header(provider: OpenAICodexProvider): + request = httpx.Request("POST", "https://chatgpt.com/backend-api/codex/responses") + response = httpx.Response( + status_code=429, + request=request, + headers={"Retry-After": "42"}, + text=json.dumps({"error": {"code": "rate_limit", "message": "Too many requests"}}), + ) + error = httpx.HTTPStatusError("Rate limited", request=request, response=response) + + parsed = provider.parse_quota_error(error) + assert parsed is not None + assert parsed["retry_after"] == 42 + assert parsed["reason"] == "RATE_LIMIT" + + +def test_parse_quota_error_from_resets_at_field(provider: OpenAICodexProvider): + now = int(time.time()) + reset_ts = now + 120 + + request = httpx.Request("POST", "https://chatgpt.com/backend-api/codex/responses") + response = httpx.Response( + status_code=429, + request=request, + text=json.dumps( + { + "error": { + "code": "usage_limit", + "message": "quota exceeded", + "resets_at": reset_ts, + } + } + ), + ) + error = httpx.HTTPStatusError("Quota hit", request=request, response=response) + + parsed = provider.parse_quota_error(error) + assert parsed is not None + assert parsed["reason"] == "USAGE_LIMIT" + assert parsed["quota_reset_timestamp"] == float(reset_ts) + assert isinstance(parsed["retry_after"], int) + assert parsed["retry_after"] >= 1 diff --git a/tests/test_openai_codex_sse.py b/tests/test_openai_codex_sse.py new file mode 100644 index 00000000..ec1411f7 --- /dev/null +++ b/tests/test_openai_codex_sse.py @@ -0,0 +1,110 @@ +import json +from pathlib import Path + +import pytest + +from rotator_library.providers.openai_codex_provider import ( + CodexSSETranslator, + CodexStreamError, +) + + +FIXTURES_DIR = Path(__file__).parent / "fixtures" / "openai_codex" + + +def _load_events(name: str): + return json.loads((FIXTURES_DIR / name).read_text()) + + +def test_fixture_driven_event_sequence_to_expected_chunks(): + events = _load_events("stream_success_events.json") + translator = CodexSSETranslator(model_id="openai_codex/gpt-5.1-codex") + + chunks = [] + for event in events: + chunks.extend(translator.process_event(event)) + + # content delta chunk present + content_chunks = [ + c for c in chunks if c["choices"][0]["delta"].get("content") + ] + assert content_chunks + assert content_chunks[-1]["choices"][0]["delta"]["content"] == "pong" + + # terminal chunk contains usage mapping + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "stop" + assert final_chunk["usage"]["prompt_tokens"] == 21 + assert final_chunk["usage"]["completion_tokens"] == 13 + assert final_chunk["usage"]["total_tokens"] == 34 + + +def test_tool_call_deltas_and_finish_reason_mapping(): + events = _load_events("stream_tool_call_events.json") + translator = CodexSSETranslator(model_id="openai_codex/gpt-5.1-codex") + + chunks = [] + for event in events: + chunks.extend(translator.process_event(event)) + + tool_chunks = [ + c for c in chunks if c["choices"][0]["delta"].get("tool_calls") + ] + assert tool_chunks + + # Validate streaming argument assembly appears in deltas + all_args = "".join( + tc["function"]["arguments"] + for chunk in tool_chunks + for tc in chunk["choices"][0]["delta"]["tool_calls"] + ) + assert "San" in all_args + assert "Francisco" in all_args + + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "tool_calls" + assert final_chunk["usage"]["total_tokens"] == 60 + + +def test_content_part_delta_alias_and_length_finish_reason(): + events = _load_events("stream_content_part_delta_events.json") + translator = CodexSSETranslator(model_id="openai_codex/gpt-5.1-codex") + + chunks = [] + for event in events: + chunks.extend(translator.process_event(event)) + + text = "".join( + c["choices"][0]["delta"].get("content", "") + for c in chunks + ) + assert text == "Hello world" + + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "length" + assert final_chunk["usage"]["total_tokens"] == 30 + + +def test_error_event_propagation(): + translator = CodexSSETranslator(model_id="openai_codex/gpt-5.1-codex") + + with pytest.raises(CodexStreamError) as exc: + translator.process_event( + { + "type": "error", + "error": { + "code": "usage_limit_reached", + "message": "quota reached", + "type": "rate_limit_error", + }, + } + ) + + assert exc.value.status_code == 429 + assert "quota" in str(exc.value).lower() + + +def test_unknown_event_tolerance(): + translator = CodexSSETranslator(model_id="openai_codex/gpt-5.1-codex") + chunks = translator.process_event({"type": "response.some_unknown_event"}) + assert chunks == [] diff --git a/tests/test_openai_codex_wiring.py b/tests/test_openai_codex_wiring.py new file mode 100644 index 00000000..d1a44830 --- /dev/null +++ b/tests/test_openai_codex_wiring.py @@ -0,0 +1,26 @@ +from rotator_library.credential_manager import CredentialManager +from rotator_library.provider_factory import get_provider_auth_class +from rotator_library.providers import PROVIDER_PLUGINS +from rotator_library.providers.openai_codex_auth_base import OpenAICodexAuthBase + + +def test_credential_discovery_recognizes_openai_codex_env_vars(tmp_path): + env_vars = { + "OPENAI_CODEX_1_ACCESS_TOKEN": "access-1", + "OPENAI_CODEX_1_REFRESH_TOKEN": "refresh-1", + } + + manager = CredentialManager(env_vars=env_vars, oauth_dir=tmp_path / "oauth_creds") + discovered = manager.discover_and_prepare() + + assert "openai_codex" in discovered + assert discovered["openai_codex"] == ["env://openai_codex/1"] + + +def test_provider_factory_returns_openai_codex_auth_base(): + auth_class = get_provider_auth_class("openai_codex") + assert auth_class is OpenAICodexAuthBase + + +def test_provider_auto_registration_includes_openai_codex(): + assert "openai_codex" in PROVIDER_PLUGINS