Skip to content

Commit ccf0157

Browse files
authored
Add basic image generation support; introduce new ToolBuiltIn class (#214)
* Add basic image generation support; introduce new ToolBuiltIn class * Cleanup * More cleanup; get image display somewhat working * Avoid unnecessary methods * Update types * Fix 3.9 typing issue * Fix typing issue * Update CHANGELOG * Tweak wording
1 parent 68a1d2f commit ccf0157

15 files changed

+221
-100
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
77
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
88
-->
99

10+
## [UNRELEASED]
11+
12+
* Added support for built-in provider tools via a new `ToolBuiltIn` class. This enables provider-specific functionality like OpenAI's image generation to be registered and used as tools. Built-in tools pass raw provider definitions directly to the API rather than wrapping Python functions. (#214)
13+
* `ChatGoogle()` gains basic support for image generation. (#214)
14+
1015
## [0.14.0] - 2025-12-09
1116

1217
### New features

chatlas/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ._provider_portkey import ChatPortkey
3535
from ._provider_snowflake import ChatSnowflake
3636
from ._tokens import token_usage
37-
from ._tools import Tool, ToolRejectError
37+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
3838
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn
3939

4040
try:
@@ -84,6 +84,7 @@
8484
"Provider",
8585
"token_usage",
8686
"Tool",
87+
"ToolBuiltIn",
8788
"ToolRejectError",
8889
"Turn",
8990
"UserTurn",

chatlas/_chat.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from ._mcp_manager import MCPSessionManager
5050
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
5151
from ._tokens import compute_cost, get_token_pricing, tokens_log
52-
from ._tools import Tool, ToolRejectError
52+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
5353
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn
5454
from ._typing_extensions import TypedDict, TypeGuard
5555
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
@@ -132,7 +132,7 @@ def __init__(
132132
self.system_prompt = system_prompt
133133
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}
134134

135-
self._tools: dict[str, Tool] = {}
135+
self._tools: dict[str, Tool | ToolBuiltIn] = {}
136136
self._on_tool_request_callbacks = CallbackManager()
137137
self._on_tool_result_callbacks = CallbackManager()
138138
self._current_display: Optional[MarkdownDisplay] = None
@@ -1880,7 +1880,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):
18801880

18811881
def register_tool(
18821882
self,
1883-
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
1883+
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | ToolBuiltIn,
18841884
*,
18851885
force: bool = False,
18861886
name: Optional[str] = None,
@@ -1982,23 +1982,30 @@ def add(a: int, b: int) -> int:
19821982
func.func, name=name, model=model, annotations=annotations
19831983
)
19841984
func = func.func
1985+
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1986+
else:
1987+
if isinstance(func, ToolBuiltIn):
1988+
tool = func
1989+
else:
1990+
tool = Tool.from_func(
1991+
func, name=name, model=model, annotations=annotations
1992+
)
19851993

1986-
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
19871994
if tool.name in self._tools and not force:
19881995
raise ValueError(
19891996
f"Tool with name '{tool.name}' is already registered. "
19901997
"Set `force=True` to overwrite it."
19911998
)
19921999
self._tools[tool.name] = tool
19932000

1994-
def get_tools(self) -> list[Tool]:
2001+
def get_tools(self) -> list[Tool | ToolBuiltIn]:
19952002
"""
19962003
Get the list of registered tools.
19972004
19982005
Returns
19992006
-------
2000-
list[Tool]
2001-
A list of `Tool` instances that are currently registered with the chat.
2007+
list[Tool | ToolBuiltIn]
2008+
A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat.
20022009
"""
20032010
return list(self._tools.values())
20042011

@@ -2522,7 +2529,7 @@ def _submit_turns(
25222529
data_model: type[BaseModel] | None = None,
25232530
kwargs: Optional[SubmitInputArgsT] = None,
25242531
) -> Generator[str, None, None]:
2525-
if any(x._is_async for x in self._tools.values()):
2532+
if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()):
25262533
raise ValueError("Cannot use async tools in a synchronous chat")
25272534

25282535
def emit(text: str | Content):
@@ -2683,15 +2690,24 @@ def _collect_all_kwargs(
26832690

26842691
def _invoke_tool(self, request: ContentToolRequest):
26852692
tool = self._tools.get(request.name)
2686-
func = tool.func if tool is not None else None
26872693

2688-
if func is None:
2694+
if tool is None:
26892695
yield self._handle_tool_error_result(
26902696
request,
26912697
error=RuntimeError("Unknown tool."),
26922698
)
26932699
return
26942700

2701+
if isinstance(tool, ToolBuiltIn):
2702+
yield self._handle_tool_error_result(
2703+
request,
2704+
error=RuntimeError(
2705+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2706+
"It should be handled by the provider."
2707+
),
2708+
)
2709+
return
2710+
26952711
# First, invoke the request callbacks. If a ToolRejectError is raised,
26962712
# treat it like a tool failure (i.e., gracefully handle it).
26972713
result: ContentToolResult | None = None
@@ -2703,9 +2719,9 @@ def _invoke_tool(self, request: ContentToolRequest):
27032719

27042720
try:
27052721
if isinstance(request.arguments, dict):
2706-
res = func(**request.arguments)
2722+
res = tool.func(**request.arguments)
27072723
else:
2708-
res = func(request.arguments)
2724+
res = tool.func(request.arguments)
27092725

27102726
# Normalize res as a generator of results.
27112727
if not inspect.isgenerator(res):
@@ -2739,10 +2755,15 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27392755
)
27402756
return
27412757

2742-
if tool._is_async:
2743-
func = tool.func
2744-
else:
2745-
func = wrap_async(tool.func)
2758+
if isinstance(tool, ToolBuiltIn):
2759+
yield self._handle_tool_error_result(
2760+
request,
2761+
error=RuntimeError(
2762+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2763+
"It should be handled by the provider."
2764+
),
2765+
)
2766+
return
27462767

27472768
# First, invoke the request callbacks. If a ToolRejectError is raised,
27482769
# treat it like a tool failure (i.e., gracefully handle it).
@@ -2753,6 +2774,11 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27532774
yield self._handle_tool_error_result(request, e)
27542775
return
27552776

2777+
if tool._is_async:
2778+
func = tool.func
2779+
else:
2780+
func = wrap_async(tool.func)
2781+
27562782
# Invoke the tool (if it hasn't been rejected).
27572783
try:
27582784
if isinstance(request.arguments, dict):

chatlas/_content.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._typing_extensions import TypedDict
1212

1313
if TYPE_CHECKING:
14-
from ._tools import Tool
14+
from ._tools import Tool, ToolBuiltIn
1515

1616

1717
class ToolAnnotations(TypedDict, total=False):
@@ -104,15 +104,21 @@ class ToolInfo(BaseModel):
104104
annotations: Optional[ToolAnnotations] = None
105105

106106
@classmethod
107-
def from_tool(cls, tool: "Tool") -> "ToolInfo":
108-
"""Create a ToolInfo from a Tool instance."""
109-
func_schema = tool.schema["function"]
110-
return cls(
111-
name=tool.name,
112-
description=func_schema.get("description", ""),
113-
parameters=func_schema.get("parameters", {}),
114-
annotations=tool.annotations,
115-
)
107+
def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
108+
"""Create a ToolInfo from a Tool or ToolBuiltIn instance."""
109+
from ._tools import ToolBuiltIn
110+
111+
if isinstance(tool, ToolBuiltIn):
112+
return cls(name=tool.name, description=tool.name, parameters={})
113+
else:
114+
# For regular tools, extract from schema
115+
func_schema = tool.schema["function"]
116+
return cls(
117+
name=tool.name,
118+
description=func_schema.get("description", ""),
119+
parameters=func_schema.get("parameters", {}),
120+
annotations=tool.annotations,
121+
)
116122

117123

118124
ContentTypeEnum = Literal[

chatlas/_mcp_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, Any, Optional, Sequence
99

10-
from ._tools import Tool
10+
from ._tools import Tool, ToolBuiltIn
1111

1212
if TYPE_CHECKING:
1313
from mcp import ClientSession
@@ -23,7 +23,7 @@ class SessionInfo(ABC):
2323

2424
# Primary derived attributes
2525
session: ClientSession | None = None
26-
tools: dict[str, Tool] = field(default_factory=dict)
26+
tools: dict[str, Tool | ToolBuiltIn] = field(default_factory=dict)
2727

2828
# Background task management
2929
ready_event: asyncio.Event = field(default_factory=asyncio.Event)
@@ -74,7 +74,7 @@ async def request_tools(self) -> None:
7474
tool_names = tool_names.difference(exclude)
7575

7676
# Apply namespace and convert to chatlas.Tool instances
77-
self_tools: dict[str, Tool] = {}
77+
self_tools: dict[str, Tool | ToolBuiltIn] = {}
7878
for tool in response.tools:
7979
if tool.name not in tool_names:
8080
continue

chatlas/_provider.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel
1717

1818
from ._content import Content
19-
from ._tools import Tool
19+
from ._tools import Tool, ToolBuiltIn
2020
from ._turn import AssistantTurn, Turn
2121
from ._typing_extensions import NotRequired, TypedDict
2222

@@ -162,7 +162,7 @@ def chat_perform(
162162
*,
163163
stream: Literal[False],
164164
turns: list[Turn],
165-
tools: dict[str, Tool],
165+
tools: dict[str, Tool | ToolBuiltIn],
166166
data_model: Optional[type[BaseModel]],
167167
kwargs: SubmitInputArgsT,
168168
) -> ChatCompletionT: ...
@@ -174,7 +174,7 @@ def chat_perform(
174174
*,
175175
stream: Literal[True],
176176
turns: list[Turn],
177-
tools: dict[str, Tool],
177+
tools: dict[str, Tool | ToolBuiltIn],
178178
data_model: Optional[type[BaseModel]],
179179
kwargs: SubmitInputArgsT,
180180
) -> Iterable[ChatCompletionChunkT]: ...
@@ -185,7 +185,7 @@ def chat_perform(
185185
*,
186186
stream: bool,
187187
turns: list[Turn],
188-
tools: dict[str, Tool],
188+
tools: dict[str, Tool | ToolBuiltIn],
189189
data_model: Optional[type[BaseModel]],
190190
kwargs: SubmitInputArgsT,
191191
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -197,7 +197,7 @@ async def chat_perform_async(
197197
*,
198198
stream: Literal[False],
199199
turns: list[Turn],
200-
tools: dict[str, Tool],
200+
tools: dict[str, Tool | ToolBuiltIn],
201201
data_model: Optional[type[BaseModel]],
202202
kwargs: SubmitInputArgsT,
203203
) -> ChatCompletionT: ...
@@ -209,7 +209,7 @@ async def chat_perform_async(
209209
*,
210210
stream: Literal[True],
211211
turns: list[Turn],
212-
tools: dict[str, Tool],
212+
tools: dict[str, Tool | ToolBuiltIn],
213213
data_model: Optional[type[BaseModel]],
214214
kwargs: SubmitInputArgsT,
215215
) -> AsyncIterable[ChatCompletionChunkT]: ...
@@ -220,7 +220,7 @@ async def chat_perform_async(
220220
*,
221221
stream: bool,
222222
turns: list[Turn],
223-
tools: dict[str, Tool],
223+
tools: dict[str, Tool | ToolBuiltIn],
224224
data_model: Optional[type[BaseModel]],
225225
kwargs: SubmitInputArgsT,
226226
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -259,15 +259,15 @@ def value_tokens(
259259
def token_count(
260260
self,
261261
*args: Content | str,
262-
tools: dict[str, Tool],
262+
tools: dict[str, Tool | ToolBuiltIn],
263263
data_model: Optional[type[BaseModel]],
264264
) -> int: ...
265265

266266
@abstractmethod
267267
async def token_count_async(
268268
self,
269269
*args: Content | str,
270-
tools: dict[str, Tool],
270+
tools: dict[str, Tool | ToolBuiltIn],
271271
data_model: Optional[type[BaseModel]],
272272
) -> int: ...
273273

0 commit comments

Comments
 (0)