Skip to content

Commit 96749af

Browse files
authored
chore: api sync, deprecate allow_resume_turn + rename task_config->benchmark_config (Sync updates from stainless branch: yanxi0830/dev) (#176)
# What does this PR do? - Adapt to API changes in llamastack/llama-stack#1397 and llamastack/llama-stack#1377 [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/integration/agents/test_agents.py--inference-model "meta-llama/Llama-3.3-70B-Instruct" --record-responses ``` ``` pytest -v -s --nbval-lax ./docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant)
1 parent 430dc09 commit 96749af

File tree

14 files changed

+119
-226
lines changed

14 files changed

+119
-226
lines changed

docs/cli_reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ Options:
182182
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
183183
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion
184184

185-
Example eval_task_config.json:
185+
Example eval_benchmark_config.json:
186186
```json
187187
{
188188
"type": "benchmark",

src/llama_stack_client/_base_client.py

Lines changed: 8 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import inspect
1010
import logging
1111
import platform
12-
import warnings
1312
import email.utils
1413
from types import TracebackType
1514
from random import random
@@ -36,7 +35,7 @@
3635
import httpx
3736
import distro
3837
import pydantic
39-
from httpx import URL, Limits
38+
from httpx import URL
4039
from pydantic import PrivateAttr
4140

4241
from . import _exceptions
@@ -51,19 +50,16 @@
5150
Timeout,
5251
NotGiven,
5352
ResponseT,
54-
Transport,
5553
AnyMapping,
5654
PostParser,
57-
ProxiesTypes,
5855
RequestFiles,
5956
HttpxSendArgs,
60-
AsyncTransport,
6157
RequestOptions,
6258
HttpxRequestFiles,
6359
ModelBuilderProtocol,
6460
)
6561
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
66-
from ._compat import model_copy, model_dump
62+
from ._compat import PYDANTIC_V2, model_copy, model_dump
6763
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
6864
from ._response import (
6965
APIResponse,
@@ -207,6 +203,9 @@ def _set_private_attributes(
207203
model: Type[_T],
208204
options: FinalRequestOptions,
209205
) -> None:
206+
if PYDANTIC_V2 and getattr(self, "__pydantic_private__", None) is None:
207+
self.__pydantic_private__ = {}
208+
210209
self._model = model
211210
self._client = client
212211
self._options = options
@@ -292,6 +291,9 @@ def _set_private_attributes(
292291
client: AsyncAPIClient,
293292
options: FinalRequestOptions,
294293
) -> None:
294+
if PYDANTIC_V2 and getattr(self, "__pydantic_private__", None) is None:
295+
self.__pydantic_private__ = {}
296+
295297
self._model = model
296298
self._client = client
297299
self._options = options
@@ -331,9 +333,6 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
331333
_base_url: URL
332334
max_retries: int
333335
timeout: Union[float, Timeout, None]
334-
_limits: httpx.Limits
335-
_proxies: ProxiesTypes | None
336-
_transport: Transport | AsyncTransport | None
337336
_strict_response_validation: bool
338337
_idempotency_header: str | None
339338
_default_stream_cls: type[_DefaultStreamT] | None = None
@@ -346,19 +345,13 @@ def __init__(
346345
_strict_response_validation: bool,
347346
max_retries: int = DEFAULT_MAX_RETRIES,
348347
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
349-
limits: httpx.Limits,
350-
transport: Transport | AsyncTransport | None,
351-
proxies: ProxiesTypes | None,
352348
custom_headers: Mapping[str, str] | None = None,
353349
custom_query: Mapping[str, object] | None = None,
354350
) -> None:
355351
self._version = version
356352
self._base_url = self._enforce_trailing_slash(URL(base_url))
357353
self.max_retries = max_retries
358354
self.timeout = timeout
359-
self._limits = limits
360-
self._proxies = proxies
361-
self._transport = transport
362355
self._custom_headers = custom_headers or {}
363356
self._custom_query = custom_query or {}
364357
self._strict_response_validation = _strict_response_validation
@@ -794,46 +787,11 @@ def __init__(
794787
base_url: str | URL,
795788
max_retries: int = DEFAULT_MAX_RETRIES,
796789
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
797-
transport: Transport | None = None,
798-
proxies: ProxiesTypes | None = None,
799-
limits: Limits | None = None,
800790
http_client: httpx.Client | None = None,
801791
custom_headers: Mapping[str, str] | None = None,
802792
custom_query: Mapping[str, object] | None = None,
803793
_strict_response_validation: bool,
804794
) -> None:
805-
kwargs: dict[str, Any] = {}
806-
if limits is not None:
807-
warnings.warn(
808-
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
809-
category=DeprecationWarning,
810-
stacklevel=3,
811-
)
812-
if http_client is not None:
813-
raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
814-
else:
815-
limits = DEFAULT_CONNECTION_LIMITS
816-
817-
if transport is not None:
818-
kwargs["transport"] = transport
819-
warnings.warn(
820-
"The `transport` argument is deprecated. The `http_client` argument should be passed instead",
821-
category=DeprecationWarning,
822-
stacklevel=3,
823-
)
824-
if http_client is not None:
825-
raise ValueError("The `http_client` argument is mutually exclusive with `transport`")
826-
827-
if proxies is not None:
828-
kwargs["proxies"] = proxies
829-
warnings.warn(
830-
"The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
831-
category=DeprecationWarning,
832-
stacklevel=3,
833-
)
834-
if http_client is not None:
835-
raise ValueError("The `http_client` argument is mutually exclusive with `proxies`")
836-
837795
if not is_given(timeout):
838796
# if the user passed in a custom http client with a non-default
839797
# timeout set then we use that timeout.
@@ -854,12 +812,9 @@ def __init__(
854812

855813
super().__init__(
856814
version=version,
857-
limits=limits,
858815
# cast to a valid type because mypy doesn't understand our type narrowing
859816
timeout=cast(Timeout, timeout),
860-
proxies=proxies,
861817
base_url=base_url,
862-
transport=transport,
863818
max_retries=max_retries,
864819
custom_query=custom_query,
865820
custom_headers=custom_headers,
@@ -869,9 +824,6 @@ def __init__(
869824
base_url=base_url,
870825
# cast to a valid type because mypy doesn't understand our type narrowing
871826
timeout=cast(Timeout, timeout),
872-
limits=limits,
873-
follow_redirects=True,
874-
**kwargs, # type: ignore
875827
)
876828

877829
def is_closed(self) -> bool:
@@ -1366,45 +1318,10 @@ def __init__(
13661318
_strict_response_validation: bool,
13671319
max_retries: int = DEFAULT_MAX_RETRIES,
13681320
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
1369-
transport: AsyncTransport | None = None,
1370-
proxies: ProxiesTypes | None = None,
1371-
limits: Limits | None = None,
13721321
http_client: httpx.AsyncClient | None = None,
13731322
custom_headers: Mapping[str, str] | None = None,
13741323
custom_query: Mapping[str, object] | None = None,
13751324
) -> None:
1376-
kwargs: dict[str, Any] = {}
1377-
if limits is not None:
1378-
warnings.warn(
1379-
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
1380-
category=DeprecationWarning,
1381-
stacklevel=3,
1382-
)
1383-
if http_client is not None:
1384-
raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
1385-
else:
1386-
limits = DEFAULT_CONNECTION_LIMITS
1387-
1388-
if transport is not None:
1389-
kwargs["transport"] = transport
1390-
warnings.warn(
1391-
"The `transport` argument is deprecated. The `http_client` argument should be passed instead",
1392-
category=DeprecationWarning,
1393-
stacklevel=3,
1394-
)
1395-
if http_client is not None:
1396-
raise ValueError("The `http_client` argument is mutually exclusive with `transport`")
1397-
1398-
if proxies is not None:
1399-
kwargs["proxies"] = proxies
1400-
warnings.warn(
1401-
"The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
1402-
category=DeprecationWarning,
1403-
stacklevel=3,
1404-
)
1405-
if http_client is not None:
1406-
raise ValueError("The `http_client` argument is mutually exclusive with `proxies`")
1407-
14081325
if not is_given(timeout):
14091326
# if the user passed in a custom http client with a non-default
14101327
# timeout set then we use that timeout.
@@ -1426,11 +1343,8 @@ def __init__(
14261343
super().__init__(
14271344
version=version,
14281345
base_url=base_url,
1429-
limits=limits,
14301346
# cast to a valid type because mypy doesn't understand our type narrowing
14311347
timeout=cast(Timeout, timeout),
1432-
proxies=proxies,
1433-
transport=transport,
14341348
max_retries=max_retries,
14351349
custom_query=custom_query,
14361350
custom_headers=custom_headers,
@@ -1440,9 +1354,6 @@ def __init__(
14401354
base_url=base_url,
14411355
# cast to a valid type because mypy doesn't understand our type narrowing
14421356
timeout=cast(Timeout, timeout),
1443-
limits=limits,
1444-
follow_redirects=True,
1445-
**kwargs, # type: ignore
14461357
)
14471358

14481359
def is_closed(self) -> bool:

src/llama_stack_client/_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
_strict_response_validation: bool = False,
125125
provider_data: Mapping[str, Any] | None = None,
126126
) -> None:
127-
"""Construct a new synchronous llama-stack-client client instance.
127+
"""Construct a new synchronous LlamaStackClient client instance.
128128
129129
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
130130
"""
@@ -340,7 +340,7 @@ def __init__(
340340
_strict_response_validation: bool = False,
341341
provider_data: Mapping[str, Any] | None = None,
342342
) -> None:
343-
"""Construct a new async llama-stack-client client instance.
343+
"""Construct a new async AsyncLlamaStackClient client instance.
344344
345345
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
346346
"""

src/llama_stack_client/lib/agents/agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ def _create_turn_streaming(
151151
stream=True,
152152
documents=documents,
153153
toolgroups=toolgroups,
154-
allow_turn_resume=True,
155154
)
156155

157156
# 2. process turn and resume if there's a tool call

src/llama_stack_client/lib/cli/eval/run_benchmark.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313
from tqdm.rich import tqdm
1414

1515
from ..common.utils import create_bar_chart
16-
from .utils import (
17-
aggregate_accuracy,
18-
aggregate_average,
19-
aggregate_categorical_count,
20-
aggregate_median,
21-
)
16+
from .utils import aggregate_accuracy, aggregate_average, aggregate_categorical_count, aggregate_median
2217

2318

2419
@click.command("run-benchmark")
@@ -110,7 +105,7 @@ def run_benchmark(
110105
benchmark_id=benchmark_id,
111106
input_rows=[r],
112107
scoring_functions=scoring_functions,
113-
task_config={
108+
benchmark_config={
114109
"type": "benchmark",
115110
"eval_candidate": {
116111
"type": "model",

src/llama_stack_client/resources/agents/turn.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def create(
5858
*,
5959
agent_id: str,
6060
messages: Iterable[turn_create_params.Message],
61-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
6261
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
6362
stream: Literal[False] | NotGiven = NOT_GIVEN,
6463
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
@@ -92,7 +91,6 @@ def create(
9291
agent_id: str,
9392
messages: Iterable[turn_create_params.Message],
9493
stream: Literal[True],
95-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
9694
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
9795
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
9896
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
@@ -125,7 +123,6 @@ def create(
125123
agent_id: str,
126124
messages: Iterable[turn_create_params.Message],
127125
stream: bool,
128-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
129126
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
130127
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
131128
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
@@ -157,7 +154,6 @@ def create(
157154
*,
158155
agent_id: str,
159156
messages: Iterable[turn_create_params.Message],
160-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
161157
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
162158
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
163159
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
@@ -178,7 +174,6 @@ def create(
178174
body=maybe_transform(
179175
{
180176
"messages": messages,
181-
"allow_turn_resume": allow_turn_resume,
182177
"documents": documents,
183178
"stream": stream,
184179
"tool_config": tool_config,
@@ -412,7 +407,6 @@ async def create(
412407
*,
413408
agent_id: str,
414409
messages: Iterable[turn_create_params.Message],
415-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
416410
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
417411
stream: Literal[False] | NotGiven = NOT_GIVEN,
418412
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
@@ -446,7 +440,6 @@ async def create(
446440
agent_id: str,
447441
messages: Iterable[turn_create_params.Message],
448442
stream: Literal[True],
449-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
450443
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
451444
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
452445
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
@@ -479,7 +472,6 @@ async def create(
479472
agent_id: str,
480473
messages: Iterable[turn_create_params.Message],
481474
stream: bool,
482-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
483475
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
484476
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
485477
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
@@ -511,7 +503,6 @@ async def create(
511503
*,
512504
agent_id: str,
513505
messages: Iterable[turn_create_params.Message],
514-
allow_turn_resume: bool | NotGiven = NOT_GIVEN,
515506
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
516507
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
517508
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
@@ -532,7 +523,6 @@ async def create(
532523
body=await async_maybe_transform(
533524
{
534525
"messages": messages,
535-
"allow_turn_resume": allow_turn_resume,
536526
"documents": documents,
537527
"stream": stream,
538528
"tool_config": tool_config,

0 commit comments

Comments
 (0)