Skip to content

Commit 0b8f28e

Browse files
committed
fix: enable datasets,post_training and eval_tsaks works
1 parent 862e900 commit 0b8f28e

File tree

8 files changed

+17
-17
lines changed

8 files changed

+17
-17
lines changed

src/llama_stack_client/lib/cli/datasets/list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def list_datasets(ctx):
2121
console = Console()
2222
headers = ["identifier", "provider_id", "metadata", "type", "purpose"]
2323

24-
datasets_list_response = client.datasets.list()
24+
datasets_list_response = client.beta.datasets.list()
2525
if datasets_list_response:
2626
table = Table()
2727
for header in headers:

src/llama_stack_client/lib/cli/datasets/register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def register(
6868
if not url:
6969
raise click.BadParameter("URL is required when dataset path is not specified")
7070

71-
response = client.datasets.register(
71+
response = client.beta.datasets.register(
7272
dataset_id=dataset_id,
7373
source={"uri": url},
7474
metadata=metadata,

src/llama_stack_client/lib/cli/datasets/unregister.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
def unregister(ctx, dataset_id: str):
1717
"""Remove a dataset"""
1818
client = ctx.obj["client"]
19-
client.datasets.unregister(dataset_id=dataset_id)
19+
client.beta.datasets.unregister(dataset_id=dataset_id)
2020
click.echo(f"Dataset '{dataset_id}' unregistered successfully")

src/llama_stack_client/lib/cli/eval/run_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,16 @@ def run_benchmark(
9696
client = ctx.obj["client"]
9797

9898
for benchmark_id in benchmark_ids:
99-
benchmark = client.benchmarks.retrieve(benchmark_id=benchmark_id)
99+
benchmark = client.alpha.benchmarks.retrieve(benchmark_id=benchmark_id)
100100
scoring_functions = benchmark.scoring_functions
101101
dataset_id = benchmark.dataset_id
102102

103-
results = client.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples)
103+
results = client.beta.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples)
104104

105105
output_res = {}
106106

107107
for i, r in enumerate(tqdm(results.data)):
108-
eval_res = client.eval.evaluate_rows(
108+
eval_res = client.alpha.eval.evaluate_rows(
109109
benchmark_id=benchmark_id,
110110
input_rows=[r],
111111
scoring_functions=scoring_functions,

src/llama_stack_client/lib/cli/eval/run_scoring.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ def run_scoring(
7777
output_res = {}
7878

7979
if dataset_id is not None:
80-
dataset = client.datasets.retrieve(dataset_id=dataset_id)
80+
dataset = client.beta.datasets.retrieve(dataset_id=dataset_id)
8181
if not dataset:
8282
click.BadParameter(
8383
f"Dataset {dataset_id} not found. Please register using llama-stack-client datasets register"
8484
)
8585

8686
# TODO: this will eventually be replaced with jobs polling from server vis score_bath
8787
# For now, get all datasets rows via datasets API
88-
results = client.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples)
88+
results = client.beta.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples)
8989
rows = results.rows
9090

9191
if dataset_path is not None:

src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ def register(
4949
except json.JSONDecodeError as err:
5050
raise click.BadParameter("Metadata must be valid JSON") from err
5151

52-
response = client.eval_tasks.register(
53-
eval_task_id=eval_task_id,
52+
response = client.alpha.benchmarks.register(
53+
benchmark_id=eval_task_id,
5454
dataset_id=dataset_id,
5555
scoring_functions=scoring_functions,
5656
provider_id=provider_id,
57-
provider_eval_task_id=provider_eval_task_id,
57+
provider_benchmark_id=provider_eval_task_id,
5858
metadata=metadata,
5959
)
6060
if response:

src/llama_stack_client/lib/cli/eval_tasks/list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def list_eval_tasks(ctx):
2121
client = ctx.obj["client"]
2222
console = Console()
2323
headers = []
24-
eval_tasks_list_response = client.eval_tasks.list()
24+
eval_tasks_list_response = client.alpha.benchmarks.list()
2525
if eval_tasks_list_response and len(eval_tasks_list_response) > 0:
2626
headers = sorted(eval_tasks_list_response[0].__dict__.keys())
2727

src/llama_stack_client/lib/cli/post_training/post_training.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def supervised_fine_tune(
4444
client = ctx.obj["client"]
4545
console = Console()
4646

47-
post_training_job = client.post_training.supervised_fine_tune(
47+
post_training_job = client.alpha.post_training.supervised_fine_tune(
4848
job_uuid=job_uuid,
4949
model=model,
5050
algorithm_config=algorithm_config,
@@ -66,7 +66,7 @@ def get_training_jobs(ctx):
6666
client = ctx.obj["client"]
6767
console = Console()
6868

69-
post_training_jobs = client.post_training.job.list()
69+
post_training_jobs = client.alpha.post_training.job.list()
7070
console.print([post_training_job.job_uuid for post_training_job in post_training_jobs])
7171

7272

@@ -80,7 +80,7 @@ def get_training_job_status(ctx, job_uuid: str):
8080
client = ctx.obj["client"]
8181
console = Console()
8282

83-
job_status_reponse = client.post_training.job.status(job_uuid=job_uuid)
83+
job_status_reponse = client.alpha.post_training.job.status(job_uuid=job_uuid)
8484
console.print(job_status_reponse)
8585

8686

@@ -94,7 +94,7 @@ def get_training_job_artifacts(ctx, job_uuid: str):
9494
client = ctx.obj["client"]
9595
console = Console()
9696

97-
job_artifacts = client.post_training.job.artifacts(job_uuid=job_uuid)
97+
job_artifacts = client.alpha.post_training.job.artifacts(job_uuid=job_uuid)
9898
console.print(job_artifacts)
9999

100100

@@ -107,7 +107,7 @@ def cancel_training_job(ctx, job_uuid: str):
107107
"""Cancel the training job"""
108108
client = ctx.obj["client"]
109109

110-
client.post_training.job.cancel(job_uuid=job_uuid)
110+
client.alpha.post_training.job.cancel(job_uuid=job_uuid)
111111

112112

113113
# Register subcommands

0 commit comments

Comments
 (0)