Skip to content

Commit 7e862d8

Browse files
authored
Merge pull request #26 from meta-llama/dineshyv/create
Resource create
2 parents ca2ff31 + 80cd023 commit 7e862d8

File tree

11 files changed

+311
-112
lines changed

11 files changed

+311
-112
lines changed

src/llama_stack_client/lib/cli/datasets/datasets.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
# the root directory of this source tree.
66

77
import click
8+
import yaml
9+
from typing import Optional
10+
import json
811

12+
from llama_models.llama3.api.datatypes import URL
913
from .list import list_datasets
1014

1115

@@ -15,5 +19,49 @@ def datasets():
1519
pass
1620

1721

22+
@datasets.command()
23+
@click.option("--dataset-id", required=True, help="Id of the dataset")
24+
@click.option("--provider-id", help="Provider ID for the dataset", default=None)
25+
@click.option("--provider-dataset-id", help="Provider's dataset ID", default=None)
26+
@click.option("--metadata", type=str, help="Metadata of the dataset")
27+
@click.option("--url", type=str, help="URL of the dataset", required=True)
28+
@click.option("--schema", type=str, help="JSON schema of the dataset", required=True)
29+
@click.pass_context
30+
def register(
31+
ctx,
32+
dataset_id: str,
33+
provider_id: Optional[str],
34+
provider_dataset_id: Optional[str],
35+
metadata: Optional[str],
36+
url: str,
37+
schema: str,
38+
):
39+
"""Create a new dataset"""
40+
client = ctx.obj["client"]
41+
42+
try:
43+
dataset_schema = json.loads(schema)
44+
except json.JSONDecodeError:
45+
raise click.BadParameter("Schema must be valid JSON")
46+
47+
if metadata:
48+
try:
49+
metadata = json.loads(metadata)
50+
except json.JSONDecodeError:
51+
raise click.BadParameter("Metadata must be valid JSON")
52+
53+
response = client.datasets.register(
54+
dataset_id=dataset_id,
55+
dataset_schema=dataset_schema,
56+
url={"uri": url},
57+
provider_id=provider_id,
58+
provider_dataset_id=provider_dataset_id,
59+
metadata=metadata,
60+
)
61+
if response:
62+
click.echo(yaml.dump(response.dict()))
63+
64+
1865
# Register subcommands
1966
datasets.add_command(list_datasets)
67+
datasets.add_command(register)

src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77

88
import click
9+
import json
10+
import yaml
11+
from typing import Optional
912

1013
from .list import list_eval_tasks
1114

@@ -16,5 +19,44 @@ def eval_tasks():
1619
pass
1720

1821

22+
@eval_tasks.command()
23+
@click.option("--eval-task-id", required=True, help="ID of the eval task")
24+
@click.option("--dataset-id", required=True, help="ID of the dataset to evaluate")
25+
@click.option("--scoring-functions", required=True, multiple=True, help="Scoring functions to use for evaluation")
26+
@click.option("--provider-id", help="Provider ID for the eval task", default=None)
27+
@click.option("--provider-eval-task-id", help="Provider's eval task ID", default=None)
28+
@click.option("--metadata", type=str, help="Metadata for the eval task in JSON format")
29+
@click.pass_context
30+
def register(
31+
ctx,
32+
eval_task_id: str,
33+
dataset_id: str,
34+
scoring_functions: tuple[str, ...],
35+
provider_id: Optional[str],
36+
provider_eval_task_id: Optional[str],
37+
metadata: Optional[str],
38+
):
39+
"""Register a new eval task"""
40+
client = ctx.obj["client"]
41+
42+
if metadata:
43+
try:
44+
metadata = json.loads(metadata)
45+
except json.JSONDecodeError:
46+
raise click.BadParameter("Metadata must be valid JSON")
47+
48+
response = client.eval_tasks.register(
49+
eval_task_id=eval_task_id,
50+
dataset_id=dataset_id,
51+
scoring_functions=scoring_functions,
52+
provider_id=provider_id,
53+
provider_eval_task_id=provider_eval_task_id,
54+
metadata=metadata,
55+
)
56+
if response:
57+
click.echo(yaml.dump(response.dict()))
58+
59+
1960
# Register subcommands
2061
eval_tasks.add_command(list_eval_tasks)
62+
eval_tasks.add_command(register)

src/llama_stack_client/lib/cli/eval_tasks/list.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ def list_eval_tasks(ctx):
1616

1717
client = ctx.obj["client"]
1818

19-
headers = ["identifier", "provider_id", "description", "type"]
20-
19+
headers = []
2120
eval_tasks_list_response = client.eval_tasks.list()
21+
if eval_tasks_list_response and len(eval_tasks_list_response) > 0:
22+
headers = sorted(eval_tasks_list_response[0].__dict__.keys())
23+
2224
if eval_tasks_list_response:
2325
print_table_from_response(eval_tasks_list_response, headers)

src/llama_stack_client/lib/cli/memory_banks/list.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/llama_stack_client/lib/cli/memory_banks/memory_banks.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
# the root directory of this source tree.
66

77
import click
8-
9-
from .list import list_memory_banks
8+
from typing import Optional
9+
import yaml
10+
from llama_stack_client.lib.cli.common.utils import print_table_from_response
1011

1112

1213
@click.group()
@@ -15,5 +16,70 @@ def memory_banks():
1516
pass
1617

1718

19+
@click.command("list")
20+
@click.pass_context
21+
def list(ctx):
22+
"""Show available memory banks on distribution endpoint"""
23+
24+
client = ctx.obj["client"]
25+
26+
memory_banks_list_response = client.memory_banks.list()
27+
headers = []
28+
if memory_banks_list_response and len(memory_banks_list_response) > 0:
29+
headers = sorted(memory_banks_list_response[0].__dict__.keys())
30+
31+
if memory_banks_list_response:
32+
print_table_from_response(memory_banks_list_response, headers)
33+
34+
35+
@memory_banks.command()
36+
@click.option("--memory-bank-id", required=True, help="Id of the memory bank")
37+
@click.option("--type", type=click.Choice(["vector", "keyvalue", "keyword", "graph"]), required=True)
38+
@click.option("--provider-id", help="Provider ID for the memory bank", default=None)
39+
@click.option("--provider-memory-bank-id", help="Provider's memory bank ID", default=None)
40+
@click.option("--chunk-size", type=int, help="Chunk size in tokens (for vector type)", default=512)
41+
@click.option("--embedding-model", type=str, help="Embedding model (for vector type)", default="all-MiniLM-L6-v2")
42+
@click.option("--overlap-size", type=int, help="Overlap size in tokens (for vector type)", default=64)
43+
@click.pass_context
44+
def create(
45+
ctx,
46+
memory_bank_id: str,
47+
type: str,
48+
provider_id: Optional[str],
49+
provider_memory_bank_id: Optional[str],
50+
chunk_size: Optional[int],
51+
embedding_model: Optional[str],
52+
overlap_size: Optional[int],
53+
):
54+
"""Create a new memory bank"""
55+
client = ctx.obj["client"]
56+
57+
config = None
58+
if type == "vector":
59+
config = {
60+
"type": "vector",
61+
"chunk_size_in_tokens": chunk_size,
62+
"embedding_model": embedding_model,
63+
}
64+
if overlap_size:
65+
config["overlap_size_in_tokens"] = overlap_size
66+
elif type == "keyvalue":
67+
config = {"type": "keyvalue"}
68+
elif type == "keyword":
69+
config = {"type": "keyword"}
70+
elif type == "graph":
71+
config = {"type": "graph"}
72+
73+
response = client.memory_banks.register(
74+
memory_bank_id=memory_bank_id,
75+
params=config,
76+
provider_id=provider_id,
77+
provider_memory_bank_id=provider_memory_bank_id,
78+
)
79+
if response:
80+
click.echo(yaml.dump(response.dict()))
81+
82+
1883
# Register subcommands
19-
memory_banks.add_command(list_memory_banks)
84+
memory_banks.add_command(list)
85+
memory_banks.add_command(create)

src/llama_stack_client/lib/cli/models/get.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/llama_stack_client/lib/cli/models/list.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

src/llama_stack_client/lib/cli/models/models.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
# the root directory of this source tree.
66

77
import click
8-
from llama_stack_client.lib.cli.models.get import get_model
9-
from llama_stack_client.lib.cli.models.list import list_models
8+
from tabulate import tabulate
9+
from llama_stack_client.lib.cli.common.utils import print_table_from_response
10+
from typing import Optional
1011

1112

1213
@click.group()
@@ -15,6 +16,63 @@ def models():
1516
pass
1617

1718

19+
@click.command(name="list", help="Show available llama models at distribution endpoint")
20+
@click.pass_context
21+
def list_models(ctx):
22+
client = ctx.obj["client"]
23+
24+
headers = ["identifier", "provider_id", "provider_resource_id", "metadata"]
25+
response = client.models.list()
26+
if response:
27+
print_table_from_response(response, headers)
28+
29+
30+
@click.command(name="get")
31+
@click.argument("model_id")
32+
@click.pass_context
33+
def get_model(ctx, model_id: str):
34+
"""Show available llama models at distribution endpoint"""
35+
client = ctx.obj["client"]
36+
37+
models_get_response = client.models.retrieve(identifier=model_id)
38+
39+
if not models_get_response:
40+
click.echo(
41+
f"Model {model_id} is not found at distribution endpoint. "
42+
"Please ensure endpoint is serving specified model."
43+
)
44+
return
45+
46+
headers = sorted(models_get_response.__dict__.keys())
47+
rows = []
48+
rows.append([models_get_response.__dict__[headers[i]] for i in range(len(headers))])
49+
50+
click.echo(tabulate(rows, headers=headers, tablefmt="grid"))
51+
52+
53+
@click.command(name="register", help="Register a new model at distribution endpoint")
54+
@click.argument("model_id")
55+
@click.option("--provider-id", help="Provider ID for the model", default=None)
56+
@click.option("--provider-model-id", help="Provider's model ID", default=None)
57+
@click.option("--metadata", help="JSON metadata for the model", default=None)
58+
@click.pass_context
59+
def register_model(
60+
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
61+
):
62+
"""Register a new model at distribution endpoint"""
63+
client = ctx.obj["client"]
64+
65+
try:
66+
response = client.models.register(
67+
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
68+
)
69+
if response:
70+
click.echo(f"Successfully registered model {model_id}")
71+
except Exception as e:
72+
click.echo(f"Failed to register model: {str(e)}")
73+
74+
1875
# Register subcommands
1976
models.add_command(list_models)
2077
models.add_command(get_model)
78+
models.add_command(register_model)

0 commit comments

Comments
 (0)