diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index 84f767aa..76166b33 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -1,61 +1,62 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations -import json +import json import os -from typing import Any, Union, Mapping -from typing_extensions import Self, override +from typing import Any, Mapping, Union import httpx +from typing_extensions import Self, override from . import _exceptions +from ._base_client import ( + DEFAULT_MAX_RETRIES, + AsyncAPIClient, + SyncAPIClient, +) +from ._exceptions import APIStatusError from ._qs import Querystring +from ._streaming import AsyncStream as AsyncStream +from ._streaming import Stream as Stream from ._types import ( NOT_GIVEN, - Omit, - Timeout, NotGiven, - Transport, + Omit, ProxiesTypes, RequestOptions, + Timeout, + Transport, ) from ._utils import ( - is_given, get_async_library, + is_given, ) from ._version import __version__ from .resources import ( - tools, + batch_inference, + datasetio, + datasets, + eval_tasks, + inference, + inspect, memory, + memory_banks, models, + providers, routes, safety, - inspect, scoring, + scoring_functions, shields, - datasets, - datasetio, - inference, - providers, + synthetic_data_generation, telemetry, - eval_tasks, - toolgroups, - memory_banks, tool_runtime, - batch_inference, - scoring_functions, - synthetic_data_generation, -) -from ._streaming import Stream as Stream, AsyncStream as AsyncStream -from ._exceptions import APIStatusError -from ._base_client import ( - DEFAULT_MAX_RETRIES, - SyncAPIClient, - AsyncAPIClient, + toolgroups, + tools, ) -from .resources.eval import eval from .resources.agents import agents +from .resources.eval import eval from .resources.post_training import post_training __all__ = [ @@ -126,15 +127,12 @@ def __init__( if base_url is None: base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") if base_url is None: - base_url = f"http://any-hosted-llama-stack.com" + base_url = "http://any-hosted-llama-stack.com" if provider_data is not None: if default_headers is None: default_headers = {} - default_headers.update({ - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - "X-LlamaStack-Client-Version": __version__ - }) + default_headers["X-LlamaStack-ProviderData"] = json.dumps(provider_data) super().__init__( version=__version__, @@ -164,7 +162,9 @@ def __init__( self.routes = routes.RoutesResource(self) self.safety = safety.SafetyResource(self) self.shields = shields.ShieldsResource(self) - self.synthetic_data_generation = synthetic_data_generation.SyntheticDataGenerationResource(self) + self.synthetic_data_generation = ( + synthetic_data_generation.SyntheticDataGenerationResource(self) + ) self.telemetry = telemetry.TelemetryResource(self) self.datasetio = datasetio.DatasetioResource(self) self.scoring = scoring.ScoringResource(self) @@ -204,10 +204,14 @@ def copy( Create a new client instance re-using the same options given to the current client with optional overriding. """ if default_headers is not None and set_default_headers is not None: - raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + raise ValueError( + "The `default_headers` and `set_default_headers` arguments are mutually exclusive" + ) if default_query is not None and set_default_query is not None: - raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") + raise ValueError( + "The `default_query` and `set_default_query` arguments are mutually exclusive" + ) headers = self._custom_headers if default_headers is not None: @@ -248,10 +252,14 @@ def _make_status_error( return _exceptions.BadRequestError(err_msg, response=response, body=body) if response.status_code == 401: - return _exceptions.AuthenticationError(err_msg, response=response, body=body) + return _exceptions.AuthenticationError( + err_msg, response=response, body=body + ) if response.status_code == 403: - return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) + return _exceptions.PermissionDeniedError( + err_msg, response=response, body=body + ) if response.status_code == 404: return _exceptions.NotFoundError(err_msg, response=response, body=body) @@ -260,13 +268,17 @@ def _make_status_error( return _exceptions.ConflictError(err_msg, response=response, body=body) if response.status_code == 422: - return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) + return _exceptions.UnprocessableEntityError( + err_msg, response=response, body=body + ) if response.status_code == 429: return _exceptions.RateLimitError(err_msg, response=response, body=body) if response.status_code >= 500: - return _exceptions.InternalServerError(err_msg, response=response, body=body) + return _exceptions.InternalServerError( + err_msg, response=response, body=body + ) return APIStatusError(err_msg, response=response, body=body) @@ -288,7 +300,9 @@ class AsyncLlamaStackClient(AsyncAPIClient): routes: routes.AsyncRoutesResource safety: safety.AsyncSafetyResource shields: shields.AsyncShieldsResource - synthetic_data_generation: synthetic_data_generation.AsyncSyntheticDataGenerationResource + synthetic_data_generation: ( + synthetic_data_generation.AsyncSyntheticDataGenerationResource + ) telemetry: telemetry.AsyncTelemetryResource datasetio: datasetio.AsyncDatasetioResource scoring: scoring.AsyncScoringResource @@ -326,15 +340,12 @@ def __init__( if base_url is None: base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL") if base_url is None: - base_url = f"http://any-hosted-llama-stack.com" + base_url = "http://any-hosted-llama-stack.com" if provider_data is not None: if default_headers is None: default_headers = {} - default_headers.update({ - "X-LlamaStack-Provider-Data": json.dumps(provider_data), - "X-LlamaStack-Client-Version": __version__ - }) + default_headers["X-LlamaStack-ProviderData"] = json.dumps(provider_data) super().__init__( version=__version__, @@ -364,7 +375,9 @@ def __init__( self.routes = routes.AsyncRoutesResource(self) self.safety = safety.AsyncSafetyResource(self) self.shields = shields.AsyncShieldsResource(self) - self.synthetic_data_generation = synthetic_data_generation.AsyncSyntheticDataGenerationResource(self) + self.synthetic_data_generation = ( + synthetic_data_generation.AsyncSyntheticDataGenerationResource(self) + ) self.telemetry = telemetry.AsyncTelemetryResource(self) self.datasetio = datasetio.AsyncDatasetioResource(self) self.scoring = scoring.AsyncScoringResource(self) @@ -404,10 +417,14 @@ def copy( Create a new client instance re-using the same options given to the current client with optional overriding. """ if default_headers is not None and set_default_headers is not None: - raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + raise ValueError( + "The `default_headers` and `set_default_headers` arguments are mutually exclusive" + ) if default_query is not None and set_default_query is not None: - raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") + raise ValueError( + "The `default_query` and `set_default_query` arguments are mutually exclusive" + ) headers = self._custom_headers if default_headers is not None: @@ -448,10 +465,14 @@ def _make_status_error( return _exceptions.BadRequestError(err_msg, response=response, body=body) if response.status_code == 401: - return _exceptions.AuthenticationError(err_msg, response=response, body=body) + return _exceptions.AuthenticationError( + err_msg, response=response, body=body + ) if response.status_code == 403: - return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) + return _exceptions.PermissionDeniedError( + err_msg, response=response, body=body + ) if response.status_code == 404: return _exceptions.NotFoundError(err_msg, response=response, body=body) @@ -460,138 +481,232 @@ def _make_status_error( return _exceptions.ConflictError(err_msg, response=response, body=body) if response.status_code == 422: - return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) + return _exceptions.UnprocessableEntityError( + err_msg, response=response, body=body + ) if response.status_code == 429: return _exceptions.RateLimitError(err_msg, response=response, body=body) if response.status_code >= 500: - return _exceptions.InternalServerError(err_msg, response=response, body=body) + return _exceptions.InternalServerError( + err_msg, response=response, body=body + ) return APIStatusError(err_msg, response=response, body=body) class LlamaStackClientWithRawResponse: def __init__(self, client: LlamaStackClient) -> None: - self.toolgroups = toolgroups.ToolgroupsResourceWithRawResponse(client.toolgroups) + self.toolgroups = toolgroups.ToolgroupsResourceWithRawResponse( + client.toolgroups + ) self.tools = tools.ToolsResourceWithRawResponse(client.tools) - self.tool_runtime = tool_runtime.ToolRuntimeResourceWithRawResponse(client.tool_runtime) + self.tool_runtime = tool_runtime.ToolRuntimeResourceWithRawResponse( + client.tool_runtime + ) self.agents = agents.AgentsResourceWithRawResponse(client.agents) - self.batch_inference = batch_inference.BatchInferenceResourceWithRawResponse(client.batch_inference) + self.batch_inference = batch_inference.BatchInferenceResourceWithRawResponse( + client.batch_inference + ) self.datasets = datasets.DatasetsResourceWithRawResponse(client.datasets) self.eval = eval.EvalResourceWithRawResponse(client.eval) self.inspect = inspect.InspectResourceWithRawResponse(client.inspect) self.inference = inference.InferenceResourceWithRawResponse(client.inference) self.memory = memory.MemoryResourceWithRawResponse(client.memory) - self.memory_banks = memory_banks.MemoryBanksResourceWithRawResponse(client.memory_banks) + self.memory_banks = memory_banks.MemoryBanksResourceWithRawResponse( + client.memory_banks + ) self.models = models.ModelsResourceWithRawResponse(client.models) - self.post_training = post_training.PostTrainingResourceWithRawResponse(client.post_training) + self.post_training = post_training.PostTrainingResourceWithRawResponse( + client.post_training + ) self.providers = providers.ProvidersResourceWithRawResponse(client.providers) self.routes = routes.RoutesResourceWithRawResponse(client.routes) self.safety = safety.SafetyResourceWithRawResponse(client.safety) self.shields = shields.ShieldsResourceWithRawResponse(client.shields) - self.synthetic_data_generation = synthetic_data_generation.SyntheticDataGenerationResourceWithRawResponse( - client.synthetic_data_generation + self.synthetic_data_generation = ( + synthetic_data_generation.SyntheticDataGenerationResourceWithRawResponse( + client.synthetic_data_generation + ) ) self.telemetry = telemetry.TelemetryResourceWithRawResponse(client.telemetry) self.datasetio = datasetio.DatasetioResourceWithRawResponse(client.datasetio) self.scoring = scoring.ScoringResourceWithRawResponse(client.scoring) - self.scoring_functions = scoring_functions.ScoringFunctionsResourceWithRawResponse(client.scoring_functions) + self.scoring_functions = ( + scoring_functions.ScoringFunctionsResourceWithRawResponse( + client.scoring_functions + ) + ) self.eval_tasks = eval_tasks.EvalTasksResourceWithRawResponse(client.eval_tasks) class AsyncLlamaStackClientWithRawResponse: def __init__(self, client: AsyncLlamaStackClient) -> None: - self.toolgroups = toolgroups.AsyncToolgroupsResourceWithRawResponse(client.toolgroups) + self.toolgroups = toolgroups.AsyncToolgroupsResourceWithRawResponse( + client.toolgroups + ) self.tools = tools.AsyncToolsResourceWithRawResponse(client.tools) - self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithRawResponse(client.tool_runtime) + self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithRawResponse( + client.tool_runtime + ) self.agents = agents.AsyncAgentsResourceWithRawResponse(client.agents) - self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithRawResponse(client.batch_inference) + self.batch_inference = ( + batch_inference.AsyncBatchInferenceResourceWithRawResponse( + client.batch_inference + ) + ) self.datasets = datasets.AsyncDatasetsResourceWithRawResponse(client.datasets) self.eval = eval.AsyncEvalResourceWithRawResponse(client.eval) self.inspect = inspect.AsyncInspectResourceWithRawResponse(client.inspect) - self.inference = inference.AsyncInferenceResourceWithRawResponse(client.inference) + self.inference = inference.AsyncInferenceResourceWithRawResponse( + client.inference + ) self.memory = memory.AsyncMemoryResourceWithRawResponse(client.memory) - self.memory_banks = memory_banks.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks) + self.memory_banks = memory_banks.AsyncMemoryBanksResourceWithRawResponse( + client.memory_banks + ) self.models = models.AsyncModelsResourceWithRawResponse(client.models) - self.post_training = post_training.AsyncPostTrainingResourceWithRawResponse(client.post_training) - self.providers = providers.AsyncProvidersResourceWithRawResponse(client.providers) + self.post_training = post_training.AsyncPostTrainingResourceWithRawResponse( + client.post_training + ) + self.providers = providers.AsyncProvidersResourceWithRawResponse( + client.providers + ) self.routes = routes.AsyncRoutesResourceWithRawResponse(client.routes) self.safety = safety.AsyncSafetyResourceWithRawResponse(client.safety) self.shields = shields.AsyncShieldsResourceWithRawResponse(client.shields) self.synthetic_data_generation = synthetic_data_generation.AsyncSyntheticDataGenerationResourceWithRawResponse( client.synthetic_data_generation ) - self.telemetry = telemetry.AsyncTelemetryResourceWithRawResponse(client.telemetry) - self.datasetio = datasetio.AsyncDatasetioResourceWithRawResponse(client.datasetio) + self.telemetry = telemetry.AsyncTelemetryResourceWithRawResponse( + client.telemetry + ) + self.datasetio = datasetio.AsyncDatasetioResourceWithRawResponse( + client.datasetio + ) self.scoring = scoring.AsyncScoringResourceWithRawResponse(client.scoring) - self.scoring_functions = scoring_functions.AsyncScoringFunctionsResourceWithRawResponse( - client.scoring_functions + self.scoring_functions = ( + scoring_functions.AsyncScoringFunctionsResourceWithRawResponse( + client.scoring_functions + ) + ) + self.eval_tasks = eval_tasks.AsyncEvalTasksResourceWithRawResponse( + client.eval_tasks ) - self.eval_tasks = eval_tasks.AsyncEvalTasksResourceWithRawResponse(client.eval_tasks) class LlamaStackClientWithStreamedResponse: def __init__(self, client: LlamaStackClient) -> None: - self.toolgroups = toolgroups.ToolgroupsResourceWithStreamingResponse(client.toolgroups) + self.toolgroups = toolgroups.ToolgroupsResourceWithStreamingResponse( + client.toolgroups + ) self.tools = tools.ToolsResourceWithStreamingResponse(client.tools) - self.tool_runtime = tool_runtime.ToolRuntimeResourceWithStreamingResponse(client.tool_runtime) + self.tool_runtime = tool_runtime.ToolRuntimeResourceWithStreamingResponse( + client.tool_runtime + ) self.agents = agents.AgentsResourceWithStreamingResponse(client.agents) - self.batch_inference = batch_inference.BatchInferenceResourceWithStreamingResponse(client.batch_inference) + self.batch_inference = ( + batch_inference.BatchInferenceResourceWithStreamingResponse( + client.batch_inference + ) + ) self.datasets = datasets.DatasetsResourceWithStreamingResponse(client.datasets) self.eval = eval.EvalResourceWithStreamingResponse(client.eval) self.inspect = inspect.InspectResourceWithStreamingResponse(client.inspect) - self.inference = inference.InferenceResourceWithStreamingResponse(client.inference) + self.inference = inference.InferenceResourceWithStreamingResponse( + client.inference + ) self.memory = memory.MemoryResourceWithStreamingResponse(client.memory) - self.memory_banks = memory_banks.MemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.memory_banks = memory_banks.MemoryBanksResourceWithStreamingResponse( + client.memory_banks + ) self.models = models.ModelsResourceWithStreamingResponse(client.models) - self.post_training = post_training.PostTrainingResourceWithStreamingResponse(client.post_training) - self.providers = providers.ProvidersResourceWithStreamingResponse(client.providers) + self.post_training = post_training.PostTrainingResourceWithStreamingResponse( + client.post_training + ) + self.providers = providers.ProvidersResourceWithStreamingResponse( + client.providers + ) self.routes = routes.RoutesResourceWithStreamingResponse(client.routes) self.safety = safety.SafetyResourceWithStreamingResponse(client.safety) self.shields = shields.ShieldsResourceWithStreamingResponse(client.shields) self.synthetic_data_generation = synthetic_data_generation.SyntheticDataGenerationResourceWithStreamingResponse( client.synthetic_data_generation ) - self.telemetry = telemetry.TelemetryResourceWithStreamingResponse(client.telemetry) - self.datasetio = datasetio.DatasetioResourceWithStreamingResponse(client.datasetio) + self.telemetry = telemetry.TelemetryResourceWithStreamingResponse( + client.telemetry + ) + self.datasetio = datasetio.DatasetioResourceWithStreamingResponse( + client.datasetio + ) self.scoring = scoring.ScoringResourceWithStreamingResponse(client.scoring) - self.scoring_functions = scoring_functions.ScoringFunctionsResourceWithStreamingResponse( - client.scoring_functions + self.scoring_functions = ( + scoring_functions.ScoringFunctionsResourceWithStreamingResponse( + client.scoring_functions + ) + ) + self.eval_tasks = eval_tasks.EvalTasksResourceWithStreamingResponse( + client.eval_tasks ) - self.eval_tasks = eval_tasks.EvalTasksResourceWithStreamingResponse(client.eval_tasks) class AsyncLlamaStackClientWithStreamedResponse: def __init__(self, client: AsyncLlamaStackClient) -> None: - self.toolgroups = toolgroups.AsyncToolgroupsResourceWithStreamingResponse(client.toolgroups) + self.toolgroups = toolgroups.AsyncToolgroupsResourceWithStreamingResponse( + client.toolgroups + ) self.tools = tools.AsyncToolsResourceWithStreamingResponse(client.tools) - self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithStreamingResponse(client.tool_runtime) + self.tool_runtime = tool_runtime.AsyncToolRuntimeResourceWithStreamingResponse( + client.tool_runtime + ) self.agents = agents.AsyncAgentsResourceWithStreamingResponse(client.agents) - self.batch_inference = batch_inference.AsyncBatchInferenceResourceWithStreamingResponse(client.batch_inference) - self.datasets = datasets.AsyncDatasetsResourceWithStreamingResponse(client.datasets) + self.batch_inference = ( + batch_inference.AsyncBatchInferenceResourceWithStreamingResponse( + client.batch_inference + ) + ) + self.datasets = datasets.AsyncDatasetsResourceWithStreamingResponse( + client.datasets + ) self.eval = eval.AsyncEvalResourceWithStreamingResponse(client.eval) self.inspect = inspect.AsyncInspectResourceWithStreamingResponse(client.inspect) - self.inference = inference.AsyncInferenceResourceWithStreamingResponse(client.inference) + self.inference = inference.AsyncInferenceResourceWithStreamingResponse( + client.inference + ) self.memory = memory.AsyncMemoryResourceWithStreamingResponse(client.memory) - self.memory_banks = memory_banks.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.memory_banks = memory_banks.AsyncMemoryBanksResourceWithStreamingResponse( + client.memory_banks + ) self.models = models.AsyncModelsResourceWithStreamingResponse(client.models) - self.post_training = post_training.AsyncPostTrainingResourceWithStreamingResponse(client.post_training) - self.providers = providers.AsyncProvidersResourceWithStreamingResponse(client.providers) + self.post_training = ( + post_training.AsyncPostTrainingResourceWithStreamingResponse( + client.post_training + ) + ) + self.providers = providers.AsyncProvidersResourceWithStreamingResponse( + client.providers + ) self.routes = routes.AsyncRoutesResourceWithStreamingResponse(client.routes) self.safety = safety.AsyncSafetyResourceWithStreamingResponse(client.safety) self.shields = shields.AsyncShieldsResourceWithStreamingResponse(client.shields) - self.synthetic_data_generation = ( - synthetic_data_generation.AsyncSyntheticDataGenerationResourceWithStreamingResponse( - client.synthetic_data_generation - ) + self.synthetic_data_generation = synthetic_data_generation.AsyncSyntheticDataGenerationResourceWithStreamingResponse( + client.synthetic_data_generation + ) + self.telemetry = telemetry.AsyncTelemetryResourceWithStreamingResponse( + client.telemetry + ) + self.datasetio = datasetio.AsyncDatasetioResourceWithStreamingResponse( + client.datasetio ) - self.telemetry = telemetry.AsyncTelemetryResourceWithStreamingResponse(client.telemetry) - self.datasetio = datasetio.AsyncDatasetioResourceWithStreamingResponse(client.datasetio) self.scoring = scoring.AsyncScoringResourceWithStreamingResponse(client.scoring) - self.scoring_functions = scoring_functions.AsyncScoringFunctionsResourceWithStreamingResponse( - client.scoring_functions + self.scoring_functions = ( + scoring_functions.AsyncScoringFunctionsResourceWithStreamingResponse( + client.scoring_functions + ) + ) + self.eval_tasks = eval_tasks.AsyncEvalTasksResourceWithStreamingResponse( + client.eval_tasks ) - self.eval_tasks = eval_tasks.AsyncEvalTasksResourceWithStreamingResponse(client.eval_tasks) Client = LlamaStackClient diff --git a/src/llama_stack_client/resources/tool_runtime.py b/src/llama_stack_client/resources/tool_runtime.py index 94d04db7..5b579199 100644 --- a/src/llama_stack_client/resources/tool_runtime.py +++ b/src/llama_stack_client/resources/tool_runtime.py @@ -23,7 +23,7 @@ ) from .._base_client import make_request_options from ..types.tool_def import ToolDef -from ..types.shared_params.url import URL +from ..types.mcp_config_param import McpConfigParam from ..types.tool_invocation_result import ToolInvocationResult __all__ = ["ToolRuntimeResource", "AsyncToolRuntimeResource"] @@ -103,7 +103,7 @@ def list_tools( self, *, tool_group_id: str | NotGiven = NOT_GIVEN, - mcp_endpoint: URL | NotGiven = NOT_GIVEN, + mcp_config: McpConfigParam | NotGiven = NOT_GIVEN, x_llama_stack_client_version: str | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -135,9 +135,7 @@ def list_tools( } return self._post( "/alpha/tool-runtime/list-tools", - body=maybe_transform( - {"mcp_endpoint": mcp_endpoint}, tool_runtime_list_tools_params.ToolRuntimeListToolsParams - ), + body=maybe_transform({"mcp_config": mcp_config}, tool_runtime_list_tools_params.ToolRuntimeListToolsParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, @@ -225,7 +223,7 @@ async def list_tools( self, *, tool_group_id: str | NotGiven = NOT_GIVEN, - mcp_endpoint: URL | NotGiven = NOT_GIVEN, + mcp_config: McpConfigParam | NotGiven = NOT_GIVEN, x_llama_stack_client_version: str | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -258,7 +256,7 @@ async def list_tools( return await self._post( "/alpha/tool-runtime/list-tools", body=await async_maybe_transform( - {"mcp_endpoint": mcp_endpoint}, tool_runtime_list_tools_params.ToolRuntimeListToolsParams + {"mcp_config": mcp_config}, tool_runtime_list_tools_params.ToolRuntimeListToolsParams ), options=make_request_options( extra_headers=extra_headers, diff --git a/src/llama_stack_client/resources/toolgroups.py b/src/llama_stack_client/resources/toolgroups.py index 437158a5..22fddf57 100644 --- a/src/llama_stack_client/resources/toolgroups.py +++ b/src/llama_stack_client/resources/toolgroups.py @@ -6,7 +6,11 @@ import httpx -from ..types import toolgroup_get_params, toolgroup_register_params, toolgroup_unregister_params +from ..types import ( + toolgroup_get_params, + toolgroup_register_params, + toolgroup_unregister_params, +) from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._utils import ( maybe_transform, @@ -23,7 +27,7 @@ ) from .._base_client import make_request_options from ..types.tool_group import ToolGroup -from ..types.shared_params.url import URL +from ..types.mcp_config_param import McpConfigParam __all__ = ["ToolgroupsResource", "AsyncToolgroupsResource"] @@ -140,7 +144,7 @@ def register( provider_id: str, toolgroup_id: str, args: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, - mcp_endpoint: URL | NotGiven = NOT_GIVEN, + mcp_config: McpConfigParam | NotGiven = NOT_GIVEN, x_llama_stack_client_version: str | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -179,7 +183,7 @@ def register( "provider_id": provider_id, "toolgroup_id": toolgroup_id, "args": args, - "mcp_endpoint": mcp_endpoint, + "mcp_config": mcp_config, }, toolgroup_register_params.ToolgroupRegisterParams, ), @@ -350,7 +354,7 @@ async def register( provider_id: str, toolgroup_id: str, args: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, - mcp_endpoint: URL | NotGiven = NOT_GIVEN, + mcp_config: McpConfigParam | NotGiven = NOT_GIVEN, x_llama_stack_client_version: str | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -389,7 +393,7 @@ async def register( "provider_id": provider_id, "toolgroup_id": toolgroup_id, "args": args, - "mcp_endpoint": mcp_endpoint, + "mcp_config": mcp_config, }, toolgroup_register_params.ToolgroupRegisterParams, ), diff --git a/src/llama_stack_client/types/__init__.py b/src/llama_stack_client/types/__init__.py index be552ab7..e20fd754 100644 --- a/src/llama_stack_client/types/__init__.py +++ b/src/llama_stack_client/types/__init__.py @@ -26,6 +26,7 @@ from .shield import Shield as Shield from .tool_def import ToolDef as ToolDef from .eval_task import EvalTask as EvalTask +from .mcp_config import McpConfig as McpConfig from .route_info import RouteInfo as RouteInfo from .scoring_fn import ScoringFn as ScoringFn from .tool_group import ToolGroup as ToolGroup @@ -37,6 +38,7 @@ from .tool_def_param import ToolDefParam as ToolDefParam from .token_log_probs import TokenLogProbs as TokenLogProbs from .tool_get_params import ToolGetParams as ToolGetParams +from .mcp_config_param import McpConfigParam as McpConfigParam from .shield_call_step import ShieldCallStep as ShieldCallStep from .span_with_status import SpanWithStatus as SpanWithStatus from .tool_list_params import ToolListParams as ToolListParams diff --git a/src/llama_stack_client/types/mcp_config.py b/src/llama_stack_client/types/mcp_config.py new file mode 100644 index 00000000..0ae7e95d --- /dev/null +++ b/src/llama_stack_client/types/mcp_config.py @@ -0,0 +1,28 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel +from .shared.url import URL + +__all__ = ["McpConfig", "McpInlineConfig", "McpRemoteConfig"] + + +class McpInlineConfig(BaseModel): + command: str + + type: Literal["inline"] + + args: Optional[List[str]] = None + + env: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None + + +class McpRemoteConfig(BaseModel): + mcp_endpoint: URL + + type: Literal["remote"] + + +McpConfig: TypeAlias = Union[McpInlineConfig, McpRemoteConfig] diff --git a/src/llama_stack_client/types/mcp_config_param.py b/src/llama_stack_client/types/mcp_config_param.py new file mode 100644 index 00000000..34250c61 --- /dev/null +++ b/src/llama_stack_client/types/mcp_config_param.py @@ -0,0 +1,29 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +from .shared_params.url import URL + +__all__ = ["McpConfigParam", "McpInlineConfig", "McpRemoteConfig"] + + +class McpInlineConfig(TypedDict, total=False): + command: Required[str] + + type: Required[Literal["inline"]] + + args: List[str] + + env: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + + +class McpRemoteConfig(TypedDict, total=False): + mcp_endpoint: Required[URL] + + type: Required[Literal["remote"]] + + +McpConfigParam: TypeAlias = Union[McpInlineConfig, McpRemoteConfig] diff --git a/src/llama_stack_client/types/tool_group.py b/src/llama_stack_client/types/tool_group.py index 82d2e057..0af8311b 100644 --- a/src/llama_stack_client/types/tool_group.py +++ b/src/llama_stack_client/types/tool_group.py @@ -4,7 +4,7 @@ from typing_extensions import Literal from .._models import BaseModel -from .shared.url import URL +from .mcp_config import McpConfig __all__ = ["ToolGroup"] @@ -20,4 +20,4 @@ class ToolGroup(BaseModel): args: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None - mcp_endpoint: Optional[URL] = None + mcp_config: Optional[McpConfig] = None diff --git a/src/llama_stack_client/types/tool_runtime_list_tools_params.py b/src/llama_stack_client/types/tool_runtime_list_tools_params.py index fb273dd8..e42be717 100644 --- a/src/llama_stack_client/types/tool_runtime_list_tools_params.py +++ b/src/llama_stack_client/types/tool_runtime_list_tools_params.py @@ -5,7 +5,7 @@ from typing_extensions import Annotated, TypedDict from .._utils import PropertyInfo -from .shared_params.url import URL +from .mcp_config_param import McpConfigParam __all__ = ["ToolRuntimeListToolsParams"] @@ -13,7 +13,7 @@ class ToolRuntimeListToolsParams(TypedDict, total=False): tool_group_id: str - mcp_endpoint: URL + mcp_config: McpConfigParam x_llama_stack_client_version: Annotated[str, PropertyInfo(alias="X-LlamaStack-Client-Version")] diff --git a/src/llama_stack_client/types/toolgroup_register_params.py b/src/llama_stack_client/types/toolgroup_register_params.py index 880c4480..f937d5e7 100644 --- a/src/llama_stack_client/types/toolgroup_register_params.py +++ b/src/llama_stack_client/types/toolgroup_register_params.py @@ -6,7 +6,7 @@ from typing_extensions import Required, Annotated, TypedDict from .._utils import PropertyInfo -from .shared_params.url import URL +from .mcp_config_param import McpConfigParam __all__ = ["ToolgroupRegisterParams"] @@ -18,7 +18,7 @@ class ToolgroupRegisterParams(TypedDict, total=False): args: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - mcp_endpoint: URL + mcp_config: McpConfigParam x_llama_stack_client_version: Annotated[str, PropertyInfo(alias="X-LlamaStack-Client-Version")] diff --git a/tests/api_resources/test_tool_runtime.py b/tests/api_resources/test_tool_runtime.py index 76c23798..7d01cbda 100644 --- a/tests/api_resources/test_tool_runtime.py +++ b/tests/api_resources/test_tool_runtime.py @@ -79,7 +79,12 @@ def test_method_list_tools(self, client: LlamaStackClient) -> None: def test_method_list_tools_with_all_params(self, client: LlamaStackClient) -> None: tool_runtime = client.tool_runtime.list_tools( tool_group_id="tool_group_id", - mcp_endpoint={"uri": "uri"}, + mcp_config={ + "command": "command", + "type": "inline", + "args": ["string"], + "env": {"foo": True}, + }, x_llama_stack_client_version="X-LlamaStack-Client-Version", x_llama_stack_provider_data="X-LlamaStack-Provider-Data", ) @@ -174,7 +179,12 @@ async def test_method_list_tools(self, async_client: AsyncLlamaStackClient) -> N async def test_method_list_tools_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: tool_runtime = await async_client.tool_runtime.list_tools( tool_group_id="tool_group_id", - mcp_endpoint={"uri": "uri"}, + mcp_config={ + "command": "command", + "type": "inline", + "args": ["string"], + "env": {"foo": True}, + }, x_llama_stack_client_version="X-LlamaStack-Client-Version", x_llama_stack_provider_data="X-LlamaStack-Provider-Data", ) diff --git a/tests/api_resources/test_toolgroups.py b/tests/api_resources/test_toolgroups.py index 7fb4344a..26e936e0 100644 --- a/tests/api_resources/test_toolgroups.py +++ b/tests/api_resources/test_toolgroups.py @@ -118,7 +118,12 @@ def test_method_register_with_all_params(self, client: LlamaStackClient) -> None provider_id="provider_id", toolgroup_id="toolgroup_id", args={"foo": True}, - mcp_endpoint={"uri": "uri"}, + mcp_config={ + "command": "command", + "type": "inline", + "args": ["string"], + "env": {"foo": True}, + }, x_llama_stack_client_version="X-LlamaStack-Client-Version", x_llama_stack_provider_data="X-LlamaStack-Provider-Data", ) @@ -293,7 +298,12 @@ async def test_method_register_with_all_params(self, async_client: AsyncLlamaSta provider_id="provider_id", toolgroup_id="toolgroup_id", args={"foo": True}, - mcp_endpoint={"uri": "uri"}, + mcp_config={ + "command": "command", + "type": "inline", + "args": ["string"], + "env": {"foo": True}, + }, x_llama_stack_client_version="X-LlamaStack-Client-Version", x_llama_stack_provider_data="X-LlamaStack-Provider-Data", )