From addcbfacc32d8c229f74cc78b9ac875895bc2508 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 20 Mar 2025 10:00:01 -0700 Subject: [PATCH] Sync updates from stainless branch: ehhuang/dev --- src/llama_stack_client/__init__.py | 6 +++ .../resources/benchmarks.py | 6 +-- src/llama_stack_client/resources/datasets.py | 6 +-- src/llama_stack_client/resources/eval/jobs.py | 6 +-- src/llama_stack_client/resources/models.py | 6 +-- .../resources/post_training/job.py | 10 ++--- .../resources/post_training/post_training.py | 10 ++--- .../resources/scoring_functions.py | 6 +-- src/llama_stack_client/resources/shields.py | 6 +-- .../resources/vector_dbs.py | 6 +-- src/llama_stack_client/types/__init__.py | 1 - .../types/algorithm_config_param.py | 37 ----------------- .../types/eval/job_status_response.py | 3 +- ...st_training_supervised_fine_tune_params.py | 40 ++++++++++++++++--- .../types/scoring_fn_params.py | 12 ++++-- .../types/scoring_fn_params_param.py | 6 +-- .../types/shared/document.py | 5 ++- .../types/shared_params/document.py | 5 ++- tests/api_resources/eval/test_jobs.py | 14 +++---- tests/api_resources/post_training/test_job.py | 26 ++++++------ tests/api_resources/test_benchmarks.py | 14 +++---- tests/api_resources/test_datasets.py | 14 +++---- tests/api_resources/test_models.py | 14 +++---- tests/api_resources/test_scoring_functions.py | 14 +++---- tests/api_resources/test_shields.py | 14 +++---- tests/api_resources/test_vector_dbs.py | 14 +++---- 26 files changed, 152 insertions(+), 149 deletions(-) delete mode 100644 src/llama_stack_client/types/algorithm_config_param.py diff --git a/src/llama_stack_client/__init__.py b/src/llama_stack_client/__init__.py index 30e6e9cb..70ef01a4 100644 --- a/src/llama_stack_client/__init__.py +++ b/src/llama_stack_client/__init__.py @@ -37,6 +37,12 @@ from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging +from .lib.agents.agent import Agent +from .lib.agents.event_logger import EventLogger as AgentEventLogger +from .lib.inference.event_logger import EventLogger as InferenceEventLogger +from .types.agents.turn_create_params import Document +from .types.shared_params.document import Document as RAGDocument + __all__ = [ "types", "__version__", diff --git a/src/llama_stack_client/resources/benchmarks.py b/src/llama_stack_client/resources/benchmarks.py index fe05e518..f541a6ba 100644 --- a/src/llama_stack_client/resources/benchmarks.py +++ b/src/llama_stack_client/resources/benchmarks.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, List, Type, Union, Iterable, Optional, cast +from typing import Dict, List, Type, Union, Iterable, cast import httpx @@ -58,7 +58,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[Benchmark]: + ) -> Benchmark: """ Args: extra_headers: Send extra headers @@ -178,7 +178,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[Benchmark]: + ) -> Benchmark: """ Args: extra_headers: Send extra headers diff --git a/src/llama_stack_client/resources/datasets.py b/src/llama_stack_client/resources/datasets.py index 1df5f9c1..ed56ac80 100644 --- a/src/llama_stack_client/resources/datasets.py +++ b/src/llama_stack_client/resources/datasets.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, Type, Union, Iterable, Optional, cast +from typing import Dict, Type, Union, Iterable, cast from typing_extensions import Literal import httpx @@ -61,7 +61,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[DatasetRetrieveResponse]: + ) -> DatasetRetrieveResponse: """ Args: extra_headers: Send extra headers @@ -286,7 +286,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[DatasetRetrieveResponse]: + ) -> DatasetRetrieveResponse: """ Args: extra_headers: Send extra headers diff --git a/src/llama_stack_client/resources/eval/jobs.py b/src/llama_stack_client/resources/eval/jobs.py index 2b7bd817..408bd4d8 100644 --- a/src/llama_stack_client/resources/eval/jobs.py +++ b/src/llama_stack_client/resources/eval/jobs.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven @@ -126,7 +124,7 @@ def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[JobStatusResponse]: + ) -> JobStatusResponse: """ Get the status of a job. @@ -256,7 +254,7 @@ async def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[JobStatusResponse]: + ) -> JobStatusResponse: """ Get the status of a job. diff --git a/src/llama_stack_client/resources/models.py b/src/llama_stack_client/resources/models.py index 584c2001..db08a9d5 100644 --- a/src/llama_stack_client/resources/models.py +++ b/src/llama_stack_client/resources/models.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, Type, Union, Iterable, Optional, cast +from typing import Dict, Type, Union, Iterable, cast from typing_extensions import Literal import httpx @@ -59,7 +59,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[Model]: + ) -> Model: """ Args: extra_headers: Send extra headers @@ -208,7 +208,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[Model]: + ) -> Model: """ Args: extra_headers: Send extra headers diff --git a/src/llama_stack_client/resources/post_training/job.py b/src/llama_stack_client/resources/post_training/job.py index 28f9f66b..bcd31952 100644 --- a/src/llama_stack_client/resources/post_training/job.py +++ b/src/llama_stack_client/resources/post_training/job.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Type, Optional, cast +from typing import List, Type, cast import httpx @@ -81,7 +81,7 @@ def artifacts( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[JobArtifactsResponse]: + ) -> JobArtifactsResponse: """ Args: extra_headers: Send extra headers @@ -145,7 +145,7 @@ def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[JobStatusResponse]: + ) -> JobStatusResponse: """ Args: extra_headers: Send extra headers @@ -221,7 +221,7 @@ async def artifacts( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[JobArtifactsResponse]: + ) -> JobArtifactsResponse: """ Args: extra_headers: Send extra headers @@ -285,7 +285,7 @@ async def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[JobStatusResponse]: + ) -> JobStatusResponse: """ Args: extra_headers: Send extra headers diff --git a/src/llama_stack_client/resources/post_training/post_training.py b/src/llama_stack_client/resources/post_training/post_training.py index a93a1ebb..944610fc 100644 --- a/src/llama_stack_client/resources/post_training/post_training.py +++ b/src/llama_stack_client/resources/post_training/post_training.py @@ -14,10 +14,7 @@ JobResourceWithStreamingResponse, AsyncJobResourceWithStreamingResponse, ) -from ...types import ( - post_training_preference_optimize_params, - post_training_supervised_fine_tune_params, -) +from ...types import post_training_preference_optimize_params, post_training_supervised_fine_tune_params from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, @@ -33,7 +30,6 @@ ) from ..._base_client import make_request_options from ...types.post_training_job import PostTrainingJob -from ...types.algorithm_config_param import AlgorithmConfigParam __all__ = ["PostTrainingResource", "AsyncPostTrainingResource"] @@ -115,7 +111,7 @@ def supervised_fine_tune( logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], model: str, training_config: post_training_supervised_fine_tune_params.TrainingConfig, - algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN, + algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN, checkpoint_dir: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -232,7 +228,7 @@ async def supervised_fine_tune( logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], model: str, training_config: post_training_supervised_fine_tune_params.TrainingConfig, - algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN, + algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN, checkpoint_dir: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. diff --git a/src/llama_stack_client/resources/scoring_functions.py b/src/llama_stack_client/resources/scoring_functions.py index 1bc535ef..c152c805 100644 --- a/src/llama_stack_client/resources/scoring_functions.py +++ b/src/llama_stack_client/resources/scoring_functions.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Type, Optional, cast +from typing import Type, cast import httpx @@ -60,7 +60,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[ScoringFn]: + ) -> ScoringFn: """ Args: extra_headers: Send extra headers @@ -180,7 +180,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[ScoringFn]: + ) -> ScoringFn: """ Args: extra_headers: Send extra headers diff --git a/src/llama_stack_client/resources/shields.py b/src/llama_stack_client/resources/shields.py index 4205f972..150455c3 100644 --- a/src/llama_stack_client/resources/shields.py +++ b/src/llama_stack_client/resources/shields.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, Type, Union, Iterable, Optional, cast +from typing import Dict, Type, Union, Iterable, cast import httpx @@ -58,7 +58,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[Shield]: + ) -> Shield: """ Args: extra_headers: Send extra headers @@ -173,7 +173,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[Shield]: + ) -> Shield: """ Args: extra_headers: Send extra headers diff --git a/src/llama_stack_client/resources/vector_dbs.py b/src/llama_stack_client/resources/vector_dbs.py index 63f9086f..79d7939d 100644 --- a/src/llama_stack_client/resources/vector_dbs.py +++ b/src/llama_stack_client/resources/vector_dbs.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Type, Optional, cast +from typing import Type, cast import httpx @@ -59,7 +59,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[VectorDBRetrieveResponse]: + ) -> VectorDBRetrieveResponse: """ Args: extra_headers: Send extra headers @@ -208,7 +208,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Optional[VectorDBRetrieveResponse]: + ) -> VectorDBRetrieveResponse: """ Args: extra_headers: Send extra headers diff --git a/src/llama_stack_client/types/__init__.py b/src/llama_stack_client/types/__init__.py index b45996a9..2e35f893 100644 --- a/src/llama_stack_client/types/__init__.py +++ b/src/llama_stack_client/types/__init__.py @@ -76,7 +76,6 @@ from .model_register_params import ModelRegisterParams as ModelRegisterParams from .query_chunks_response import QueryChunksResponse as QueryChunksResponse from .query_condition_param import QueryConditionParam as QueryConditionParam -from .algorithm_config_param import AlgorithmConfigParam as AlgorithmConfigParam from .benchmark_config_param import BenchmarkConfigParam as BenchmarkConfigParam from .list_datasets_response import ListDatasetsResponse as ListDatasetsResponse from .provider_list_response import ProviderListResponse as ProviderListResponse diff --git a/src/llama_stack_client/types/algorithm_config_param.py b/src/llama_stack_client/types/algorithm_config_param.py deleted file mode 100644 index 3f3c0cac..00000000 --- a/src/llama_stack_client/types/algorithm_config_param.py +++ /dev/null @@ -1,37 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import List, Union -from typing_extensions import Literal, Required, TypeAlias, TypedDict - -__all__ = ["AlgorithmConfigParam", "LoraFinetuningConfig", "QatFinetuningConfig"] - - -class LoraFinetuningConfig(TypedDict, total=False): - alpha: Required[int] - - apply_lora_to_mlp: Required[bool] - - apply_lora_to_output: Required[bool] - - lora_attn_modules: Required[List[str]] - - rank: Required[int] - - type: Required[Literal["LoRA"]] - - quantize_base: bool - - use_dora: bool - - -class QatFinetuningConfig(TypedDict, total=False): - group_size: Required[int] - - quantizer_name: Required[str] - - type: Required[Literal["QAT"]] - - -AlgorithmConfigParam: TypeAlias = Union[LoraFinetuningConfig, QatFinetuningConfig] diff --git a/src/llama_stack_client/types/eval/job_status_response.py b/src/llama_stack_client/types/eval/job_status_response.py index 3aa0952a..4f02f31d 100644 --- a/src/llama_stack_client/types/eval/job_status_response.py +++ b/src/llama_stack_client/types/eval/job_status_response.py @@ -1,8 +1,7 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Optional from typing_extensions import Literal, TypeAlias __all__ = ["JobStatusResponse"] -JobStatusResponse: TypeAlias = Optional[Literal["completed", "in_progress", "failed", "scheduled"]] +JobStatusResponse: TypeAlias = Literal["completed", "in_progress", "failed", "scheduled"] diff --git a/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py b/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py index fa18742a..68b79782 100644 --- a/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py +++ b/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py @@ -2,10 +2,8 @@ from __future__ import annotations -from typing import Dict, Union, Iterable -from typing_extensions import Literal, Required, TypedDict - -from .algorithm_config_param import AlgorithmConfigParam +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, TypeAlias, TypedDict __all__ = [ "PostTrainingSupervisedFineTuneParams", @@ -13,6 +11,9 @@ "TrainingConfigDataConfig", "TrainingConfigOptimizerConfig", "TrainingConfigEfficiencyConfig", + "AlgorithmConfig", + "AlgorithmConfigLoraFinetuningConfig", + "AlgorithmConfigQatFinetuningConfig", ] @@ -27,7 +28,7 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False): training_config: Required[TrainingConfig] - algorithm_config: AlgorithmConfigParam + algorithm_config: AlgorithmConfig checkpoint_dir: str @@ -84,3 +85,32 @@ class TrainingConfig(TypedDict, total=False): dtype: str efficiency_config: TrainingConfigEfficiencyConfig + + +class AlgorithmConfigLoraFinetuningConfig(TypedDict, total=False): + alpha: Required[int] + + apply_lora_to_mlp: Required[bool] + + apply_lora_to_output: Required[bool] + + lora_attn_modules: Required[List[str]] + + rank: Required[int] + + type: Required[Literal["LoRA"]] + + quantize_base: bool + + use_dora: bool + + +class AlgorithmConfigQatFinetuningConfig(TypedDict, total=False): + group_size: Required[int] + + quantizer_name: Required[str] + + type: Required[Literal["QAT"]] + + +AlgorithmConfig: TypeAlias = Union[AlgorithmConfigLoraFinetuningConfig, AlgorithmConfigQatFinetuningConfig] diff --git a/src/llama_stack_client/types/scoring_fn_params.py b/src/llama_stack_client/types/scoring_fn_params.py index 5ca23590..6f4a62b0 100644 --- a/src/llama_stack_client/types/scoring_fn_params.py +++ b/src/llama_stack_client/types/scoring_fn_params.py @@ -14,7 +14,9 @@ class LlmAsJudgeScoringFnParams(BaseModel): type: Literal["llm_as_judge"] - aggregation_functions: Optional[List[Literal["average", "median", "categorical_count", "accuracy"]]] = None + aggregation_functions: Optional[ + List[Literal["average", "weighted_average", "median", "categorical_count", "accuracy"]] + ] = None judge_score_regexes: Optional[List[str]] = None @@ -24,7 +26,9 @@ class LlmAsJudgeScoringFnParams(BaseModel): class RegexParserScoringFnParams(BaseModel): type: Literal["regex_parser"] - aggregation_functions: Optional[List[Literal["average", "median", "categorical_count", "accuracy"]]] = None + aggregation_functions: Optional[ + List[Literal["average", "weighted_average", "median", "categorical_count", "accuracy"]] + ] = None parsing_regexes: Optional[List[str]] = None @@ -32,7 +36,9 @@ class RegexParserScoringFnParams(BaseModel): class BasicScoringFnParams(BaseModel): type: Literal["basic"] - aggregation_functions: Optional[List[Literal["average", "median", "categorical_count", "accuracy"]]] = None + aggregation_functions: Optional[ + List[Literal["average", "weighted_average", "median", "categorical_count", "accuracy"]] + ] = None ScoringFnParams: TypeAlias = Annotated[ diff --git a/src/llama_stack_client/types/scoring_fn_params_param.py b/src/llama_stack_client/types/scoring_fn_params_param.py index 5b636c27..4c255b52 100644 --- a/src/llama_stack_client/types/scoring_fn_params_param.py +++ b/src/llama_stack_client/types/scoring_fn_params_param.py @@ -13,7 +13,7 @@ class LlmAsJudgeScoringFnParams(TypedDict, total=False): type: Required[Literal["llm_as_judge"]] - aggregation_functions: List[Literal["average", "median", "categorical_count", "accuracy"]] + aggregation_functions: List[Literal["average", "weighted_average", "median", "categorical_count", "accuracy"]] judge_score_regexes: List[str] @@ -23,7 +23,7 @@ class LlmAsJudgeScoringFnParams(TypedDict, total=False): class RegexParserScoringFnParams(TypedDict, total=False): type: Required[Literal["regex_parser"]] - aggregation_functions: List[Literal["average", "median", "categorical_count", "accuracy"]] + aggregation_functions: List[Literal["average", "weighted_average", "median", "categorical_count", "accuracy"]] parsing_regexes: List[str] @@ -31,7 +31,7 @@ class RegexParserScoringFnParams(TypedDict, total=False): class BasicScoringFnParams(TypedDict, total=False): type: Required[Literal["basic"]] - aggregation_functions: List[Literal["average", "median", "categorical_count", "accuracy"]] + aggregation_functions: List[Literal["average", "weighted_average", "median", "categorical_count", "accuracy"]] ScoringFnParamsParam: TypeAlias = Union[LlmAsJudgeScoringFnParams, RegexParserScoringFnParams, BasicScoringFnParams] diff --git a/src/llama_stack_client/types/shared/document.py b/src/llama_stack_client/types/shared/document.py index b9bfa898..67704232 100644 --- a/src/llama_stack_client/types/shared/document.py +++ b/src/llama_stack_client/types/shared/document.py @@ -59,10 +59,13 @@ class ContentURL(BaseModel): class Document(BaseModel): content: Content - """A image content item""" + """The content of the document.""" document_id: str + """The unique identifier for the document.""" metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + """Additional metadata for the document.""" mime_type: Optional[str] = None + """The MIME type of the document.""" diff --git a/src/llama_stack_client/types/shared_params/document.py b/src/llama_stack_client/types/shared_params/document.py index fd3c3df1..78564cfa 100644 --- a/src/llama_stack_client/types/shared_params/document.py +++ b/src/llama_stack_client/types/shared_params/document.py @@ -60,10 +60,13 @@ class ContentURL(TypedDict, total=False): class Document(TypedDict, total=False): content: Required[Content] - """A image content item""" + """The content of the document.""" document_id: Required[str] + """The unique identifier for the document.""" metadata: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + """Additional metadata for the document.""" mime_type: str + """The MIME type of the document.""" diff --git a/tests/api_resources/eval/test_jobs.py b/tests/api_resources/eval/test_jobs.py index f9b85a08..874f5ff0 100644 --- a/tests/api_resources/eval/test_jobs.py +++ b/tests/api_resources/eval/test_jobs.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, cast import pytest @@ -120,7 +120,7 @@ def test_method_status(self, client: LlamaStackClient) -> None: job_id="job_id", benchmark_id="benchmark_id", ) - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize def test_raw_response_status(self, client: LlamaStackClient) -> None: @@ -132,7 +132,7 @@ def test_raw_response_status(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize def test_streaming_response_status(self, client: LlamaStackClient) -> None: @@ -144,7 +144,7 @@ def test_streaming_response_status(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) assert cast(Any, response.is_closed) is True @@ -268,7 +268,7 @@ async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None: job_id="job_id", benchmark_id="benchmark_id", ) - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None: @@ -280,7 +280,7 @@ async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None: @@ -292,7 +292,7 @@ async def test_streaming_response_status(self, async_client: AsyncLlamaStackClie assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/post_training/test_job.py b/tests/api_resources/post_training/test_job.py index c38838d7..6fca52db 100644 --- a/tests/api_resources/post_training/test_job.py +++ b/tests/api_resources/post_training/test_job.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, List, Optional, cast +from typing import Any, List, cast import pytest @@ -51,7 +51,7 @@ def test_method_artifacts(self, client: LlamaStackClient) -> None: job = client.post_training.job.artifacts( job_uuid="job_uuid", ) - assert_matches_type(Optional[JobArtifactsResponse], job, path=["response"]) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) @parametrize def test_raw_response_artifacts(self, client: LlamaStackClient) -> None: @@ -62,7 +62,7 @@ def test_raw_response_artifacts(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(Optional[JobArtifactsResponse], job, path=["response"]) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) @parametrize def test_streaming_response_artifacts(self, client: LlamaStackClient) -> None: @@ -73,7 +73,7 @@ def test_streaming_response_artifacts(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(Optional[JobArtifactsResponse], job, path=["response"]) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) assert cast(Any, response.is_closed) is True @@ -113,7 +113,7 @@ def test_method_status(self, client: LlamaStackClient) -> None: job = client.post_training.job.status( job_uuid="job_uuid", ) - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize def test_raw_response_status(self, client: LlamaStackClient) -> None: @@ -124,7 +124,7 @@ def test_raw_response_status(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize def test_streaming_response_status(self, client: LlamaStackClient) -> None: @@ -135,7 +135,7 @@ def test_streaming_response_status(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) assert cast(Any, response.is_closed) is True @@ -173,7 +173,7 @@ async def test_method_artifacts(self, async_client: AsyncLlamaStackClient) -> No job = await async_client.post_training.job.artifacts( job_uuid="job_uuid", ) - assert_matches_type(Optional[JobArtifactsResponse], job, path=["response"]) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) @parametrize async def test_raw_response_artifacts(self, async_client: AsyncLlamaStackClient) -> None: @@ -184,7 +184,7 @@ async def test_raw_response_artifacts(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(Optional[JobArtifactsResponse], job, path=["response"]) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) @parametrize async def test_streaming_response_artifacts(self, async_client: AsyncLlamaStackClient) -> None: @@ -195,7 +195,7 @@ async def test_streaming_response_artifacts(self, async_client: AsyncLlamaStackC assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(Optional[JobArtifactsResponse], job, path=["response"]) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) assert cast(Any, response.is_closed) is True @@ -235,7 +235,7 @@ async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None: job = await async_client.post_training.job.status( job_uuid="job_uuid", ) - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None: @@ -246,7 +246,7 @@ async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) @parametrize async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None: @@ -257,6 +257,6 @@ async def test_streaming_response_status(self, async_client: AsyncLlamaStackClie assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(Optional[JobStatusResponse], job, path=["response"]) + assert_matches_type(JobStatusResponse, job, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_benchmarks.py b/tests/api_resources/test_benchmarks.py index 03aceead..12cb3870 100644 --- a/tests/api_resources/test_benchmarks.py +++ b/tests/api_resources/test_benchmarks.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, cast import pytest @@ -22,7 +22,7 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: benchmark = client.benchmarks.retrieve( "benchmark_id", ) - assert_matches_type(Optional[Benchmark], benchmark, path=["response"]) + assert_matches_type(Benchmark, benchmark, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @@ -33,7 +33,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" benchmark = response.parse() - assert_matches_type(Optional[Benchmark], benchmark, path=["response"]) + assert_matches_type(Benchmark, benchmark, path=["response"]) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: @@ -44,7 +44,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" benchmark = response.parse() - assert_matches_type(Optional[Benchmark], benchmark, path=["response"]) + assert_matches_type(Benchmark, benchmark, path=["response"]) assert cast(Any, response.is_closed) is True @@ -138,7 +138,7 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non benchmark = await async_client.benchmarks.retrieve( "benchmark_id", ) - assert_matches_type(Optional[Benchmark], benchmark, path=["response"]) + assert_matches_type(Benchmark, benchmark, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -149,7 +149,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" benchmark = await response.parse() - assert_matches_type(Optional[Benchmark], benchmark, path=["response"]) + assert_matches_type(Benchmark, benchmark, path=["response"]) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -160,7 +160,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" benchmark = await response.parse() - assert_matches_type(Optional[Benchmark], benchmark, path=["response"]) + assert_matches_type(Benchmark, benchmark, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_datasets.py b/tests/api_resources/test_datasets.py index 7f19e741..010e10d0 100644 --- a/tests/api_resources/test_datasets.py +++ b/tests/api_resources/test_datasets.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, cast import pytest @@ -27,7 +27,7 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: dataset = client.datasets.retrieve( "dataset_id", ) - assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) + assert_matches_type(DatasetRetrieveResponse, dataset, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @@ -38,7 +38,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) + assert_matches_type(DatasetRetrieveResponse, dataset, path=["response"]) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: @@ -49,7 +49,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) + assert_matches_type(DatasetRetrieveResponse, dataset, path=["response"]) assert cast(Any, response.is_closed) is True @@ -235,7 +235,7 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non dataset = await async_client.datasets.retrieve( "dataset_id", ) - assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) + assert_matches_type(DatasetRetrieveResponse, dataset, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -246,7 +246,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) + assert_matches_type(DatasetRetrieveResponse, dataset, path=["response"]) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -257,7 +257,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) + assert_matches_type(DatasetRetrieveResponse, dataset, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index c38903d5..a2c8e68a 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, cast import pytest @@ -22,7 +22,7 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: model = client.models.retrieve( "model_id", ) - assert_matches_type(Optional[Model], model, path=["response"]) + assert_matches_type(Model, model, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @@ -33,7 +33,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert_matches_type(Optional[Model], model, path=["response"]) + assert_matches_type(Model, model, path=["response"]) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: @@ -44,7 +44,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert_matches_type(Optional[Model], model, path=["response"]) + assert_matches_type(Model, model, path=["response"]) assert cast(Any, response.is_closed) is True @@ -169,7 +169,7 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non model = await async_client.models.retrieve( "model_id", ) - assert_matches_type(Optional[Model], model, path=["response"]) + assert_matches_type(Model, model, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -180,7 +180,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert_matches_type(Optional[Model], model, path=["response"]) + assert_matches_type(Model, model, path=["response"]) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -191,7 +191,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert_matches_type(Optional[Model], model, path=["response"]) + assert_matches_type(Model, model, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_scoring_functions.py b/tests/api_resources/test_scoring_functions.py index 44f8d3df..5806bf59 100644 --- a/tests/api_resources/test_scoring_functions.py +++ b/tests/api_resources/test_scoring_functions.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, cast import pytest @@ -25,7 +25,7 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: scoring_function = client.scoring_functions.retrieve( "scoring_fn_id", ) - assert_matches_type(Optional[ScoringFn], scoring_function, path=["response"]) + assert_matches_type(ScoringFn, scoring_function, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @@ -36,7 +36,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" scoring_function = response.parse() - assert_matches_type(Optional[ScoringFn], scoring_function, path=["response"]) + assert_matches_type(ScoringFn, scoring_function, path=["response"]) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: @@ -47,7 +47,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" scoring_function = response.parse() - assert_matches_type(Optional[ScoringFn], scoring_function, path=["response"]) + assert_matches_type(ScoringFn, scoring_function, path=["response"]) assert cast(Any, response.is_closed) is True @@ -147,7 +147,7 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non scoring_function = await async_client.scoring_functions.retrieve( "scoring_fn_id", ) - assert_matches_type(Optional[ScoringFn], scoring_function, path=["response"]) + assert_matches_type(ScoringFn, scoring_function, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -158,7 +158,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" scoring_function = await response.parse() - assert_matches_type(Optional[ScoringFn], scoring_function, path=["response"]) + assert_matches_type(ScoringFn, scoring_function, path=["response"]) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -169,7 +169,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" scoring_function = await response.parse() - assert_matches_type(Optional[ScoringFn], scoring_function, path=["response"]) + assert_matches_type(ScoringFn, scoring_function, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_shields.py b/tests/api_resources/test_shields.py index a32be825..a351a6f0 100644 --- a/tests/api_resources/test_shields.py +++ b/tests/api_resources/test_shields.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, cast import pytest @@ -22,7 +22,7 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: shield = client.shields.retrieve( "identifier", ) - assert_matches_type(Optional[Shield], shield, path=["response"]) + assert_matches_type(Shield, shield, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @@ -33,7 +33,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = response.parse() - assert_matches_type(Optional[Shield], shield, path=["response"]) + assert_matches_type(Shield, shield, path=["response"]) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: @@ -44,7 +44,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = response.parse() - assert_matches_type(Optional[Shield], shield, path=["response"]) + assert_matches_type(Shield, shield, path=["response"]) assert cast(Any, response.is_closed) is True @@ -130,7 +130,7 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non shield = await async_client.shields.retrieve( "identifier", ) - assert_matches_type(Optional[Shield], shield, path=["response"]) + assert_matches_type(Shield, shield, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -141,7 +141,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = await response.parse() - assert_matches_type(Optional[Shield], shield, path=["response"]) + assert_matches_type(Shield, shield, path=["response"]) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -152,7 +152,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = await response.parse() - assert_matches_type(Optional[Shield], shield, path=["response"]) + assert_matches_type(Shield, shield, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_vector_dbs.py b/tests/api_resources/test_vector_dbs.py index 63c5a3f0..d185edf1 100644 --- a/tests/api_resources/test_vector_dbs.py +++ b/tests/api_resources/test_vector_dbs.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, cast +from typing import Any, cast import pytest @@ -26,7 +26,7 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: vector_db = client.vector_dbs.retrieve( "vector_db_id", ) - assert_matches_type(Optional[VectorDBRetrieveResponse], vector_db, path=["response"]) + assert_matches_type(VectorDBRetrieveResponse, vector_db, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @@ -37,7 +37,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" vector_db = response.parse() - assert_matches_type(Optional[VectorDBRetrieveResponse], vector_db, path=["response"]) + assert_matches_type(VectorDBRetrieveResponse, vector_db, path=["response"]) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: @@ -48,7 +48,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" vector_db = response.parse() - assert_matches_type(Optional[VectorDBRetrieveResponse], vector_db, path=["response"]) + assert_matches_type(VectorDBRetrieveResponse, vector_db, path=["response"]) assert cast(Any, response.is_closed) is True @@ -176,7 +176,7 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non vector_db = await async_client.vector_dbs.retrieve( "vector_db_id", ) - assert_matches_type(Optional[VectorDBRetrieveResponse], vector_db, path=["response"]) + assert_matches_type(VectorDBRetrieveResponse, vector_db, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -187,7 +187,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" vector_db = await response.parse() - assert_matches_type(Optional[VectorDBRetrieveResponse], vector_db, path=["response"]) + assert_matches_type(VectorDBRetrieveResponse, vector_db, path=["response"]) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: @@ -198,7 +198,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" vector_db = await response.parse() - assert_matches_type(Optional[VectorDBRetrieveResponse], vector_db, path=["response"]) + assert_matches_type(VectorDBRetrieveResponse, vector_db, path=["response"]) assert cast(Any, response.is_closed) is True