Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* Added support for built-in provider tools via a new `ToolBuiltIn` class. This enables provider-specific functionality like OpenAI's image generation to be registered and used as tools. Built-in tools pass raw provider definitions directly to the API rather than wrapping Python functions. (#214)
* `ChatGoogle()` gains basic support for image generation. (#214)
* `ChatOpenAI()` and `ChatAzureOpenAI()` gain a new `service_tier` parameter to request a specific service tier (e.g., `"flex"` for slower/cheaper or `"priority"` for faster/more expensive). (#204)

### Changes

Expand Down
34 changes: 29 additions & 5 deletions chatlas/_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import base64
import warnings
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Literal, Optional, cast

import orjson
from openai.types.responses import Response, ResponseStreamEvent
Expand Down Expand Up @@ -48,6 +48,9 @@ def ChatOpenAI(
model: "Optional[ResponsesModel | str]" = None,
api_key: Optional[str] = None,
base_url: str = "https://api.openai.com/v1",
service_tier: Optional[
Literal["auto", "default", "flex", "scale", "priority"]
] = None,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat["SubmitInputArgs", Response]:
"""
Expand Down Expand Up @@ -92,6 +95,13 @@ def ChatOpenAI(
variable.
base_url
The base URL to the endpoint; the default uses OpenAI.
service_tier
Request a specific service tier. Options:
- `"auto"` (default): uses the service tier configured in Project settings.
- `"default"`: standard pricing and performance.
- `"flex"`: slower and cheaper.
- `"scale"`: batch-like pricing for high-volume use.
- `"priority"`: faster and more expensive.
kwargs
Additional arguments to pass to the `openai.OpenAI()` client
constructor.
Expand Down Expand Up @@ -145,6 +155,10 @@ def ChatOpenAI(
if model is None:
model = log_model_default("gpt-4.1")

kwargs_chat: "SubmitInputArgs" = {}
if service_tier is not None:
kwargs_chat["service_tier"] = service_tier

return Chat(
provider=OpenAIProvider(
api_key=api_key,
Expand All @@ -153,6 +167,7 @@ def ChatOpenAI(
kwargs=kwargs,
),
system_prompt=system_prompt,
kwargs_chat=kwargs_chat,
)


Expand Down Expand Up @@ -261,6 +276,16 @@ def stream_text(self, chunk):
def stream_merge_chunks(self, completion, chunk):
if chunk.type == "response.completed":
return chunk.response
elif chunk.type == "response.failed":
error = chunk.response.error
if error is None:
msg = "Request failed with an unknown error."
else:
msg = f"Request failed ({error.code}): {error.message}"
raise RuntimeError(msg)
elif chunk.type == "error":
raise RuntimeError(f"Request errored: {chunk.message}")

# Since this value won't actually be used, we can lie about the type
return cast(Response, None)

Expand Down Expand Up @@ -296,12 +321,11 @@ def value_cost(
if tokens is None:
return None

# Extract service_tier from completion if available
variant = ""
service_tier = ""
if completion is not None:
variant = getattr(completion, "service_tier", None) or ""
service_tier = completion.service_tier or ""

return get_token_cost(self.name, self.model, tokens, variant)
return get_token_cost(self.name, self.model, tokens, service_tier)

def batch_result_turn(self, result, has_data_model: bool = False):
response = BatchResult.model_validate(result).response
Expand Down
17 changes: 16 additions & 1 deletion chatlas/_provider_openai_azure.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.chat import ChatCompletion
Expand All @@ -21,6 +21,9 @@ def ChatAzureOpenAI(
api_version: str,
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
service_tier: Optional[
Literal["auto", "default", "flex", "scale", "priority"]
] = None,
kwargs: Optional["ChatAzureClientArgs"] = None,
) -> Chat["SubmitInputArgs", ChatCompletion]:
"""
Expand Down Expand Up @@ -62,6 +65,13 @@ def ChatAzureOpenAI(
variable.
system_prompt
A system prompt to set the behavior of the assistant.
service_tier
Request a specific service tier. Options:
- `"auto"` (default): uses the service tier configured in Project settings.
- `"default"`: standard pricing and performance.
- `"flex"`: slower and cheaper.
- `"scale"`: batch-like pricing for high-volume use.
- `"priority"`: faster and more expensive.
kwargs
Additional arguments to pass to the `openai.AzureOpenAI()` client constructor.

Expand All @@ -71,6 +81,10 @@ def ChatAzureOpenAI(
A Chat object.
"""

kwargs_chat: "SubmitInputArgs" = {}
if service_tier is not None:
kwargs_chat["service_tier"] = service_tier

return Chat(
provider=OpenAIAzureProvider(
endpoint=endpoint,
Expand All @@ -80,6 +94,7 @@ def ChatAzureOpenAI(
kwargs=kwargs,
),
system_prompt=system_prompt,
kwargs_chat=kwargs_chat,
)


Expand Down
28 changes: 28 additions & 0 deletions tests/test_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,31 @@ def test_openai_custom_http_client():

def test_openai_list_models():
assert_list_models(ChatOpenAI)


def test_openai_service_tier():
chat = ChatOpenAI(service_tier="flex")
assert chat.kwargs_chat.get("service_tier") == "flex"


def test_openai_service_tier_affects_pricing():
from chatlas._tokens import get_token_cost

chat = ChatOpenAI(service_tier="priority")
chat.chat("What is 1+1?")

turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens is not None
assert turn.cost is not None

# Verify that cost was calculated using priority pricing
tokens = turn.tokens
priority_cost = get_token_cost("OpenAI", chat.provider.model, tokens, "priority")
assert priority_cost is not None
assert turn.cost == priority_cost

# Verify priority pricing is more expensive than default
default_cost = get_token_cost("OpenAI", chat.provider.model, tokens, "")
assert default_cost is not None
assert turn.cost > default_cost
Loading