Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def search(
query: Union[str, List[str]],
filters: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
max_num_results: int | NotGiven = NOT_GIVEN,
ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
ranking_options: vector_store_search_params.RankingOptions | NotGiven = NOT_GIVEN,
rewrite_query: bool | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -651,7 +651,7 @@ async def search(
query: Union[str, List[str]],
filters: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
max_num_results: int | NotGiven = NOT_GIVEN,
ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
ranking_options: vector_store_search_params.RankingOptions | NotGiven = NOT_GIVEN,
rewrite_query: bool | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down
10 changes: 0 additions & 10 deletions src/llama_stack_client/types/response_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,6 @@ class Error(BaseModel):


class ResponseObject(BaseModel):
@property
def output_text(self) -> str:
texts: List[str] = []
for output in self.output:
if output.type == "message":
for content in output.content:
if content.type == "output_text":
texts.append(content.text)
return "".join(texts)

id: str

created_at: int
Expand Down
16 changes: 14 additions & 2 deletions src/llama_stack_client/types/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,27 @@

from .._models import BaseModel

__all__ = ["VectorStore"]
__all__ = ["VectorStore", "FileCounts"]


class FileCounts(BaseModel):
cancelled: int

completed: int

failed: int

in_progress: int

total: int


class VectorStore(BaseModel):
id: str

created_at: int

file_counts: Dict[str, int]
file_counts: FileCounts

metadata: Dict[str, Union[bool, float, str, List[object], object, None]]

Expand Down
10 changes: 8 additions & 2 deletions src/llama_stack_client/types/vector_store_search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict, List, Union, Iterable
from typing_extensions import Required, TypedDict

__all__ = ["VectorStoreSearchParams"]
__all__ = ["VectorStoreSearchParams", "RankingOptions"]


class VectorStoreSearchParams(TypedDict, total=False):
Expand All @@ -18,8 +18,14 @@ class VectorStoreSearchParams(TypedDict, total=False):
max_num_results: int
"""Maximum number of results to return (1 to 50 inclusive, default 10)."""

ranking_options: Dict[str, Union[bool, float, str, Iterable[object], object, None]]
ranking_options: RankingOptions
"""Ranking options for fine-tuning the search results."""

rewrite_query: bool
"""Whether to rewrite the natural language query for vector search (default false)"""


class RankingOptions(TypedDict, total=False):
ranker: str

score_threshold: float
10 changes: 8 additions & 2 deletions tests/api_resources/test_vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ def test_method_search_with_all_params(self, client: LlamaStackClient) -> None:
query="string",
filters={"foo": True},
max_num_results=0,
ranking_options={"foo": True},
ranking_options={
"ranker": "ranker",
"score_threshold": 0,
},
rewrite_query=True,
)
assert_matches_type(VectorStoreSearchResponse, vector_store, path=["response"])
Expand Down Expand Up @@ -505,7 +508,10 @@ async def test_method_search_with_all_params(self, async_client: AsyncLlamaStack
query="string",
filters={"foo": True},
max_num_results=0,
ranking_options={"foo": True},
ranking_options={
"ranker": "ranker",
"score_threshold": 0,
},
rewrite_query=True,
)
assert_matches_type(VectorStoreSearchResponse, vector_store, path=["response"])
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from __future__ import annotations

import os
Expand Down
179 changes: 90 additions & 89 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,16 @@

from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient, APIResponseValidationError
from llama_stack_client._types import Omit
from llama_stack_client._utils import maybe_transform
from llama_stack_client._models import BaseModel, FinalRequestOptions
from llama_stack_client._constants import RAW_RESPONSE_HEADER
from llama_stack_client._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError
from llama_stack_client._base_client import (
DEFAULT_TIMEOUT,
HTTPX_DEFAULT_TIMEOUT,
BaseClient,
DefaultHttpxClient,
DefaultAsyncHttpxClient,
make_request_options,
)
from llama_stack_client.types.inference_chat_completion_params import InferenceChatCompletionParamsNonStreaming

from .utils import update_env

Expand Down Expand Up @@ -677,60 +676,37 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str

@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaStackClient) -> None:
respx_mock.post("/v1/inference/chat-completion").mock(side_effect=httpx.TimeoutException("Test timeout error"))

with pytest.raises(APITimeoutError):
self.client.post(
"/v1/inference/chat-completion",
body=cast(
object,
maybe_transform(
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
InferenceChatCompletionParamsNonStreaming,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
client.inference.with_streaming_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
).__enter__()

assert _get_open_connections(self.client) == 0

@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaStackClient) -> None:
respx_mock.post("/v1/inference/chat-completion").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
self.client.post(
"/v1/inference/chat-completion",
body=cast(
object,
maybe_transform(
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
InferenceChatCompletionParamsNonStreaming,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)

client.inference.with_streaming_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
).__enter__()
assert _get_open_connections(self.client) == 0

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
Expand Down Expand Up @@ -836,6 +812,28 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:

assert response.http_request.headers.get("x-stainless-retry-count") == "42"

def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Test that the proxy environment variables are set correctly
monkeypatch.setenv("HTTPS_PROXY", "https://example.org")

client = DefaultHttpxClient()

mounts = tuple(client._mounts.items())
assert len(mounts) == 1
assert mounts[0][0].pattern == "https://"

@pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
def test_default_client_creation(self) -> None:
# Ensure that the client can be initialized without any exceptions
DefaultHttpxClient(
verify=True,
cert=None,
trust_env=True,
http1=True,
http2=False,
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
)

@pytest.mark.respx(base_url=base_url)
def test_follow_redirects(self, respx_mock: MockRouter) -> None:
# Test that the default follow_redirects=True allows following redirects
Expand Down Expand Up @@ -1495,60 +1493,41 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte

@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
async def test_retrying_timeout_errors_doesnt_leak(
self, respx_mock: MockRouter, async_client: AsyncLlamaStackClient
) -> None:
respx_mock.post("/v1/inference/chat-completion").mock(side_effect=httpx.TimeoutException("Test timeout error"))

with pytest.raises(APITimeoutError):
await self.client.post(
"/v1/inference/chat-completion",
body=cast(
object,
maybe_transform(
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
InferenceChatCompletionParamsNonStreaming,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
await async_client.inference.with_streaming_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
).__aenter__()

assert _get_open_connections(self.client) == 0

@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
async def test_retrying_status_errors_doesnt_leak(
self, respx_mock: MockRouter, async_client: AsyncLlamaStackClient
) -> None:
respx_mock.post("/v1/inference/chat-completion").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
await self.client.post(
"/v1/inference/chat-completion",
body=cast(
object,
maybe_transform(
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
InferenceChatCompletionParamsNonStreaming,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)

await async_client.inference.with_streaming_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
).__aenter__()
assert _get_open_connections(self.client) == 0

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
Expand Down Expand Up @@ -1702,6 +1681,28 @@ async def test_main() -> None:

time.sleep(0.1)

async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Test that the proxy environment variables are set correctly
monkeypatch.setenv("HTTPS_PROXY", "https://example.org")

client = DefaultAsyncHttpxClient()

mounts = tuple(client._mounts.items())
assert len(mounts) == 1
assert mounts[0][0].pattern == "https://"

@pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
async def test_default_client_creation(self) -> None:
# Ensure that the client can be initialized without any exceptions
DefaultAsyncHttpxClient(
verify=True,
cert=None,
trust_env=True,
http1=True,
http2=False,
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
)

@pytest.mark.respx(base_url=base_url)
async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
# Test that the default follow_redirects=True allows following redirects
Expand Down
Loading