Skip to content

Commit 1ca6508

Browse files
committed
Sync updates from stainless branch: hardikjshah/dev
1 parent b664564 commit 1ca6508

File tree

8 files changed

+67
-21
lines changed

8 files changed

+67
-21
lines changed

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77
import inspect
88
import json
99
from abc import abstractmethod
10-
from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union
10+
from typing import (
11+
Any,
12+
Callable,
13+
Dict,
14+
get_args,
15+
get_origin,
16+
get_type_hints,
17+
List,
18+
TypeVar,
19+
Union,
20+
)
1121

1222
from llama_stack_client.types import CompletionMessage, Message, ToolResponse
1323
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
@@ -72,7 +82,14 @@ def run(
7282

7383
metadata = {}
7484
try:
75-
response = self.run_impl(**tool_call.arguments)
85+
if tool_call.arguments_json is not None:
86+
params = json.loads(tool_call.arguments_json)
87+
elif isinstance(tool_call.arguments, str):
88+
params = json.loads(tool_call.arguments)
89+
else:
90+
params = tool_call.arguments
91+
92+
response = self.run_impl(**params)
7693
if isinstance(response, dict) and "content" in response:
7794
content = json.dumps(response["content"], ensure_ascii=False)
7895
metadata = response.get("metadata", {})

src/llama_stack_client/lib/agents/react/tool_parser.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
import json
78
import uuid
89
from typing import List, Optional, Union
910

10-
from pydantic import BaseModel, ValidationError
11-
1211
from llama_stack_client.types.shared.completion_message import CompletionMessage
1312
from llama_stack_client.types.shared.tool_call import ToolCall
13+
14+
from pydantic import BaseModel, ValidationError
15+
1416
from ..tool_parser import ToolParser
1517

1618

@@ -31,6 +33,7 @@ class ReActOutput(BaseModel):
3133

3234

3335
class ReActToolParser(ToolParser):
36+
@override
3437
def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
3538
tool_calls = []
3639
response_text = str(output_message.content)
@@ -49,6 +52,13 @@ def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
4952
params = {param.name: param.value for param in tool_params}
5053
if tool_name and tool_params:
5154
call_id = str(uuid.uuid4())
52-
tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=params)]
55+
tool_calls = [
56+
ToolCall(
57+
call_id=call_id,
58+
tool_name=tool_name,
59+
arguments=params,
60+
arguments_json=json.dumps(params),
61+
)
62+
]
5363

5464
return tool_calls

src/llama_stack_client/resources/datasets.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def iterrows(
122122
Uses cursor-based pagination.
123123
124124
Args:
125-
limit: The number of rows to get per page.
125+
limit: The number of rows to get.
126126
127127
start_index: Index into dataset for the first row to get. Get all rows if None.
128128
@@ -185,8 +185,8 @@ def register(
185185
"Hello, John Doe. How can I help you today?"}, {"role": "user", "content":
186186
"What's my name?"}, ], "answer": "John Doe" }
187187
188-
source:
189-
The data source of the dataset. Examples: - { "type": "uri", "uri":
188+
source: The data source of the dataset. Ensure that the data source schema is compatible
189+
with the purpose of the dataset. Examples: - { "type": "uri", "uri":
190190
"https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
191191
"lsfs://mydata.jsonl" } - { "type": "uri", "uri":
192192
"data:csv;base64,{base64_content}" } - { "type": "uri", "uri":
@@ -347,7 +347,7 @@ async def iterrows(
347347
Uses cursor-based pagination.
348348
349349
Args:
350-
limit: The number of rows to get per page.
350+
limit: The number of rows to get.
351351
352352
start_index: Index into dataset for the first row to get. Get all rows if None.
353353
@@ -410,8 +410,8 @@ async def register(
410410
"Hello, John Doe. How can I help you today?"}, {"role": "user", "content":
411411
"What's my name?"}, ], "answer": "John Doe" }
412412
413-
source:
414-
The data source of the dataset. Examples: - { "type": "uri", "uri":
413+
source: The data source of the dataset. Ensure that the data source schema is compatible
414+
with the purpose of the dataset. Examples: - { "type": "uri", "uri":
415415
"https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
416416
"lsfs://mydata.jsonl" } - { "type": "uri", "uri":
417417
"data:csv;base64,{base64_content}" } - { "type": "uri", "uri":

src/llama_stack_client/types/dataset_iterrows_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class DatasetIterrowsParams(TypedDict, total=False):
1111
limit: int
12-
"""The number of rows to get per page."""
12+
"""The number of rows to get."""
1313

1414
start_index: int
1515
"""Index into dataset for the first row to get. Get all rows if None."""

src/llama_stack_client/types/dataset_iterrows_response.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class DatasetIterrowsResponse(BaseModel):
1111
data: List[Dict[str, Union[bool, float, str, List[object], object, None]]]
1212
"""The rows in the current page."""
1313

14-
next_index: Optional[int] = None
14+
next_start_index: Optional[int] = None
1515
"""Index into dataset for the first row in the next page.
1616
1717
None if there are no more rows.

src/llama_stack_client/types/dataset_register_params.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ class DatasetRegisterParams(TypedDict, total=False):
2727
source: Required[Source]
2828
"""The data source of the dataset.
2929
30-
Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" } - {
31-
"type": "uri", "uri": "lsfs://mydata.jsonl" } - { "type": "uri", "uri":
30+
Ensure that the data source schema is compatible with the purpose of the
31+
dataset. Examples: - { "type": "uri", "uri":
32+
"https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
33+
"lsfs://mydata.jsonl" } - { "type": "uri", "uri":
3234
"data:csv;base64,{base64_content}" } - { "type": "uri", "uri":
3335
"huggingface://llamastack/simpleqa?split=train" } - { "type": "rows", "rows": [
3436
{ "messages": [ {"role": "user", "content": "Hello, world!"}, {"role":
Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

3-
from typing import Dict, List, Union
3+
from typing import Dict, List, Union, Optional
44
from typing_extensions import Literal
55

66
from ..._models import BaseModel
@@ -9,11 +9,18 @@
99

1010

1111
class ToolCall(BaseModel):
12-
arguments: Dict[
12+
arguments: Union[
1313
str,
14-
Union[str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None],
14+
Dict[
15+
str,
16+
Union[
17+
str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None
18+
],
19+
],
1520
]
1621

1722
call_id: str
1823

1924
tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]
25+
26+
arguments_json: Optional[str] = None

src/llama_stack_client/types/shared_params/tool_call.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,24 @@
1010

1111
class ToolCall(TypedDict, total=False):
1212
arguments: Required[
13-
Dict[
13+
Union[
1414
str,
15-
Union[
16-
str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None
15+
Dict[
16+
str,
17+
Union[
18+
str,
19+
float,
20+
bool,
21+
List[Union[str, float, bool, None]],
22+
Dict[str, Union[str, float, bool, None]],
23+
None,
24+
],
1725
],
1826
]
1927
]
2028

2129
call_id: Required[str]
2230

2331
tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]]
32+
33+
arguments_json: str

0 commit comments

Comments
 (0)