diff --git a/CHANGELOG.md b/CHANGELOG.md index fb7e8559..fb1ca385 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added support for built-in provider tools via a new `ToolBuiltIn` class. This enables provider-specific functionality like OpenAI's image generation to be registered and used as tools. Built-in tools pass raw provider definitions directly to the API rather than wrapping Python functions. (#214) * `ChatGoogle()` gains basic support for image generation. (#214) +* `ChatOpenAI()` and `ChatAzureOpenAI()` gain a new `service_tier` parameter to request a specific service tier (e.g., `"flex"` for slower/cheaper or `"priority"` for faster/more expensive). (#204) ### Changes diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 4486d828..e35eef4c 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -2,7 +2,7 @@ import base64 import warnings -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Literal, Optional, cast import orjson from openai.types.responses import Response, ResponseStreamEvent @@ -48,6 +48,9 @@ def ChatOpenAI( model: "Optional[ResponsesModel | str]" = None, api_key: Optional[str] = None, base_url: str = "https://api.openai.com/v1", + service_tier: Optional[ + Literal["auto", "default", "flex", "scale", "priority"] + ] = None, kwargs: Optional["ChatClientArgs"] = None, ) -> Chat["SubmitInputArgs", Response]: """ @@ -92,6 +95,13 @@ def ChatOpenAI( variable. base_url The base URL to the endpoint; the default uses OpenAI. + service_tier + Request a specific service tier. Options: + - `"auto"` (default): uses the service tier configured in Project settings. + - `"default"`: standard pricing and performance. + - `"flex"`: slower and cheaper. + - `"scale"`: batch-like pricing for high-volume use. + - `"priority"`: faster and more expensive. kwargs Additional arguments to pass to the `openai.OpenAI()` client constructor. @@ -145,6 +155,10 @@ def ChatOpenAI( if model is None: model = log_model_default("gpt-4.1") + kwargs_chat: "SubmitInputArgs" = {} + if service_tier is not None: + kwargs_chat["service_tier"] = service_tier + return Chat( provider=OpenAIProvider( api_key=api_key, @@ -153,6 +167,7 @@ def ChatOpenAI( kwargs=kwargs, ), system_prompt=system_prompt, + kwargs_chat=kwargs_chat, ) @@ -261,6 +276,16 @@ def stream_text(self, chunk): def stream_merge_chunks(self, completion, chunk): if chunk.type == "response.completed": return chunk.response + elif chunk.type == "response.failed": + error = chunk.response.error + if error is None: + msg = "Request failed with an unknown error." + else: + msg = f"Request failed ({error.code}): {error.message}" + raise RuntimeError(msg) + elif chunk.type == "error": + raise RuntimeError(f"Request errored: {chunk.message}") + # Since this value won't actually be used, we can lie about the type return cast(Response, None) @@ -296,12 +321,11 @@ def value_cost( if tokens is None: return None - # Extract service_tier from completion if available - variant = "" + service_tier = "" if completion is not None: - variant = getattr(completion, "service_tier", None) or "" + service_tier = completion.service_tier or "" - return get_token_cost(self.name, self.model, tokens, variant) + return get_token_cost(self.name, self.model, tokens, service_tier) def batch_result_turn(self, result, has_data_model: bool = False): response = BatchResult.model_validate(result).response diff --git a/chatlas/_provider_openai_azure.py b/chatlas/_provider_openai_azure.py index 7a1b78bc..0f2bc39d 100644 --- a/chatlas/_provider_openai_azure.py +++ b/chatlas/_provider_openai_azure.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from openai import AsyncAzureOpenAI, AzureOpenAI from openai.types.chat import ChatCompletion @@ -21,6 +21,9 @@ def ChatAzureOpenAI( api_version: str, api_key: Optional[str] = None, system_prompt: Optional[str] = None, + service_tier: Optional[ + Literal["auto", "default", "flex", "scale", "priority"] + ] = None, kwargs: Optional["ChatAzureClientArgs"] = None, ) -> Chat["SubmitInputArgs", ChatCompletion]: """ @@ -62,6 +65,13 @@ def ChatAzureOpenAI( variable. system_prompt A system prompt to set the behavior of the assistant. + service_tier + Request a specific service tier. Options: + - `"auto"` (default): uses the service tier configured in Project settings. + - `"default"`: standard pricing and performance. + - `"flex"`: slower and cheaper. + - `"scale"`: batch-like pricing for high-volume use. + - `"priority"`: faster and more expensive. kwargs Additional arguments to pass to the `openai.AzureOpenAI()` client constructor. @@ -71,6 +81,10 @@ def ChatAzureOpenAI( A Chat object. """ + kwargs_chat: "SubmitInputArgs" = {} + if service_tier is not None: + kwargs_chat["service_tier"] = service_tier + return Chat( provider=OpenAIAzureProvider( endpoint=endpoint, @@ -80,6 +94,7 @@ def ChatAzureOpenAI( kwargs=kwargs, ), system_prompt=system_prompt, + kwargs_chat=kwargs_chat, ) diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index a36e14dc..42bcafab 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -105,3 +105,31 @@ def test_openai_custom_http_client(): def test_openai_list_models(): assert_list_models(ChatOpenAI) + + +def test_openai_service_tier(): + chat = ChatOpenAI(service_tier="flex") + assert chat.kwargs_chat.get("service_tier") == "flex" + + +def test_openai_service_tier_affects_pricing(): + from chatlas._tokens import get_token_cost + + chat = ChatOpenAI(service_tier="priority") + chat.chat("What is 1+1?") + + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens is not None + assert turn.cost is not None + + # Verify that cost was calculated using priority pricing + tokens = turn.tokens + priority_cost = get_token_cost("OpenAI", chat.provider.model, tokens, "priority") + assert priority_cost is not None + assert turn.cost == priority_cost + + # Verify priority pricing is more expensive than default + default_cost = get_token_cost("OpenAI", chat.provider.model, tokens, "") + assert default_cost is not None + assert turn.cost > default_cost \ No newline at end of file