diff --git a/src/llama_stack_client/resources/eval/jobs.py b/src/llama_stack_client/resources/eval/jobs.py index 408bd4d8..d46b63f9 100644 --- a/src/llama_stack_client/resources/eval/jobs.py +++ b/src/llama_stack_client/resources/eval/jobs.py @@ -13,9 +13,9 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ...types.job import Job from ..._base_client import make_request_options from ...types.evaluate_response import EvaluateResponse -from ...types.eval.job_status_response import JobStatusResponse __all__ = ["JobsResource", "AsyncJobsResource"] @@ -124,7 +124,7 @@ def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobStatusResponse: + ) -> Job: """ Get the status of a job. @@ -146,7 +146,7 @@ def status( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=JobStatusResponse, + cast_to=Job, ) @@ -254,7 +254,7 @@ async def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobStatusResponse: + ) -> Job: """ Get the status of a job. @@ -276,7 +276,7 @@ async def status( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=JobStatusResponse, + cast_to=Job, ) diff --git a/src/llama_stack_client/resources/post_training/post_training.py b/src/llama_stack_client/resources/post_training/post_training.py index 944610fc..a93a1ebb 100644 --- a/src/llama_stack_client/resources/post_training/post_training.py +++ b/src/llama_stack_client/resources/post_training/post_training.py @@ -14,7 +14,10 @@ JobResourceWithStreamingResponse, AsyncJobResourceWithStreamingResponse, ) -from ...types import post_training_preference_optimize_params, post_training_supervised_fine_tune_params +from ...types import ( + post_training_preference_optimize_params, + post_training_supervised_fine_tune_params, +) from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, @@ -30,6 +33,7 @@ ) from ..._base_client import make_request_options from ...types.post_training_job import PostTrainingJob +from ...types.algorithm_config_param import AlgorithmConfigParam __all__ = ["PostTrainingResource", "AsyncPostTrainingResource"] @@ -111,7 +115,7 @@ def supervised_fine_tune( logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], model: str, training_config: post_training_supervised_fine_tune_params.TrainingConfig, - algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN, + algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN, checkpoint_dir: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -228,7 +232,7 @@ async def supervised_fine_tune( logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], model: str, training_config: post_training_supervised_fine_tune_params.TrainingConfig, - algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN, + algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN, checkpoint_dir: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. diff --git a/src/llama_stack_client/types/__init__.py b/src/llama_stack_client/types/__init__.py index 2e35f893..b45996a9 100644 --- a/src/llama_stack_client/types/__init__.py +++ b/src/llama_stack_client/types/__init__.py @@ -76,6 +76,7 @@ from .model_register_params import ModelRegisterParams as ModelRegisterParams from .query_chunks_response import QueryChunksResponse as QueryChunksResponse from .query_condition_param import QueryConditionParam as QueryConditionParam +from .algorithm_config_param import AlgorithmConfigParam as AlgorithmConfigParam from .benchmark_config_param import BenchmarkConfigParam as BenchmarkConfigParam from .list_datasets_response import ListDatasetsResponse as ListDatasetsResponse from .provider_list_response import ProviderListResponse as ProviderListResponse diff --git a/src/llama_stack_client/types/algorithm_config_param.py b/src/llama_stack_client/types/algorithm_config_param.py new file mode 100644 index 00000000..3f3c0cac --- /dev/null +++ b/src/llama_stack_client/types/algorithm_config_param.py @@ -0,0 +1,37 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import List, Union +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +__all__ = ["AlgorithmConfigParam", "LoraFinetuningConfig", "QatFinetuningConfig"] + + +class LoraFinetuningConfig(TypedDict, total=False): + alpha: Required[int] + + apply_lora_to_mlp: Required[bool] + + apply_lora_to_output: Required[bool] + + lora_attn_modules: Required[List[str]] + + rank: Required[int] + + type: Required[Literal["LoRA"]] + + quantize_base: bool + + use_dora: bool + + +class QatFinetuningConfig(TypedDict, total=False): + group_size: Required[int] + + quantizer_name: Required[str] + + type: Required[Literal["QAT"]] + + +AlgorithmConfigParam: TypeAlias = Union[LoraFinetuningConfig, QatFinetuningConfig] diff --git a/src/llama_stack_client/types/eval/__init__.py b/src/llama_stack_client/types/eval/__init__.py index a0c3f3bc..f8ee8b14 100644 --- a/src/llama_stack_client/types/eval/__init__.py +++ b/src/llama_stack_client/types/eval/__init__.py @@ -1,5 +1,3 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations - -from .job_status_response import JobStatusResponse as JobStatusResponse diff --git a/src/llama_stack_client/types/eval/job_status_response.py b/src/llama_stack_client/types/eval/job_status_response.py deleted file mode 100644 index 4f02f31d..00000000 --- a/src/llama_stack_client/types/eval/job_status_response.py +++ /dev/null @@ -1,7 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing_extensions import Literal, TypeAlias - -__all__ = ["JobStatusResponse"] - -JobStatusResponse: TypeAlias = Literal["completed", "in_progress", "failed", "scheduled"] diff --git a/src/llama_stack_client/types/job.py b/src/llama_stack_client/types/job.py index 25c33c4c..74c6beb7 100644 --- a/src/llama_stack_client/types/job.py +++ b/src/llama_stack_client/types/job.py @@ -1,5 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. +from typing_extensions import Literal from .._models import BaseModel @@ -8,3 +9,5 @@ class Job(BaseModel): job_id: str + + status: Literal["completed", "in_progress", "failed", "scheduled"] diff --git a/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py b/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py index 68b79782..fa18742a 100644 --- a/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py +++ b/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py @@ -2,8 +2,10 @@ from __future__ import annotations -from typing import Dict, List, Union, Iterable -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing import Dict, Union, Iterable +from typing_extensions import Literal, Required, TypedDict + +from .algorithm_config_param import AlgorithmConfigParam __all__ = [ "PostTrainingSupervisedFineTuneParams", @@ -11,9 +13,6 @@ "TrainingConfigDataConfig", "TrainingConfigOptimizerConfig", "TrainingConfigEfficiencyConfig", - "AlgorithmConfig", - "AlgorithmConfigLoraFinetuningConfig", - "AlgorithmConfigQatFinetuningConfig", ] @@ -28,7 +27,7 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False): training_config: Required[TrainingConfig] - algorithm_config: AlgorithmConfig + algorithm_config: AlgorithmConfigParam checkpoint_dir: str @@ -85,32 +84,3 @@ class TrainingConfig(TypedDict, total=False): dtype: str efficiency_config: TrainingConfigEfficiencyConfig - - -class AlgorithmConfigLoraFinetuningConfig(TypedDict, total=False): - alpha: Required[int] - - apply_lora_to_mlp: Required[bool] - - apply_lora_to_output: Required[bool] - - lora_attn_modules: Required[List[str]] - - rank: Required[int] - - type: Required[Literal["LoRA"]] - - quantize_base: bool - - use_dora: bool - - -class AlgorithmConfigQatFinetuningConfig(TypedDict, total=False): - group_size: Required[int] - - quantizer_name: Required[str] - - type: Required[Literal["QAT"]] - - -AlgorithmConfig: TypeAlias = Union[AlgorithmConfigLoraFinetuningConfig, AlgorithmConfigQatFinetuningConfig] diff --git a/src/llama_stack_client/types/tool_invocation_result.py b/src/llama_stack_client/types/tool_invocation_result.py index a28160bb..01f7db28 100644 --- a/src/llama_stack_client/types/tool_invocation_result.py +++ b/src/llama_stack_client/types/tool_invocation_result.py @@ -9,7 +9,7 @@ class ToolInvocationResult(BaseModel): - content: InterleavedContent + content: Optional[InterleavedContent] = None """A image content item""" error_code: Optional[int] = None diff --git a/tests/api_resources/eval/test_jobs.py b/tests/api_resources/eval/test_jobs.py index 874f5ff0..5f289c74 100644 --- a/tests/api_resources/eval/test_jobs.py +++ b/tests/api_resources/eval/test_jobs.py @@ -9,8 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import EvaluateResponse -from llama_stack_client.types.eval import JobStatusResponse +from llama_stack_client.types import Job, EvaluateResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -120,7 +119,7 @@ def test_method_status(self, client: LlamaStackClient) -> None: job_id="job_id", benchmark_id="benchmark_id", ) - assert_matches_type(JobStatusResponse, job, path=["response"]) + assert_matches_type(Job, job, path=["response"]) @parametrize def test_raw_response_status(self, client: LlamaStackClient) -> None: @@ -132,7 +131,7 @@ def test_raw_response_status(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(JobStatusResponse, job, path=["response"]) + assert_matches_type(Job, job, path=["response"]) @parametrize def test_streaming_response_status(self, client: LlamaStackClient) -> None: @@ -144,7 +143,7 @@ def test_streaming_response_status(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(JobStatusResponse, job, path=["response"]) + assert_matches_type(Job, job, path=["response"]) assert cast(Any, response.is_closed) is True @@ -268,7 +267,7 @@ async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None: job_id="job_id", benchmark_id="benchmark_id", ) - assert_matches_type(JobStatusResponse, job, path=["response"]) + assert_matches_type(Job, job, path=["response"]) @parametrize async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None: @@ -280,7 +279,7 @@ async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(JobStatusResponse, job, path=["response"]) + assert_matches_type(Job, job, path=["response"]) @parametrize async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None: @@ -292,7 +291,7 @@ async def test_streaming_response_status(self, async_client: AsyncLlamaStackClie assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = await response.parse() - assert_matches_type(JobStatusResponse, job, path=["response"]) + assert_matches_type(Job, job, path=["response"]) assert cast(Any, response.is_closed) is True