Skip to content

Commit fa1067b

Browse files
authored
Sync updates from stainless branch: ehhuang/dev (#243)
# What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant)
1 parent d2e7537 commit fa1067b

File tree

7 files changed

+124
-107
lines changed

7 files changed

+124
-107
lines changed

src/llama_stack_client/resources/vector_stores/vector_stores.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def search(
318318
query: Union[str, List[str]],
319319
filters: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
320320
max_num_results: int | NotGiven = NOT_GIVEN,
321-
ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
321+
ranking_options: vector_store_search_params.RankingOptions | NotGiven = NOT_GIVEN,
322322
rewrite_query: bool | NotGiven = NOT_GIVEN,
323323
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
324324
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -651,7 +651,7 @@ async def search(
651651
query: Union[str, List[str]],
652652
filters: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
653653
max_num_results: int | NotGiven = NOT_GIVEN,
654-
ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
654+
ranking_options: vector_store_search_params.RankingOptions | NotGiven = NOT_GIVEN,
655655
rewrite_query: bool | NotGiven = NOT_GIVEN,
656656
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
657657
# The extra values given here take precedence over values defined on the client or passed to this method.

src/llama_stack_client/types/response_object.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,6 @@ class Error(BaseModel):
189189

190190

191191
class ResponseObject(BaseModel):
192-
@property
193-
def output_text(self) -> str:
194-
texts: List[str] = []
195-
for output in self.output:
196-
if output.type == "message":
197-
for content in output.content:
198-
if content.type == "output_text":
199-
texts.append(content.text)
200-
return "".join(texts)
201-
202192
id: str
203193

204194
created_at: int

src/llama_stack_client/types/vector_store.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,27 @@
55

66
from .._models import BaseModel
77

8-
__all__ = ["VectorStore"]
8+
__all__ = ["VectorStore", "FileCounts"]
9+
10+
11+
class FileCounts(BaseModel):
12+
cancelled: int
13+
14+
completed: int
15+
16+
failed: int
17+
18+
in_progress: int
19+
20+
total: int
921

1022

1123
class VectorStore(BaseModel):
1224
id: str
1325

1426
created_at: int
1527

16-
file_counts: Dict[str, int]
28+
file_counts: FileCounts
1729

1830
metadata: Dict[str, Union[bool, float, str, List[object], object, None]]
1931

src/llama_stack_client/types/vector_store_search_params.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Dict, List, Union, Iterable
66
from typing_extensions import Required, TypedDict
77

8-
__all__ = ["VectorStoreSearchParams"]
8+
__all__ = ["VectorStoreSearchParams", "RankingOptions"]
99

1010

1111
class VectorStoreSearchParams(TypedDict, total=False):
@@ -18,8 +18,14 @@ class VectorStoreSearchParams(TypedDict, total=False):
1818
max_num_results: int
1919
"""Maximum number of results to return (1 to 50 inclusive, default 10)."""
2020

21-
ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]]
21+
ranking_options: RankingOptions
2222
"""Ranking options for fine-tuning the search results."""
2323

2424
rewrite_query: bool
2525
"""Whether to rewrite the natural language query for vector search (default false)"""
26+
27+
28+
class RankingOptions(TypedDict, total=False):
29+
ranker: str
30+
31+
score_threshold: float

tests/api_resources/test_vector_stores.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,10 @@ def test_method_search_with_all_params(self, client: LlamaStackClient) -> None:
242242
query="string",
243243
filters={"foo": True},
244244
max_num_results=0,
245-
ranking_options={"foo": True},
245+
ranking_options={
246+
"ranker": "ranker",
247+
"score_threshold": 0,
248+
},
246249
rewrite_query=True,
247250
)
248251
assert_matches_type(VectorStoreSearchResponse, vector_store, path=["response"])
@@ -505,7 +508,10 @@ async def test_method_search_with_all_params(self, async_client: AsyncLlamaStack
505508
query="string",
506509
filters={"foo": True},
507510
max_num_results=0,
508-
ranking_options={"foo": True},
511+
ranking_options={
512+
"ranker": "ranker",
513+
"score_threshold": 0,
514+
},
509515
rewrite_query=True,
510516
)
511517
assert_matches_type(VectorStoreSearchResponse, vector_store, path=["response"])

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
13
from __future__ import annotations
24

35
import os

tests/test_client.py

Lines changed: 90 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,16 @@
2323

2424
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient, APIResponseValidationError
2525
from llama_stack_client._types import Omit
26-
from llama_stack_client._utils import maybe_transform
2726
from llama_stack_client._models import BaseModel, FinalRequestOptions
28-
from llama_stack_client._constants import RAW_RESPONSE_HEADER
2927
from llama_stack_client._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError
3028
from llama_stack_client._base_client import (
3129
DEFAULT_TIMEOUT,
3230
HTTPX_DEFAULT_TIMEOUT,
3331
BaseClient,
32+
DefaultHttpxClient,
33+
DefaultAsyncHttpxClient,
3434
make_request_options,
3535
)
36-
from llama_stack_client.types.inference_chat_completion_params import InferenceChatCompletionParamsNonStreaming
3736

3837
from .utils import update_env
3938

@@ -677,60 +676,37 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
677676

678677
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
679678
@pytest.mark.respx(base_url=base_url)
680-
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
679+
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaStackClient) -> None:
681680
respx_mock.post("/v1/inference/chat-completion").mock(side_effect=httpx.TimeoutException("Test timeout error"))
682681

683682
with pytest.raises(APITimeoutError):
684-
self.client.post(
685-
"/v1/inference/chat-completion",
686-
body=cast(
687-
object,
688-
maybe_transform(
689-
dict(
690-
messages=[
691-
{
692-
"content": "string",
693-
"role": "user",
694-
}
695-
],
696-
model_id="model_id",
697-
),
698-
InferenceChatCompletionParamsNonStreaming,
699-
),
700-
),
701-
cast_to=httpx.Response,
702-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
703-
)
683+
client.inference.with_streaming_response.chat_completion(
684+
messages=[
685+
{
686+
"content": "string",
687+
"role": "user",
688+
}
689+
],
690+
model_id="model_id",
691+
).__enter__()
704692

705693
assert _get_open_connections(self.client) == 0
706694

707695
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
708696
@pytest.mark.respx(base_url=base_url)
709-
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
697+
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaStackClient) -> None:
710698
respx_mock.post("/v1/inference/chat-completion").mock(return_value=httpx.Response(500))
711699

712700
with pytest.raises(APIStatusError):
713-
self.client.post(
714-
"/v1/inference/chat-completion",
715-
body=cast(
716-
object,
717-
maybe_transform(
718-
dict(
719-
messages=[
720-
{
721-
"content": "string",
722-
"role": "user",
723-
}
724-
],
725-
model_id="model_id",
726-
),
727-
InferenceChatCompletionParamsNonStreaming,
728-
),
729-
),
730-
cast_to=httpx.Response,
731-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
732-
)
733-
701+
client.inference.with_streaming_response.chat_completion(
702+
messages=[
703+
{
704+
"content": "string",
705+
"role": "user",
706+
}
707+
],
708+
model_id="model_id",
709+
).__enter__()
734710
assert _get_open_connections(self.client) == 0
735711

736712
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@@ -836,6 +812,28 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
836812

837813
assert response.http_request.headers.get("x-stainless-retry-count") == "42"
838814

815+
def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
816+
# Test that the proxy environment variables are set correctly
817+
monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
818+
819+
client = DefaultHttpxClient()
820+
821+
mounts = tuple(client._mounts.items())
822+
assert len(mounts) == 1
823+
assert mounts[0][0].pattern == "https://"
824+
825+
@pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
826+
def test_default_client_creation(self) -> None:
827+
# Ensure that the client can be initialized without any exceptions
828+
DefaultHttpxClient(
829+
verify=True,
830+
cert=None,
831+
trust_env=True,
832+
http1=True,
833+
http2=False,
834+
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
835+
)
836+
839837
@pytest.mark.respx(base_url=base_url)
840838
def test_follow_redirects(self, respx_mock: MockRouter) -> None:
841839
# Test that the default follow_redirects=True allows following redirects
@@ -1495,60 +1493,41 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
14951493

14961494
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
14971495
@pytest.mark.respx(base_url=base_url)
1498-
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1496+
async def test_retrying_timeout_errors_doesnt_leak(
1497+
self, respx_mock: MockRouter, async_client: AsyncLlamaStackClient
1498+
) -> None:
14991499
respx_mock.post("/v1/inference/chat-completion").mock(side_effect=httpx.TimeoutException("Test timeout error"))
15001500

15011501
with pytest.raises(APITimeoutError):
1502-
await self.client.post(
1503-
"/v1/inference/chat-completion",
1504-
body=cast(
1505-
object,
1506-
maybe_transform(
1507-
dict(
1508-
messages=[
1509-
{
1510-
"content": "string",
1511-
"role": "user",
1512-
}
1513-
],
1514-
model_id="model_id",
1515-
),
1516-
InferenceChatCompletionParamsNonStreaming,
1517-
),
1518-
),
1519-
cast_to=httpx.Response,
1520-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1521-
)
1502+
await async_client.inference.with_streaming_response.chat_completion(
1503+
messages=[
1504+
{
1505+
"content": "string",
1506+
"role": "user",
1507+
}
1508+
],
1509+
model_id="model_id",
1510+
).__aenter__()
15221511

15231512
assert _get_open_connections(self.client) == 0
15241513

15251514
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
15261515
@pytest.mark.respx(base_url=base_url)
1527-
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1516+
async def test_retrying_status_errors_doesnt_leak(
1517+
self, respx_mock: MockRouter, async_client: AsyncLlamaStackClient
1518+
) -> None:
15281519
respx_mock.post("/v1/inference/chat-completion").mock(return_value=httpx.Response(500))
15291520

15301521
with pytest.raises(APIStatusError):
1531-
await self.client.post(
1532-
"/v1/inference/chat-completion",
1533-
body=cast(
1534-
object,
1535-
maybe_transform(
1536-
dict(
1537-
messages=[
1538-
{
1539-
"content": "string",
1540-
"role": "user",
1541-
}
1542-
],
1543-
model_id="model_id",
1544-
),
1545-
InferenceChatCompletionParamsNonStreaming,
1546-
),
1547-
),
1548-
cast_to=httpx.Response,
1549-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1550-
)
1551-
1522+
await async_client.inference.with_streaming_response.chat_completion(
1523+
messages=[
1524+
{
1525+
"content": "string",
1526+
"role": "user",
1527+
}
1528+
],
1529+
model_id="model_id",
1530+
).__aenter__()
15521531
assert _get_open_connections(self.client) == 0
15531532

15541533
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@@ -1702,6 +1681,28 @@ async def test_main() -> None:
17021681

17031682
time.sleep(0.1)
17041683

1684+
async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1685+
# Test that the proxy environment variables are set correctly
1686+
monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1687+
1688+
client = DefaultAsyncHttpxClient()
1689+
1690+
mounts = tuple(client._mounts.items())
1691+
assert len(mounts) == 1
1692+
assert mounts[0][0].pattern == "https://"
1693+
1694+
@pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1695+
async def test_default_client_creation(self) -> None:
1696+
# Ensure that the client can be initialized without any exceptions
1697+
DefaultAsyncHttpxClient(
1698+
verify=True,
1699+
cert=None,
1700+
trust_env=True,
1701+
http1=True,
1702+
http2=False,
1703+
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1704+
)
1705+
17051706
@pytest.mark.respx(base_url=base_url)
17061707
async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
17071708
# Test that the default follow_redirects=True allows following redirects

0 commit comments

Comments
 (0)