diff --git a/src/llama_stack_client/_decoders/jsonl.py b/src/llama_stack_client/_decoders/jsonl.py index e9d29a1c..ac5ac74f 100644 --- a/src/llama_stack_client/_decoders/jsonl.py +++ b/src/llama_stack_client/_decoders/jsonl.py @@ -17,11 +17,15 @@ class JSONLDecoder(Generic[_T]): into a given type. """ - http_response: httpx.Response | None + http_response: httpx.Response """The HTTP response this decoder was constructed from""" def __init__( - self, *, raw_iterator: Iterator[bytes], line_type: type[_T], http_response: httpx.Response | None + self, + *, + raw_iterator: Iterator[bytes], + line_type: type[_T], + http_response: httpx.Response, ) -> None: super().__init__() self.http_response = http_response @@ -29,6 +33,13 @@ def __init__( self._line_type = line_type self._iterator = self.__decode__() + def close(self) -> None: + """Close the response body stream. + + This is called automatically if you consume the entire stream. + """ + self.http_response.close() + def __decode__(self) -> Iterator[_T]: buf = b"" for chunk in self._raw_iterator: @@ -63,10 +74,14 @@ class AsyncJSONLDecoder(Generic[_T]): into a given type. """ - http_response: httpx.Response | None + http_response: httpx.Response def __init__( - self, *, raw_iterator: AsyncIterator[bytes], line_type: type[_T], http_response: httpx.Response | None + self, + *, + raw_iterator: AsyncIterator[bytes], + line_type: type[_T], + http_response: httpx.Response, ) -> None: super().__init__() self.http_response = http_response @@ -74,6 +89,13 @@ def __init__( self._line_type = line_type self._iterator = self.__decode__() + async def close(self) -> None: + """Close the response body stream. + + This is called automatically if you consume the entire stream. + """ + await self.http_response.aclose() + async def __decode__(self) -> AsyncIterator[_T]: buf = b"" async for chunk in self._raw_iterator: diff --git a/src/llama_stack_client/_models.py b/src/llama_stack_client/_models.py index 12c34b7d..c4401ff8 100644 --- a/src/llama_stack_client/_models.py +++ b/src/llama_stack_client/_models.py @@ -426,10 +426,16 @@ def construct_type(*, value: object, type_: object) -> object: If the given value does not match the expected type then it is returned as-is. """ + + # store a reference to the original type we were given before we extract any inner + # types so that we can properly resolve forward references in `TypeAliasType` annotations + original_type = None + # we allow `object` as the input type because otherwise, passing things like # `Literal['value']` will be reported as a type error by type checkers type_ = cast("type[object]", type_) if is_type_alias_type(type_): + original_type = type_ # type: ignore[unreachable] type_ = type_.__value__ # type: ignore[unreachable] # unwrap `Annotated[T, ...]` -> `T` @@ -446,7 +452,7 @@ def construct_type(*, value: object, type_: object) -> object: if is_union(origin): try: - return validate_type(type_=cast("type[object]", type_), value=value) + return validate_type(type_=cast("type[object]", original_type or type_), value=value) except Exception: pass diff --git a/src/llama_stack_client/_response.py b/src/llama_stack_client/_response.py index d7e58fbe..ea35182f 100644 --- a/src/llama_stack_client/_response.py +++ b/src/llama_stack_client/_response.py @@ -144,7 +144,7 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: return cast( R, cast("type[JSONLDecoder[Any]]", cast_to)( - raw_iterator=self.http_response.iter_bytes(chunk_size=4096), + raw_iterator=self.http_response.iter_bytes(chunk_size=64), line_type=extract_type_arg(cast_to, 0), http_response=self.http_response, ), @@ -154,7 +154,7 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: return cast( R, cast("type[AsyncJSONLDecoder[Any]]", cast_to)( - raw_iterator=self.http_response.aiter_bytes(chunk_size=4096), + raw_iterator=self.http_response.aiter_bytes(chunk_size=64), line_type=extract_type_arg(cast_to, 0), http_response=self.http_response, ), diff --git a/src/llama_stack_client/_utils/_transform.py b/src/llama_stack_client/_utils/_transform.py index a6b62cad..18afd9d8 100644 --- a/src/llama_stack_client/_utils/_transform.py +++ b/src/llama_stack_client/_utils/_transform.py @@ -25,7 +25,7 @@ is_annotated_type, strip_annotated_type, ) -from .._compat import model_dump, is_typeddict +from .._compat import get_origin, model_dump, is_typeddict _T = TypeVar("_T") @@ -164,9 +164,14 @@ def _transform_recursive( inner_type = annotation stripped_type = strip_annotated_type(inner_type) + origin = get_origin(stripped_type) or stripped_type if is_typeddict(stripped_type) and is_mapping(data): return _transform_typeddict(data, stripped_type) + if origin == dict and is_mapping(data): + items_type = get_args(stripped_type)[1] + return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + if ( # List[T] (is_list_type(stripped_type) and is_list(data)) @@ -307,9 +312,14 @@ async def _async_transform_recursive( inner_type = annotation stripped_type = strip_annotated_type(inner_type) + origin = get_origin(stripped_type) or stripped_type if is_typeddict(stripped_type) and is_mapping(data): return await _async_transform_typeddict(data, stripped_type) + if origin == dict and is_mapping(data): + items_type = get_args(stripped_type)[1] + return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + if ( # List[T] (is_list_type(stripped_type) and is_list(data)) diff --git a/src/llama_stack_client/resources/eval_tasks.py b/src/llama_stack_client/resources/eval_tasks.py index 82a07839..2afb692f 100644 --- a/src/llama_stack_client/resources/eval_tasks.py +++ b/src/llama_stack_client/resources/eval_tasks.py @@ -50,7 +50,7 @@ def with_streaming_response(self) -> EvalTasksResourceWithStreamingResponse: def retrieve( self, - eval_task_id: str, + task_id: str, *, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -69,10 +69,10 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ - if not eval_task_id: - raise ValueError(f"Expected a non-empty value for `eval_task_id` but received {eval_task_id!r}") + if not task_id: + raise ValueError(f"Expected a non-empty value for `task_id` but received {task_id!r}") return self._get( - f"/v1/eval-tasks/{eval_task_id}", + f"/v1/eval/tasks/{task_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -90,7 +90,7 @@ def list( timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EvalTaskListResponse: return self._get( - "/v1/eval-tasks", + "/v1/eval/tasks", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, @@ -105,8 +105,8 @@ def register( self, *, dataset_id: str, - eval_task_id: str, scoring_functions: List[str], + task_id: str, metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, provider_eval_task_id: str | NotGiven = NOT_GIVEN, provider_id: str | NotGiven = NOT_GIVEN, @@ -129,12 +129,12 @@ def register( """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( - "/v1/eval-tasks", + "/v1/eval/tasks", body=maybe_transform( { "dataset_id": dataset_id, - "eval_task_id": eval_task_id, "scoring_functions": scoring_functions, + "task_id": task_id, "metadata": metadata, "provider_eval_task_id": provider_eval_task_id, "provider_id": provider_id, @@ -170,7 +170,7 @@ def with_streaming_response(self) -> AsyncEvalTasksResourceWithStreamingResponse async def retrieve( self, - eval_task_id: str, + task_id: str, *, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -189,10 +189,10 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ - if not eval_task_id: - raise ValueError(f"Expected a non-empty value for `eval_task_id` but received {eval_task_id!r}") + if not task_id: + raise ValueError(f"Expected a non-empty value for `task_id` but received {task_id!r}") return await self._get( - f"/v1/eval-tasks/{eval_task_id}", + f"/v1/eval/tasks/{task_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -210,7 +210,7 @@ async def list( timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EvalTaskListResponse: return await self._get( - "/v1/eval-tasks", + "/v1/eval/tasks", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, @@ -225,8 +225,8 @@ async def register( self, *, dataset_id: str, - eval_task_id: str, scoring_functions: List[str], + task_id: str, metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, provider_eval_task_id: str | NotGiven = NOT_GIVEN, provider_id: str | NotGiven = NOT_GIVEN, @@ -249,12 +249,12 @@ async def register( """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( - "/v1/eval-tasks", + "/v1/eval/tasks", body=await async_maybe_transform( { "dataset_id": dataset_id, - "eval_task_id": eval_task_id, "scoring_functions": scoring_functions, + "task_id": task_id, "metadata": metadata, "provider_eval_task_id": provider_eval_task_id, "provider_id": provider_id, diff --git a/src/llama_stack_client/types/eval_task_register_params.py b/src/llama_stack_client/types/eval_task_register_params.py index 417bc2cd..1737ffa7 100644 --- a/src/llama_stack_client/types/eval_task_register_params.py +++ b/src/llama_stack_client/types/eval_task_register_params.py @@ -11,10 +11,10 @@ class EvalTaskRegisterParams(TypedDict, total=False): dataset_id: Required[str] - eval_task_id: Required[str] - scoring_functions: Required[List[str]] + task_id: Required[str] + metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] provider_eval_task_id: str diff --git a/tests/api_resources/test_eval_tasks.py b/tests/api_resources/test_eval_tasks.py index 5b18621b..0ac41e88 100644 --- a/tests/api_resources/test_eval_tasks.py +++ b/tests/api_resources/test_eval_tasks.py @@ -20,14 +20,14 @@ class TestEvalTasks: @parametrize def test_method_retrieve(self, client: LlamaStackClient) -> None: eval_task = client.eval_tasks.retrieve( - "eval_task_id", + "task_id", ) assert_matches_type(Optional[EvalTask], eval_task, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: response = client.eval_tasks.with_raw_response.retrieve( - "eval_task_id", + "task_id", ) assert response.is_closed is True @@ -38,7 +38,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: with client.eval_tasks.with_streaming_response.retrieve( - "eval_task_id", + "task_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -50,7 +50,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: @parametrize def test_path_params_retrieve(self, client: LlamaStackClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `eval_task_id` but received ''"): + with pytest.raises(ValueError, match=r"Expected a non-empty value for `task_id` but received ''"): client.eval_tasks.with_raw_response.retrieve( "", ) @@ -84,8 +84,8 @@ def test_streaming_response_list(self, client: LlamaStackClient) -> None: def test_method_register(self, client: LlamaStackClient) -> None: eval_task = client.eval_tasks.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", ) assert eval_task is None @@ -93,8 +93,8 @@ def test_method_register(self, client: LlamaStackClient) -> None: def test_method_register_with_all_params(self, client: LlamaStackClient) -> None: eval_task = client.eval_tasks.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", metadata={"foo": True}, provider_eval_task_id="provider_eval_task_id", provider_id="provider_id", @@ -105,8 +105,8 @@ def test_method_register_with_all_params(self, client: LlamaStackClient) -> None def test_raw_response_register(self, client: LlamaStackClient) -> None: response = client.eval_tasks.with_raw_response.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", ) assert response.is_closed is True @@ -118,8 +118,8 @@ def test_raw_response_register(self, client: LlamaStackClient) -> None: def test_streaming_response_register(self, client: LlamaStackClient) -> None: with client.eval_tasks.with_streaming_response.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -136,14 +136,14 @@ class TestAsyncEvalTasks: @parametrize async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: eval_task = await async_client.eval_tasks.retrieve( - "eval_task_id", + "task_id", ) assert_matches_type(Optional[EvalTask], eval_task, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.eval_tasks.with_raw_response.retrieve( - "eval_task_id", + "task_id", ) assert response.is_closed is True @@ -154,7 +154,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.eval_tasks.with_streaming_response.retrieve( - "eval_task_id", + "task_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -166,7 +166,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl @parametrize async def test_path_params_retrieve(self, async_client: AsyncLlamaStackClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `eval_task_id` but received ''"): + with pytest.raises(ValueError, match=r"Expected a non-empty value for `task_id` but received ''"): await async_client.eval_tasks.with_raw_response.retrieve( "", ) @@ -200,8 +200,8 @@ async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient async def test_method_register(self, async_client: AsyncLlamaStackClient) -> None: eval_task = await async_client.eval_tasks.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", ) assert eval_task is None @@ -209,8 +209,8 @@ async def test_method_register(self, async_client: AsyncLlamaStackClient) -> Non async def test_method_register_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: eval_task = await async_client.eval_tasks.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", metadata={"foo": True}, provider_eval_task_id="provider_eval_task_id", provider_id="provider_id", @@ -221,8 +221,8 @@ async def test_method_register_with_all_params(self, async_client: AsyncLlamaSta async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.eval_tasks.with_raw_response.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", ) assert response.is_closed is True @@ -234,8 +234,8 @@ async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) async def test_streaming_response_register(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.eval_tasks.with_streaming_response.register( dataset_id="dataset_id", - eval_task_id="eval_task_id", scoring_functions=["string"], + task_id="task_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/test_transform.py b/tests/test_transform.py index 364c685e..8ceafb36 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -2,7 +2,7 @@ import io import pathlib -from typing import Any, List, Union, TypeVar, Iterable, Optional, cast +from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict @@ -388,6 +388,15 @@ def my_iter() -> Iterable[Baz8]: } +@parametrize +@pytest.mark.asyncio +async def test_dictionary_items(use_async: bool) -> None: + class DictItems(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}} + + class TypedDictIterableUnionStr(TypedDict): foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]