Skip to content

Commit 7584610

Browse files
ehhuangEric Huang (AI Platform)
andauthored
Sync updates from stainless branch: ehhuang_dev (#129)
# What does this PR do? Sync stainless sdk ## Test Plan ``` # llama-stack-client-python uv pip install -e . cd llama-stack uv pip install -e . LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B ``` Outputs: ======================================= short test summary info ======================================= FAILED tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior[meta-llama/Llama-3.1-8B-Instruct] - assert 'function' in 'shield_call> No Violationinference> "Why did the bicycle fall over? Because ... ================== 1 failed, 32 passed, 3 skipped, 115 warnings in 64.37s (0:01:04) =================== 1 Tests fail on both head and base. test_override_system_message_behavior is a bit flaky; passed with ``` LLAMA_STACK_CONFIG=fireworks pytest \-\-inference\-model=meta\-llama/Llama\-3\.3\-70B\-Instruct -s -v tests/client-sdk/agents/test_agents.py::test_override_system_message_behavior ``` will look into fixing this test ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. Co-authored-by: Eric Huang (AI Platform) <erichuang@fb.com>
1 parent cf64724 commit 7584610

36 files changed

+388
-64
lines changed

src/llama_stack_client/_base_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,17 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0
418418
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
419419
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
420420

421-
# Don't set the retry count header if it was already set or removed by the caller. We check
421+
# Don't set these headers if they were already set or removed by the caller. We check
422422
# `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case.
423-
if "x-stainless-retry-count" not in (header.lower() for header in custom_headers):
423+
lower_custom_headers = [header.lower() for header in custom_headers]
424+
if "x-stainless-retry-count" not in lower_custom_headers:
424425
headers["x-stainless-retry-count"] = str(retries_taken)
426+
if "x-stainless-read-timeout" not in lower_custom_headers:
427+
timeout = self.timeout if isinstance(options.timeout, NotGiven) else options.timeout
428+
if isinstance(timeout, Timeout):
429+
timeout = timeout.read
430+
if timeout is not None:
431+
headers["x-stainless-read-timeout"] = str(timeout)
425432

426433
return headers
427434

src/llama_stack_client/_client.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ class LlamaStackClient(SyncAPIClient):
9898
with_streaming_response: LlamaStackClientWithStreamedResponse
9999

100100
# client options
101+
api_key: str | None
101102

102103
def __init__(
103104
self,
104105
*,
105-
base_url: str | httpx.URL | None = None,
106106
api_key: str | None = None,
107+
base_url: str | httpx.URL | None = None,
107108
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
108109
max_retries: int = DEFAULT_MAX_RETRIES,
109110
default_headers: Mapping[str, str] | None = None,
@@ -123,19 +124,20 @@ def __init__(
123124
_strict_response_validation: bool = False,
124125
provider_data: Mapping[str, Any] | None = None,
125126
) -> None:
126-
"""Construct a new synchronous llama-stack-client client instance."""
127-
if base_url is None:
128-
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
129-
if base_url is None:
130-
base_url = f"http://any-hosted-llama-stack.com"
127+
"""Construct a new synchronous llama-stack-client client instance.
131128
129+
This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
130+
"""
132131
if api_key is None:
133132
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
134133
self.api_key = api_key
135134

135+
if base_url is None:
136+
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
137+
if base_url is None:
138+
base_url = f"http://any-hosted-llama-stack.com"
139+
136140
custom_headers = default_headers or {}
137-
if api_key is not None:
138-
custom_headers["Authorization"] = f"Bearer {api_key}"
139141
custom_headers["X-LlamaStack-Client-Version"] = __version__
140142
if provider_data is not None:
141143
custom_headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
@@ -182,6 +184,14 @@ def __init__(
182184
def qs(self) -> Querystring:
183185
return Querystring(array_format="comma")
184186

187+
@property
188+
@override
189+
def auth_headers(self) -> dict[str, str]:
190+
api_key = self.api_key
191+
if api_key is None:
192+
return {}
193+
return {"Authorization": f"Bearer {api_key}"}
194+
185195
@property
186196
@override
187197
def default_headers(self) -> dict[str, str | Omit]:
@@ -194,8 +204,8 @@ def default_headers(self) -> dict[str, str | Omit]:
194204
def copy(
195205
self,
196206
*,
197-
base_url: str | httpx.URL | None = None,
198207
api_key: str | None = None,
208+
base_url: str | httpx.URL | None = None,
199209
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
200210
http_client: httpx.Client | None = None,
201211
max_retries: int | NotGiven = NOT_GIVEN,
@@ -228,8 +238,8 @@ def copy(
228238

229239
http_client = http_client or self._client
230240
return self.__class__(
231-
base_url=base_url or self.base_url,
232241
api_key=api_key or self.api_key,
242+
base_url=base_url or self.base_url,
233243
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
234244
http_client=http_client,
235245
max_retries=max_retries if is_given(max_retries) else self.max_retries,
@@ -304,12 +314,13 @@ class AsyncLlamaStackClient(AsyncAPIClient):
304314
with_streaming_response: AsyncLlamaStackClientWithStreamedResponse
305315

306316
# client options
317+
api_key: str | None
307318

308319
def __init__(
309320
self,
310321
*,
311-
base_url: str | httpx.URL | None = None,
312322
api_key: str | None = None,
323+
base_url: str | httpx.URL | None = None,
313324
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
314325
max_retries: int = DEFAULT_MAX_RETRIES,
315326
default_headers: Mapping[str, str] | None = None,
@@ -329,19 +340,20 @@ def __init__(
329340
_strict_response_validation: bool = False,
330341
provider_data: Mapping[str, Any] | None = None,
331342
) -> None:
332-
"""Construct a new async llama-stack-client client instance."""
333-
if base_url is None:
334-
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
335-
if base_url is None:
336-
base_url = f"http://any-hosted-llama-stack.com"
343+
"""Construct a new async llama-stack-client client instance.
337344
345+
This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
346+
"""
338347
if api_key is None:
339348
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
340349
self.api_key = api_key
341350

351+
if base_url is None:
352+
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
353+
if base_url is None:
354+
base_url = f"http://any-hosted-llama-stack.com"
355+
342356
custom_headers = default_headers or {}
343-
if api_key is not None:
344-
custom_headers["Authorization"] = f"Bearer {api_key}"
345357
custom_headers["X-LlamaStack-Client-Version"] = __version__
346358
if provider_data is not None:
347359
custom_headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
@@ -388,6 +400,14 @@ def __init__(
388400
def qs(self) -> Querystring:
389401
return Querystring(array_format="comma")
390402

403+
@property
404+
@override
405+
def auth_headers(self) -> dict[str, str]:
406+
api_key = self.api_key
407+
if api_key is None:
408+
return {}
409+
return {"Authorization": f"Bearer {api_key}"}
410+
391411
@property
392412
@override
393413
def default_headers(self) -> dict[str, str | Omit]:
@@ -400,8 +420,8 @@ def default_headers(self) -> dict[str, str | Omit]:
400420
def copy(
401421
self,
402422
*,
403-
base_url: str | httpx.URL | None = None,
404423
api_key: str | None = None,
424+
base_url: str | httpx.URL | None = None,
405425
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
406426
http_client: httpx.AsyncClient | None = None,
407427
max_retries: int | NotGiven = NOT_GIVEN,
@@ -434,8 +454,8 @@ def copy(
434454

435455
http_client = http_client or self._client
436456
return self.__class__(
437-
base_url=base_url or self.base_url,
438457
api_key=api_key or self.api_key,
458+
base_url=base_url or self.base_url,
439459
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
440460
http_client=http_client,
441461
max_retries=max_retries if is_given(max_retries) else self.max_retries,

src/llama_stack_client/_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to"
77

88
# default timeout is 1 minute
9-
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0)
9+
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60, connect=5.0)
1010
DEFAULT_MAX_RETRIES = 2
1111
DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20)
1212

src/llama_stack_client/resources/agents/turn.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def create(
5959
messages: Iterable[turn_create_params.Message],
6060
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
6161
stream: Literal[False] | NotGiven = NOT_GIVEN,
62+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
6263
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
6364
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
6465
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -69,6 +70,8 @@ def create(
6970
) -> Turn:
7071
"""
7172
Args:
73+
tool_config: Configuration for tool use.
74+
7275
extra_headers: Send extra headers
7376
7477
extra_query: Add additional query parameters to the request
@@ -88,6 +91,7 @@ def create(
8891
messages: Iterable[turn_create_params.Message],
8992
stream: Literal[True],
9093
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
94+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
9195
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
9296
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
9397
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -98,6 +102,8 @@ def create(
98102
) -> Stream[AgentTurnResponseStreamChunk]:
99103
"""
100104
Args:
105+
tool_config: Configuration for tool use.
106+
101107
extra_headers: Send extra headers
102108
103109
extra_query: Add additional query parameters to the request
@@ -117,6 +123,7 @@ def create(
117123
messages: Iterable[turn_create_params.Message],
118124
stream: bool,
119125
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
126+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
120127
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
121128
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
122129
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -127,6 +134,8 @@ def create(
127134
) -> Turn | Stream[AgentTurnResponseStreamChunk]:
128135
"""
129136
Args:
137+
tool_config: Configuration for tool use.
138+
130139
extra_headers: Send extra headers
131140
132141
extra_query: Add additional query parameters to the request
@@ -146,6 +155,7 @@ def create(
146155
messages: Iterable[turn_create_params.Message],
147156
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
148157
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
158+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
149159
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
150160
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
151161
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -165,6 +175,7 @@ def create(
165175
"messages": messages,
166176
"documents": documents,
167177
"stream": stream,
178+
"tool_config": tool_config,
168179
"toolgroups": toolgroups,
169180
},
170181
turn_create_params.TurnCreateParams,
@@ -244,6 +255,7 @@ async def create(
244255
messages: Iterable[turn_create_params.Message],
245256
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
246257
stream: Literal[False] | NotGiven = NOT_GIVEN,
258+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
247259
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
248260
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
249261
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -254,6 +266,8 @@ async def create(
254266
) -> Turn:
255267
"""
256268
Args:
269+
tool_config: Configuration for tool use.
270+
257271
extra_headers: Send extra headers
258272
259273
extra_query: Add additional query parameters to the request
@@ -273,6 +287,7 @@ async def create(
273287
messages: Iterable[turn_create_params.Message],
274288
stream: Literal[True],
275289
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
290+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
276291
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
277292
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
278293
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -283,6 +298,8 @@ async def create(
283298
) -> AsyncStream[AgentTurnResponseStreamChunk]:
284299
"""
285300
Args:
301+
tool_config: Configuration for tool use.
302+
286303
extra_headers: Send extra headers
287304
288305
extra_query: Add additional query parameters to the request
@@ -302,6 +319,7 @@ async def create(
302319
messages: Iterable[turn_create_params.Message],
303320
stream: bool,
304321
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
322+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
305323
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
306324
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
307325
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -312,6 +330,8 @@ async def create(
312330
) -> Turn | AsyncStream[AgentTurnResponseStreamChunk]:
313331
"""
314332
Args:
333+
tool_config: Configuration for tool use.
334+
315335
extra_headers: Send extra headers
316336
317337
extra_query: Add additional query parameters to the request
@@ -331,6 +351,7 @@ async def create(
331351
messages: Iterable[turn_create_params.Message],
332352
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
333353
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
354+
tool_config: turn_create_params.ToolConfig | NotGiven = NOT_GIVEN,
334355
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
335356
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
336357
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -350,6 +371,7 @@ async def create(
350371
"messages": messages,
351372
"documents": documents,
352373
"stream": stream,
374+
"tool_config": tool_config,
353375
"toolgroups": toolgroups,
354376
},
355377
turn_create_params.TurnCreateParams,

src/llama_stack_client/resources/batch_inference.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def chat_completion(
7272
) -> BatchInferenceChatCompletionResponse:
7373
"""
7474
Args:
75+
response_format: Configuration for JSON schema-guided response generation.
76+
77+
tool_choice: Whether tool use is required or automatic. This is a hint to the model which may
78+
not be followed. It depends on the Instruction Following capabilities of the
79+
model.
80+
81+
tool_prompt_format: Prompt format for calling custom / zero shot tools.
82+
7583
extra_headers: Send extra headers
7684
7785
extra_query: Add additional query parameters to the request
@@ -118,6 +126,8 @@ def completion(
118126
) -> BatchCompletion:
119127
"""
120128
Args:
129+
response_format: Configuration for JSON schema-guided response generation.
130+
121131
extra_headers: Send extra headers
122132
123133
extra_query: Add additional query parameters to the request
@@ -185,6 +195,14 @@ async def chat_completion(
185195
) -> BatchInferenceChatCompletionResponse:
186196
"""
187197
Args:
198+
response_format: Configuration for JSON schema-guided response generation.
199+
200+
tool_choice: Whether tool use is required or automatic. This is a hint to the model which may
201+
not be followed. It depends on the Instruction Following capabilities of the
202+
model.
203+
204+
tool_prompt_format: Prompt format for calling custom / zero shot tools.
205+
188206
extra_headers: Send extra headers
189207
190208
extra_query: Add additional query parameters to the request
@@ -231,6 +249,8 @@ async def completion(
231249
) -> BatchCompletion:
232250
"""
233251
Args:
252+
response_format: Configuration for JSON schema-guided response generation.
253+
234254
extra_headers: Send extra headers
235255
236256
extra_query: Add additional query parameters to the request

0 commit comments

Comments
 (0)