From 5d01674226e6eacde2508eb4b3f866f289eb421e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 20 Feb 2025 15:05:22 -0800 Subject: [PATCH] Sync updates from stainless branch: yanxi0830/dev --- .../resources/agents/turn.py | 186 ++++++++- .../types/agents/__init__.py | 1 + .../types/agents/turn_continue_params.py | 33 ++ .../agents/turn_response_event_payload.py | 9 + tests/api_resources/agents/test_turn.py | 384 ++++++++++++++++++ 5 files changed, 612 insertions(+), 1 deletion(-) create mode 100644 src/llama_stack_client/types/agents/turn_continue_params.py diff --git a/src/llama_stack_client/resources/agents/turn.py b/src/llama_stack_client/resources/agents/turn.py index da659e26..de634074 100644 --- a/src/llama_stack_client/resources/agents/turn.py +++ b/src/llama_stack_client/resources/agents/turn.py @@ -23,7 +23,7 @@ ) from ..._streaming import Stream, AsyncStream from ..._base_client import make_request_options -from ...types.agents import turn_create_params +from ...types.agents import turn_create_params, turn_continue_params from ...types.agents.turn import Turn from ...types.agents.agent_turn_response_stream_chunk import AgentTurnResponseStreamChunk @@ -225,6 +225,92 @@ def retrieve( cast_to=Turn, ) + @overload + def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + new_messages: Iterable[turn_continue_params.NewMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + new_messages: Iterable[turn_continue_params.NewMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "session_id", "new_messages"]) + def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + new_messages: Iterable[turn_continue_params.NewMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | Stream[AgentTurnResponseStreamChunk]: + if not agent_id: + raise ValueError(f"Expected a non-empty value for `agent_id` but received {agent_id!r}") + if not session_id: + raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}") + if not turn_id: + raise ValueError(f"Expected a non-empty value for `turn_id` but received {turn_id!r}") + return self._post( + f"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/continue", + body=maybe_transform({"new_messages": new_messages}, turn_continue_params.TurnContinueParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Turn, + stream=stream or False, + stream_cls=Stream[AgentTurnResponseStreamChunk], + ) + class AsyncTurnResource(AsyncAPIResource): @cached_property @@ -421,6 +507,92 @@ async def retrieve( cast_to=Turn, ) + @overload + async def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + new_messages: Iterable[turn_continue_params.NewMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + new_messages: Iterable[turn_continue_params.NewMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "session_id", "new_messages"]) + async def continue_( + self, + turn_id: str, + *, + agent_id: str, + session_id: str, + new_messages: Iterable[turn_continue_params.NewMessage], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Turn | AsyncStream[AgentTurnResponseStreamChunk]: + if not agent_id: + raise ValueError(f"Expected a non-empty value for `agent_id` but received {agent_id!r}") + if not session_id: + raise ValueError(f"Expected a non-empty value for `session_id` but received {session_id!r}") + if not turn_id: + raise ValueError(f"Expected a non-empty value for `turn_id` but received {turn_id!r}") + return await self._post( + f"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/continue", + body=await async_maybe_transform({"new_messages": new_messages}, turn_continue_params.TurnContinueParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Turn, + stream=stream or False, + stream_cls=AsyncStream[AgentTurnResponseStreamChunk], + ) + class TurnResourceWithRawResponse: def __init__(self, turn: TurnResource) -> None: @@ -432,6 +604,9 @@ def __init__(self, turn: TurnResource) -> None: self.retrieve = to_raw_response_wrapper( turn.retrieve, ) + self.continue_ = to_raw_response_wrapper( + turn.continue_, + ) class AsyncTurnResourceWithRawResponse: @@ -444,6 +619,9 @@ def __init__(self, turn: AsyncTurnResource) -> None: self.retrieve = async_to_raw_response_wrapper( turn.retrieve, ) + self.continue_ = async_to_raw_response_wrapper( + turn.continue_, + ) class TurnResourceWithStreamingResponse: @@ -456,6 +634,9 @@ def __init__(self, turn: TurnResource) -> None: self.retrieve = to_streamed_response_wrapper( turn.retrieve, ) + self.continue_ = to_streamed_response_wrapper( + turn.continue_, + ) class AsyncTurnResourceWithStreamingResponse: @@ -468,3 +649,6 @@ def __init__(self, turn: AsyncTurnResource) -> None: self.retrieve = async_to_streamed_response_wrapper( turn.retrieve, ) + self.continue_ = async_to_streamed_response_wrapper( + turn.continue_, + ) diff --git a/src/llama_stack_client/types/agents/__init__.py b/src/llama_stack_client/types/agents/__init__.py index be21f291..e1a831dd 100644 --- a/src/llama_stack_client/types/agents/__init__.py +++ b/src/llama_stack_client/types/agents/__init__.py @@ -6,6 +6,7 @@ from .session import Session as Session from .turn_create_params import TurnCreateParams as TurnCreateParams from .turn_response_event import TurnResponseEvent as TurnResponseEvent +from .turn_continue_params import TurnContinueParams as TurnContinueParams from .session_create_params import SessionCreateParams as SessionCreateParams from .step_retrieve_response import StepRetrieveResponse as StepRetrieveResponse from .session_create_response import SessionCreateResponse as SessionCreateResponse diff --git a/src/llama_stack_client/types/agents/turn_continue_params.py b/src/llama_stack_client/types/agents/turn_continue_params.py new file mode 100644 index 00000000..e58f0551 --- /dev/null +++ b/src/llama_stack_client/types/agents/turn_continue_params.py @@ -0,0 +1,33 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable +from typing_extensions import Required, TypeAlias, TypedDict + +from ..shared_params.user_message import UserMessage +from ..shared_params.tool_response_message import ToolResponseMessage + +__all__ = ["TurnContinueParamsBase", "NewMessage", "TurnContinueParamsNonStreaming"] + + +class TurnContinueParamsBase(TypedDict, total=False): + agent_id: Required[str] + + session_id: Required[str] + + new_messages: Required[Iterable[NewMessage]] + + +NewMessage: TypeAlias = Union[UserMessage, ToolResponseMessage] + + +class TurnContinueParamsNonStreaming(TurnContinueParamsBase, total=False): + pass + + +class TurnContinueParamsNonStreaming(TurnContinueParamsBase, total=False): + pass + + +TurnContinueParams = Union[TurnContinueParamsNonStreaming, TurnContinueParamsNonStreaming] diff --git a/src/llama_stack_client/types/agents/turn_response_event_payload.py b/src/llama_stack_client/types/agents/turn_response_event_payload.py index f12f8b03..e3315cb3 100644 --- a/src/llama_stack_client/types/agents/turn_response_event_payload.py +++ b/src/llama_stack_client/types/agents/turn_response_event_payload.py @@ -20,6 +20,7 @@ "AgentTurnResponseStepCompletePayloadStepDetails", "AgentTurnResponseTurnStartPayload", "AgentTurnResponseTurnCompletePayload", + "AgentTurnResponseTurnAwaitingInputPayload", ] @@ -72,6 +73,13 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): """A single turn in an interaction with an Agentic System.""" +class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): + event_type: Literal["turn_awaiting_input"] + + turn: Turn + """A single turn in an interaction with an Agentic System.""" + + TurnResponseEventPayload: TypeAlias = Annotated[ Union[ AgentTurnResponseStepStartPayload, @@ -79,6 +87,7 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): AgentTurnResponseStepCompletePayload, AgentTurnResponseTurnStartPayload, AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnAwaitingInputPayload, ], PropertyInfo(discriminator="event_type"), ] diff --git a/tests/api_resources/agents/test_turn.py b/tests/api_resources/agents/test_turn.py index b64bf957..46debc46 100644 --- a/tests/api_resources/agents/test_turn.py +++ b/tests/api_resources/agents/test_turn.py @@ -293,6 +293,198 @@ def test_path_params_retrieve(self, client: LlamaStackClient) -> None: session_id="session_id", ) + @parametrize + def test_method_continue_overload_1(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_raw_response_continue_overload_1(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_streaming_response_continue_overload_1(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_continue_overload_1(self, client: LlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + @parametrize + def test_method_continue_overload_2(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_raw_response_continue_overload_2(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_streaming_response_continue_overload_2(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_continue_overload_2(self, client: LlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + class TestAsyncTurn: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -572,3 +764,195 @@ async def test_path_params_retrieve(self, async_client: AsyncLlamaStackClient) - agent_id="agent_id", session_id="session_id", ) + + @parametrize + async def test_method_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_raw_response_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_streaming_response_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_continue_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + @parametrize + async def test_method_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_raw_response_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_streaming_response_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_continue_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `agent_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `session_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="turn_id", + agent_id="agent_id", + session_id="", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `turn_id` but received ''"): + await async_client.agents.turn.with_raw_response.continue_( + turn_id="", + agent_id="agent_id", + session_id="session_id", + new_messages=[ + { + "content": "string", + "role": "user", + } + ], + )