2929
3030from .agents .active_streaming_tool import ActiveStreamingTool
3131from .agents .base_agent import BaseAgent
32+ from .agents .base_agent import BaseAgentState
3233from .agents .context_cache_config import ContextCacheConfig
3334from .agents .invocation_context import InvocationContext
3435from .agents .invocation_context import new_invocation_context_id
@@ -272,7 +273,8 @@ async def run_async(
272273 * ,
273274 user_id : str ,
274275 session_id : str ,
275- new_message : types .Content ,
276+ invocation_id : Optional [str ] = None ,
277+ new_message : Optional [types .Content ] = None ,
276278 state_delta : Optional [dict [str , Any ]] = None ,
277279 run_config : Optional [RunConfig ] = None ,
278280 ) -> AsyncGenerator [Event , None ]:
@@ -281,6 +283,8 @@ async def run_async(
281283 Args:
282284 user_id: The user ID of the session.
283285 session_id: The session ID of the session.
286+ invocation_id: The invocation ID of the session, set this to resume an
287+ interrupted invocation.
284288 new_message: A new message to append to the session.
285289 state_delta: Optional state changes to apply to the session.
286290 run_config: The run config for the agent.
@@ -289,29 +293,57 @@ async def run_async(
289293 The events generated by the agent.
290294
291295 Raises:
292- ValueError: If the session is not found.
296+ ValueError: If the session is not found; If both invocation_id and
297+ new_message are None.
293298 """
294299 run_config = run_config or RunConfig ()
295300
296- if not new_message .role :
301+ if new_message and not new_message .role :
297302 new_message .role = 'user'
298303
299304 async def _run_with_trace (
300- new_message : types .Content ,
305+ new_message : Optional [types .Content ] = None ,
306+ invocation_id : Optional [str ] = None ,
301307 ) -> AsyncGenerator [Event , None ]:
302308 with tracer .start_as_current_span ('invocation' ):
303309 session = await self .session_service .get_session (
304310 app_name = self .app_name , user_id = user_id , session_id = session_id
305311 )
306312 if not session :
307313 raise ValueError (f'Session not found: { session_id } ' )
308-
309- invocation_context = await self ._setup_context_for_new_invocation (
310- session = session ,
311- new_message = new_message ,
312- run_config = run_config ,
313- state_delta = state_delta ,
314- )
314+ if not invocation_id and not new_message :
315+ raise ValueError ('Both invocation_id and new_message are None.' )
316+
317+ if invocation_id :
318+ if (
319+ not self .resumability_config
320+ or not self .resumability_config .is_resumable
321+ ):
322+ raise ValueError (
323+ f'invocation_id: { invocation_id } is provided but the app is not'
324+ ' resumable.'
325+ )
326+ invocation_context = await self ._setup_context_for_resumed_invocation (
327+ session = session ,
328+ new_message = new_message ,
329+ invocation_id = invocation_id ,
330+ run_config = run_config ,
331+ state_delta = state_delta ,
332+ )
333+ if invocation_context .end_of_agents .get (self .agent .name ):
334+ # Directly return if the root agent has already ended.
335+ # TODO: Handle the case where the invocation-to-resume started from
336+ # a sub_agent:
337+ # invocation1: root_agent -> sub_agent1
338+ # invocation2: sub_agent1 [paused][resume]
339+ return
340+ else :
341+ invocation_context = await self ._setup_context_for_new_invocation (
342+ session = session ,
343+ new_message = new_message , # new_message is not None.
344+ run_config = run_config ,
345+ state_delta = state_delta ,
346+ )
315347
316348 async def execute (ctx : InvocationContext ) -> AsyncGenerator [Event ]:
317349 async with Aclosing (ctx .agent .run_async (ctx )) as agen :
@@ -329,7 +361,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
329361 async for event in agen :
330362 yield event
331363
332- async with Aclosing (_run_with_trace (new_message )) as agen :
364+ async with Aclosing (_run_with_trace (new_message , invocation_id )) as agen :
333365 async for event in agen :
334366 yield event
335367
@@ -462,6 +494,11 @@ async def _append_new_message_to_session(
462494 author = 'user' ,
463495 content = new_message ,
464496 )
497+ # If new_message is a function response, find the matching function call
498+ # and use its branch as the new event's branch.
499+ if function_call := invocation_context ._find_matching_function_call (event ):
500+ event .branch = function_call .branch
501+
465502 await self .session_service .append_event (session = session , event = event )
466503
467504 async def run_live (
@@ -692,10 +729,82 @@ async def _setup_context_for_new_invocation(
692729 invocation_context .agent = self ._find_agent_to_run (session , self .agent )
693730 return invocation_context
694731
732+ async def _setup_context_for_resumed_invocation (
733+ self ,
734+ * ,
735+ session : Session ,
736+ new_message : Optional [types .Content ],
737+ invocation_id : Optional [str ],
738+ run_config : RunConfig ,
739+ state_delta : Optional [dict [str , Any ]],
740+ ) -> InvocationContext :
741+ """Sets up the context for a resumed invocation.
742+
743+ Args:
744+ session: The session to setup the invocation context for.
745+ new_message: The new message to process and append to the session.
746+ invocation_id: The invocation id to resume.
747+ run_config: The run config of the agent.
748+ state_delta: Optional state changes to apply to the session.
749+
750+ Returns:
751+ The invocation context for the resumed invocation.
752+
753+ Raises:
754+ ValueError: If the session has no events to resume; If no user message is
755+ available for resuming the invocation; Or if the app is not resumable.
756+ """
757+ if not session .events :
758+ raise ValueError (f'Session { session .id } has no events to resume.' )
759+
760+ # Step 1: Maybe retrive a previous user message for the invocation.
761+ user_message = new_message or self ._find_user_message_for_invocation (
762+ session .events , invocation_id
763+ )
764+ if not user_message :
765+ raise ValueError (
766+ f'No user message available for resuming invocation: { invocation_id } '
767+ )
768+ # Step 2: Create invocation context.
769+ invocation_context = self ._new_invocation_context (
770+ session ,
771+ new_message = user_message ,
772+ run_config = run_config ,
773+ invocation_id = invocation_id ,
774+ )
775+ # Step 3: Maybe handle new message.
776+ if new_message :
777+ await self ._handle_new_message (
778+ session = session ,
779+ new_message = user_message ,
780+ invocation_context = invocation_context ,
781+ run_config = run_config ,
782+ state_delta = state_delta ,
783+ )
784+ # Step 4: Populate agent states for the current invocation.
785+ invocation_context .populate_invocation_agent_states ()
786+ return invocation_context
787+
788+ def _find_user_message_for_invocation (
789+ self , events : list [Event ], invocation_id : str
790+ ) -> Optional [types .Content ]:
791+ """Finds the user message that started a specific invocation."""
792+ for event in events :
793+ if (
794+ event .invocation_id == invocation_id
795+ and event .author == 'user'
796+ and event .content
797+ and event .content .parts
798+ and event .content .parts [0 ].text
799+ ):
800+ return event .content
801+ return None
802+
695803 def _new_invocation_context (
696804 self ,
697805 session : Session ,
698806 * ,
807+ invocation_id : Optional [str ] = None ,
699808 new_message : Optional [types .Content ] = None ,
700809 live_request_queue : Optional [LiveRequestQueue ] = None ,
701810 run_config : Optional [RunConfig ] = None ,
@@ -704,6 +813,7 @@ def _new_invocation_context(
704813
705814 Args:
706815 session: The session for the context.
816+ invocation_id: The invocation id for the context.
707817 new_message: The new message for the context.
708818 live_request_queue: The live request queue for the context.
709819 run_config: The run config for the context.
@@ -712,7 +822,7 @@ def _new_invocation_context(
712822 The new invocation context.
713823 """
714824 run_config = run_config or RunConfig ()
715- invocation_id = new_invocation_context_id ()
825+ invocation_id = invocation_id or new_invocation_context_id ()
716826
717827 if run_config .support_cfc and isinstance (self .agent , LlmAgent ):
718828 model_name = self .agent .canonical_model .model
0 commit comments