Skip to content

Commit d3bb258

Browse files
authored
Sync updates from stainless branch: yanxi0830/dev (#209)
# What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant)
1 parent 6ff7e05 commit d3bb258

File tree

10 files changed

+66
-61
lines changed

10 files changed

+66
-61
lines changed

src/llama_stack_client/resources/eval/jobs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
async_to_raw_response_wrapper,
1414
async_to_streamed_response_wrapper,
1515
)
16+
from ...types.job import Job
1617
from ..._base_client import make_request_options
1718
from ...types.evaluate_response import EvaluateResponse
18-
from ...types.eval.job_status_response import JobStatusResponse
1919

2020
__all__ = ["JobsResource", "AsyncJobsResource"]
2121

@@ -124,7 +124,7 @@ def status(
124124
extra_query: Query | None = None,
125125
extra_body: Body | None = None,
126126
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
127-
) -> JobStatusResponse:
127+
) -> Job:
128128
"""
129129
Get the status of a job.
130130
@@ -146,7 +146,7 @@ def status(
146146
options=make_request_options(
147147
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
148148
),
149-
cast_to=JobStatusResponse,
149+
cast_to=Job,
150150
)
151151

152152

@@ -254,7 +254,7 @@ async def status(
254254
extra_query: Query | None = None,
255255
extra_body: Body | None = None,
256256
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
257-
) -> JobStatusResponse:
257+
) -> Job:
258258
"""
259259
Get the status of a job.
260260
@@ -276,7 +276,7 @@ async def status(
276276
options=make_request_options(
277277
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
278278
),
279-
cast_to=JobStatusResponse,
279+
cast_to=Job,
280280
)
281281

282282

src/llama_stack_client/resources/post_training/post_training.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
JobResourceWithStreamingResponse,
1515
AsyncJobResourceWithStreamingResponse,
1616
)
17-
from ...types import post_training_preference_optimize_params, post_training_supervised_fine_tune_params
17+
from ...types import (
18+
post_training_preference_optimize_params,
19+
post_training_supervised_fine_tune_params,
20+
)
1821
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
1922
from ..._utils import (
2023
maybe_transform,
@@ -30,6 +33,7 @@
3033
)
3134
from ..._base_client import make_request_options
3235
from ...types.post_training_job import PostTrainingJob
36+
from ...types.algorithm_config_param import AlgorithmConfigParam
3337

3438
__all__ = ["PostTrainingResource", "AsyncPostTrainingResource"]
3539

@@ -111,7 +115,7 @@ def supervised_fine_tune(
111115
logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]],
112116
model: str,
113117
training_config: post_training_supervised_fine_tune_params.TrainingConfig,
114-
algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN,
118+
algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN,
115119
checkpoint_dir: str | NotGiven = NOT_GIVEN,
116120
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
117121
# The extra values given here take precedence over values defined on the client or passed to this method.
@@ -228,7 +232,7 @@ async def supervised_fine_tune(
228232
logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]],
229233
model: str,
230234
training_config: post_training_supervised_fine_tune_params.TrainingConfig,
231-
algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig | NotGiven = NOT_GIVEN,
235+
algorithm_config: AlgorithmConfigParam | NotGiven = NOT_GIVEN,
232236
checkpoint_dir: str | NotGiven = NOT_GIVEN,
233237
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
234238
# The extra values given here take precedence over values defined on the client or passed to this method.

src/llama_stack_client/types/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from .model_register_params import ModelRegisterParams as ModelRegisterParams
7777
from .query_chunks_response import QueryChunksResponse as QueryChunksResponse
7878
from .query_condition_param import QueryConditionParam as QueryConditionParam
79+
from .algorithm_config_param import AlgorithmConfigParam as AlgorithmConfigParam
7980
from .benchmark_config_param import BenchmarkConfigParam as BenchmarkConfigParam
8081
from .list_datasets_response import ListDatasetsResponse as ListDatasetsResponse
8182
from .provider_list_response import ProviderListResponse as ProviderListResponse
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
from typing import List, Union
6+
from typing_extensions import Literal, Required, TypeAlias, TypedDict
7+
8+
__all__ = ["AlgorithmConfigParam", "LoraFinetuningConfig", "QatFinetuningConfig"]
9+
10+
11+
class LoraFinetuningConfig(TypedDict, total=False):
12+
alpha: Required[int]
13+
14+
apply_lora_to_mlp: Required[bool]
15+
16+
apply_lora_to_output: Required[bool]
17+
18+
lora_attn_modules: Required[List[str]]
19+
20+
rank: Required[int]
21+
22+
type: Required[Literal["LoRA"]]
23+
24+
quantize_base: bool
25+
26+
use_dora: bool
27+
28+
29+
class QatFinetuningConfig(TypedDict, total=False):
30+
group_size: Required[int]
31+
32+
quantizer_name: Required[str]
33+
34+
type: Required[Literal["QAT"]]
35+
36+
37+
AlgorithmConfigParam: TypeAlias = Union[LoraFinetuningConfig, QatFinetuningConfig]
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

33
from __future__ import annotations
4-
5-
from .job_status_response import JobStatusResponse as JobStatusResponse

src/llama_stack_client/types/eval/job_status_response.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

src/llama_stack_client/types/job.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

3+
from typing_extensions import Literal
34

45
from .._models import BaseModel
56

@@ -8,3 +9,5 @@
89

910
class Job(BaseModel):
1011
job_id: str
12+
13+
status: Literal["completed", "in_progress", "failed", "scheduled"]

src/llama_stack_client/types/post_training_supervised_fine_tune_params.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22

33
from __future__ import annotations
44

5-
from typing import Dict, List, Union, Iterable
6-
from typing_extensions import Literal, Required, TypeAlias, TypedDict
5+
from typing import Dict, Union, Iterable
6+
from typing_extensions import Literal, Required, TypedDict
7+
8+
from .algorithm_config_param import AlgorithmConfigParam
79

810
__all__ = [
911
"PostTrainingSupervisedFineTuneParams",
1012
"TrainingConfig",
1113
"TrainingConfigDataConfig",
1214
"TrainingConfigOptimizerConfig",
1315
"TrainingConfigEfficiencyConfig",
14-
"AlgorithmConfig",
15-
"AlgorithmConfigLoraFinetuningConfig",
16-
"AlgorithmConfigQatFinetuningConfig",
1716
]
1817

1918

@@ -28,7 +27,7 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False):
2827

2928
training_config: Required[TrainingConfig]
3029

31-
algorithm_config: AlgorithmConfig
30+
algorithm_config: AlgorithmConfigParam
3231

3332
checkpoint_dir: str
3433

@@ -85,32 +84,3 @@ class TrainingConfig(TypedDict, total=False):
8584
dtype: str
8685

8786
efficiency_config: TrainingConfigEfficiencyConfig
88-
89-
90-
class AlgorithmConfigLoraFinetuningConfig(TypedDict, total=False):
91-
alpha: Required[int]
92-
93-
apply_lora_to_mlp: Required[bool]
94-
95-
apply_lora_to_output: Required[bool]
96-
97-
lora_attn_modules: Required[List[str]]
98-
99-
rank: Required[int]
100-
101-
type: Required[Literal["LoRA"]]
102-
103-
quantize_base: bool
104-
105-
use_dora: bool
106-
107-
108-
class AlgorithmConfigQatFinetuningConfig(TypedDict, total=False):
109-
group_size: Required[int]
110-
111-
quantizer_name: Required[str]
112-
113-
type: Required[Literal["QAT"]]
114-
115-
116-
AlgorithmConfig: TypeAlias = Union[AlgorithmConfigLoraFinetuningConfig, AlgorithmConfigQatFinetuningConfig]

src/llama_stack_client/types/tool_invocation_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
class ToolInvocationResult(BaseModel):
12-
content: InterleavedContent
12+
content: Optional[InterleavedContent] = None
1313
"""A image content item"""
1414

1515
error_code: Optional[int] = None

tests/api_resources/eval/test_jobs.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
from tests.utils import assert_matches_type
1111
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
12-
from llama_stack_client.types import EvaluateResponse
13-
from llama_stack_client.types.eval import JobStatusResponse
12+
from llama_stack_client.types import Job, EvaluateResponse
1413

1514
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
1615

@@ -120,7 +119,7 @@ def test_method_status(self, client: LlamaStackClient) -> None:
120119
job_id="job_id",
121120
benchmark_id="benchmark_id",
122121
)
123-
assert_matches_type(JobStatusResponse, job, path=["response"])
122+
assert_matches_type(Job, job, path=["response"])
124123

125124
@parametrize
126125
def test_raw_response_status(self, client: LlamaStackClient) -> None:
@@ -132,7 +131,7 @@ def test_raw_response_status(self, client: LlamaStackClient) -> None:
132131
assert response.is_closed is True
133132
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
134133
job = response.parse()
135-
assert_matches_type(JobStatusResponse, job, path=["response"])
134+
assert_matches_type(Job, job, path=["response"])
136135

137136
@parametrize
138137
def test_streaming_response_status(self, client: LlamaStackClient) -> None:
@@ -144,7 +143,7 @@ def test_streaming_response_status(self, client: LlamaStackClient) -> None:
144143
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
145144

146145
job = response.parse()
147-
assert_matches_type(JobStatusResponse, job, path=["response"])
146+
assert_matches_type(Job, job, path=["response"])
148147

149148
assert cast(Any, response.is_closed) is True
150149

@@ -268,7 +267,7 @@ async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None:
268267
job_id="job_id",
269268
benchmark_id="benchmark_id",
270269
)
271-
assert_matches_type(JobStatusResponse, job, path=["response"])
270+
assert_matches_type(Job, job, path=["response"])
272271

273272
@parametrize
274273
async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None:
@@ -280,7 +279,7 @@ async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) ->
280279
assert response.is_closed is True
281280
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
282281
job = await response.parse()
283-
assert_matches_type(JobStatusResponse, job, path=["response"])
282+
assert_matches_type(Job, job, path=["response"])
284283

285284
@parametrize
286285
async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None:
@@ -292,7 +291,7 @@ async def test_streaming_response_status(self, async_client: AsyncLlamaStackClie
292291
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
293292

294293
job = await response.parse()
295-
assert_matches_type(JobStatusResponse, job, path=["response"])
294+
assert_matches_type(Job, job, path=["response"])
296295

297296
assert cast(Any, response.is_closed) is True
298297

0 commit comments

Comments
 (0)