diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 567a92b4a..e9739f473 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -54,7 +54,7 @@ from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent -from ..types.agent import AgentInput +from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue @@ -129,6 +129,7 @@ def __init__( structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, + concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW, ): """Initialize the Agent with the specified configuration. @@ -186,6 +187,11 @@ def __init__( retry_strategy: Strategy for retrying model calls on throttling or other transient errors. Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. Implement a custom HookProvider for custom retry logic, or pass None to disable retries. + concurrent_invocation_mode: Mode controlling concurrent invocation behavior. + Defaults to "throw" which raises ConcurrencyException if concurrent invocation is attempted. + Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations. + Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided + only for advanced use cases where the caller understands the risks. Raises: ValueError: If agent id contains path separators. @@ -263,6 +269,7 @@ def __init__( # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads, so asyncio.Lock wouldn't work self._invocation_lock = threading.Lock() + self._concurrent_invocation_mode = concurrent_invocation_mode # In the future, we'll have a RetryStrategy base class but until # that API is determined we only allow ModelRetryStrategy @@ -622,14 +629,15 @@ async def stream_async( yield event["data"] ``` """ - # Acquire lock to prevent concurrent invocations + # Conditionally acquire lock based on concurrent_invocation_mode # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads - acquired = self._invocation_lock.acquire(blocking=False) - if not acquired: - raise ConcurrencyException( - "Agent is already processing a request. Concurrent invocations are not supported." - ) + if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: + lock_acquired = self._invocation_lock.acquire(blocking=False) + if not lock_acquired: + raise ConcurrencyException( + "Agent is already processing a request. Concurrent invocations are not supported." + ) try: self._interrupt_state.resume(prompt) @@ -678,7 +686,8 @@ async def stream_async( raise finally: - self._invocation_lock.release() + if self._invocation_lock.locked(): + self._invocation_lock.release() async def _run_loop( self, diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index aa69149a6..cda01f8aa 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -3,9 +3,26 @@ This module defines the types used for an Agent. """ +from enum import Enum from typing import TypeAlias from .content import ContentBlock, Messages from .interrupt import InterruptResponseContent AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None + + +class ConcurrentInvocationMode(str, Enum): + """Mode controlling concurrent invocation behavior. + + Values: + THROW: Raises ConcurrencyException if concurrent invocation is attempted (default). + UNSAFE_REENTRANT: Allows concurrent invocations without locking. + + Warning: + The ``UNSAFE_REENTRANT`` mode makes no guarantees about resulting behavior and is + provided only for advanced use cases where the caller understands the risks. + """ + + THROW = "throw" + UNSAFE_REENTRANT = "unsafe_reentrant" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index eb039185c..d95d26f92 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -26,6 +26,7 @@ from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent +from strands.types.agent import ConcurrentInvocationMode from strands.types.content import Messages from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -2231,20 +2232,17 @@ def test_agent_concurrent_call_raises_exception(): {"role": "assistant", "content": [{"text": "world"}]}, ] ) - agent = Agent(model=model) + agent = Agent(model=model, concurrent_invocation_mode="throw") results = [] errors = [] - lock = threading.Lock() def invoke(): try: result = agent("test") - with lock: - results.append(result) + results.append(result) except ConcurrencyException as e: - with lock: - errors.append(e) + errors.append(e) # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) @@ -2282,7 +2280,7 @@ def test_agent_concurrent_structured_output_raises_exception(): {"role": "assistant", "content": [{"text": "response2"}]}, ], ) - agent = Agent(model=model) + agent = Agent(model=model, concurrent_invocation_mode="throw") results = [] errors = [] @@ -2320,6 +2318,83 @@ def invoke(): assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() +def test_agent_concurrent_call_succeeds_with_unsafe_reentrant_mode(): + """Test that concurrent __call__() calls succeed when concurrent_invocation_mode is 'unsafe_reentrant'.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Start first thread and wait for it to begin streaming + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) + t2.start() + + # Let both threads proceed + model.proceed_event.set() + t1.join() + t2.join() + + # Both should succeed, no ConcurrencyException raised + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 successes, got {len(results)}" + + +def test_agent_concurrent_invocation_mode_default_is_throw(): + """Test that the default concurrent_invocation_mode is 'throw'.""" + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + agent = Agent(model=model) + + # Verify the default mode + assert agent._concurrent_invocation_mode == "throw" + + +def test_agent_concurrent_invocation_mode_stores_value(): + """Test that concurrent_invocation_mode is stored correctly as instance variable.""" + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + + agent_throw = Agent(model=model, concurrent_invocation_mode="throw") + assert agent_throw._concurrent_invocation_mode == "throw" + + agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" + + +def test_agent_concurrent_invocation_mode_accepts_enum(): + """Test that concurrent_invocation_mode accepts enum values as well as strings.""" + + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + + # Using enum values + agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW) + assert agent_throw._concurrent_invocation_mode == "throw" + assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW + + agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT) + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" + assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT + + @pytest.mark.asyncio async def test_agent_sequential_invocations_work(): """Test that sequential invocations work correctly after lock is released."""