From d484aa62a6d8c0b6645c70c56f3264d253b1c608 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 20 Jun 2025 11:48:57 -0700 Subject: [PATCH] Sync updates from stainless branch: ehhuang/dev --- .../resources/vector_stores/vector_stores.py | 4 +- .../types/response_object.py | 10 - src/llama_stack_client/types/vector_store.py | 16 +- .../types/vector_store_search_params.py | 10 +- tests/api_resources/test_vector_stores.py | 10 +- tests/conftest.py | 2 + tests/test_client.py | 179 +++++++++--------- 7 files changed, 124 insertions(+), 107 deletions(-) diff --git a/src/llama_stack_client/resources/vector_stores/vector_stores.py b/src/llama_stack_client/resources/vector_stores/vector_stores.py index de0b8205..79ab9db3 100644 --- a/src/llama_stack_client/resources/vector_stores/vector_stores.py +++ b/src/llama_stack_client/resources/vector_stores/vector_stores.py @@ -318,7 +318,7 @@ def search( query: Union[str, List[str]], filters: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, max_num_results: int | NotGiven = NOT_GIVEN, - ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, + ranking_options: vector_store_search_params.RankingOptions | NotGiven = NOT_GIVEN, rewrite_query: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -651,7 +651,7 @@ async def search( query: Union[str, List[str]], filters: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, max_num_results: int | NotGiven = NOT_GIVEN, - ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, + ranking_options: vector_store_search_params.RankingOptions | NotGiven = NOT_GIVEN, rewrite_query: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. diff --git a/src/llama_stack_client/types/response_object.py b/src/llama_stack_client/types/response_object.py index c09c1a10..a00115bb 100644 --- a/src/llama_stack_client/types/response_object.py +++ b/src/llama_stack_client/types/response_object.py @@ -189,16 +189,6 @@ class Error(BaseModel): class ResponseObject(BaseModel): - @property - def output_text(self) -> str: - texts: List[str] = [] - for output in self.output: - if output.type == "message": - for content in output.content: - if content.type == "output_text": - texts.append(content.text) - return "".join(texts) - id: str created_at: int diff --git a/src/llama_stack_client/types/vector_store.py b/src/llama_stack_client/types/vector_store.py index e766af15..5dc4ad3a 100644 --- a/src/llama_stack_client/types/vector_store.py +++ b/src/llama_stack_client/types/vector_store.py @@ -5,7 +5,19 @@ from .._models import BaseModel -__all__ = ["VectorStore"] +__all__ = ["VectorStore", "FileCounts"] + + +class FileCounts(BaseModel): + cancelled: int + + completed: int + + failed: int + + in_progress: int + + total: int class VectorStore(BaseModel): @@ -13,7 +25,7 @@ class VectorStore(BaseModel): created_at: int - file_counts: Dict[str, int] + file_counts: FileCounts metadata: Dict[str, Union[bool, float, str, List[object], object, None]] diff --git a/src/llama_stack_client/types/vector_store_search_params.py b/src/llama_stack_client/types/vector_store_search_params.py index 0c8545df..c7e86cd0 100644 --- a/src/llama_stack_client/types/vector_store_search_params.py +++ b/src/llama_stack_client/types/vector_store_search_params.py @@ -5,7 +5,7 @@ from typing import Dict, List, Union, Iterable from typing_extensions import Required, TypedDict -__all__ = ["VectorStoreSearchParams"] +__all__ = ["VectorStoreSearchParams", "RankingOptions"] class VectorStoreSearchParams(TypedDict, total=False): @@ -18,8 +18,14 @@ class VectorStoreSearchParams(TypedDict, total=False): max_num_results: int """Maximum number of results to return (1 to 50 inclusive, default 10).""" - ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + ranking_options: RankingOptions """Ranking options for fine-tuning the search results.""" rewrite_query: bool """Whether to rewrite the natural language query for vector search (default false)""" + + +class RankingOptions(TypedDict, total=False): + ranker: str + + score_threshold: float diff --git a/tests/api_resources/test_vector_stores.py b/tests/api_resources/test_vector_stores.py index 81b59ac1..bd63d5e7 100644 --- a/tests/api_resources/test_vector_stores.py +++ b/tests/api_resources/test_vector_stores.py @@ -242,7 +242,10 @@ def test_method_search_with_all_params(self, client: LlamaStackClient) -> None: query="string", filters={"foo": True}, max_num_results=0, - ranking_options={"foo": True}, + ranking_options={ + "ranker": "ranker", + "score_threshold": 0, + }, rewrite_query=True, ) assert_matches_type(VectorStoreSearchResponse, vector_store, path=["response"]) @@ -505,7 +508,10 @@ async def test_method_search_with_all_params(self, async_client: AsyncLlamaStack query="string", filters={"foo": True}, max_num_results=0, - ranking_options={"foo": True}, + ranking_options={ + "ranker": "ranker", + "score_threshold": 0, + }, rewrite_query=True, ) assert_matches_type(VectorStoreSearchResponse, vector_store, path=["response"]) diff --git a/tests/conftest.py b/tests/conftest.py index dd04ad98..ed5e8a48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + from __future__ import annotations import os diff --git a/tests/test_client.py b/tests/test_client.py index b3d169b0..59472837 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -23,17 +23,16 @@ from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient, APIResponseValidationError from llama_stack_client._types import Omit -from llama_stack_client._utils import maybe_transform from llama_stack_client._models import BaseModel, FinalRequestOptions -from llama_stack_client._constants import RAW_RESPONSE_HEADER from llama_stack_client._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError from llama_stack_client._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, + DefaultHttpxClient, + DefaultAsyncHttpxClient, make_request_options, ) -from llama_stack_client.types.inference_chat_completion_params import InferenceChatCompletionParamsNonStreaming from .utils import update_env @@ -677,60 +676,37 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaStackClient) -> None: respx_mock.post("/v1/inference/chat-completion").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - self.client.post( - "/v1/inference/chat-completion", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model_id="model_id", - ), - InferenceChatCompletionParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) + client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ).__enter__() assert _get_open_connections(self.client) == 0 @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaStackClient) -> None: respx_mock.post("/v1/inference/chat-completion").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - self.client.post( - "/v1/inference/chat-completion", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model_id="model_id", - ), - InferenceChatCompletionParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - + client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ).__enter__() assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @@ -836,6 +812,28 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" + def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + @pytest.mark.respx(base_url=base_url) def test_follow_redirects(self, respx_mock: MockRouter) -> None: # Test that the default follow_redirects=True allows following redirects @@ -1495,60 +1493,41 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_timeout_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncLlamaStackClient + ) -> None: respx_mock.post("/v1/inference/chat-completion").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - await self.client.post( - "/v1/inference/chat-completion", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model_id="model_id", - ), - InferenceChatCompletionParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) + await async_client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ).__aenter__() assert _get_open_connections(self.client) == 0 @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_status_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncLlamaStackClient + ) -> None: respx_mock.post("/v1/inference/chat-completion").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - await self.client.post( - "/v1/inference/chat-completion", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model_id="model_id", - ), - InferenceChatCompletionParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - + await async_client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ).__aenter__() assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @@ -1702,6 +1681,28 @@ async def test_main() -> None: time.sleep(0.1) + async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultAsyncHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + async def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultAsyncHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + @pytest.mark.respx(base_url=base_url) async def test_follow_redirects(self, respx_mock: MockRouter) -> None: # Test that the default follow_redirects=True allows following redirects