Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from ..tools.structured_output._structured_output_context import StructuredOutputContext
from ..tools.watcher import ToolWatcher
from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
from ..types.agent import AgentInput
from ..types.agent import AgentInput, ConcurrentInvocationMode
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException
from ..types.traces import AttributeValue
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(
structured_output_prompt: str | None = None,
tool_executor: ToolExecutor | None = None,
retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY,
concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW,
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -186,6 +187,11 @@ def __init__(
retry_strategy: Strategy for retrying model calls on throttling or other transient errors.
Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s.
Implement a custom HookProvider for custom retry logic, or pass None to disable retries.
concurrent_invocation_mode: Mode controlling concurrent invocation behavior.
Defaults to "throw" which raises ConcurrencyException if concurrent invocation is attempted.
Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations.
Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided
only for advanced use cases where the caller understands the risks.

Raises:
ValueError: If agent id contains path separators.
Expand Down Expand Up @@ -263,6 +269,7 @@ def __init__(
# Using threading.Lock instead of asyncio.Lock because run_async() creates
# separate event loops in different threads, so asyncio.Lock wouldn't work
self._invocation_lock = threading.Lock()
self._concurrent_invocation_mode = concurrent_invocation_mode

# In the future, we'll have a RetryStrategy base class but until
# that API is determined we only allow ModelRetryStrategy
Expand Down Expand Up @@ -622,14 +629,15 @@ async def stream_async(
yield event["data"]
```
"""
# Acquire lock to prevent concurrent invocations
# Conditionally acquire lock based on concurrent_invocation_mode
# Using threading.Lock instead of asyncio.Lock because run_async() creates
# separate event loops in different threads
acquired = self._invocation_lock.acquire(blocking=False)
if not acquired:
raise ConcurrencyException(
"Agent is already processing a request. Concurrent invocations are not supported."
)
if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW:
lock_acquired = self._invocation_lock.acquire(blocking=False)
if not lock_acquired:
raise ConcurrencyException(
"Agent is already processing a request. Concurrent invocations are not supported."
)

try:
self._interrupt_state.resume(prompt)
Expand Down Expand Up @@ -678,7 +686,8 @@ async def stream_async(
raise

finally:
self._invocation_lock.release()
if self._invocation_lock.locked():
self._invocation_lock.release()

async def _run_loop(
self,
Expand Down
17 changes: 17 additions & 0 deletions src/strands/types/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,26 @@
This module defines the types used for an Agent.
"""

from enum import Enum
from typing import TypeAlias

from .content import ContentBlock, Messages
from .interrupt import InterruptResponseContent

AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None


class ConcurrentInvocationMode(str, Enum):
"""Mode controlling concurrent invocation behavior.

Values:
THROW: Raises ConcurrencyException if concurrent invocation is attempted (default).
UNSAFE_REENTRANT: Allows concurrent invocations without locking.

Warning:
The ``UNSAFE_REENTRANT`` mode makes no guarantees about resulting behavior and is
provided only for advanced use cases where the caller understands the risks.
"""

THROW = "throw"
UNSAFE_REENTRANT = "unsafe_reentrant"
89 changes: 82 additions & 7 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from strands.session.repository_session_manager import RepositorySessionManager
from strands.telemetry.tracer import serialize
from strands.types._events import EventLoopStopEvent, ModelStreamEvent
from strands.types.agent import ConcurrentInvocationMode
from strands.types.content import Messages
from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException
from strands.types.session import Session, SessionAgent, SessionMessage, SessionType
Expand Down Expand Up @@ -2231,20 +2232,17 @@ def test_agent_concurrent_call_raises_exception():
{"role": "assistant", "content": [{"text": "world"}]},
]
)
agent = Agent(model=model)
agent = Agent(model=model, concurrent_invocation_mode="throw")

results = []
errors = []
lock = threading.Lock()

def invoke():
try:
result = agent("test")
with lock:
results.append(result)
results.append(result)
except ConcurrencyException as e:
with lock:
errors.append(e)
errors.append(e)

# Start first thread and wait for it to begin streaming
t1 = threading.Thread(target=invoke)
Expand Down Expand Up @@ -2282,7 +2280,7 @@ def test_agent_concurrent_structured_output_raises_exception():
{"role": "assistant", "content": [{"text": "response2"}]},
],
)
agent = Agent(model=model)
agent = Agent(model=model, concurrent_invocation_mode="throw")

results = []
errors = []
Expand Down Expand Up @@ -2320,6 +2318,83 @@ def invoke():
assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower()


def test_agent_concurrent_call_succeeds_with_unsafe_reentrant_mode():
"""Test that concurrent __call__() calls succeed when concurrent_invocation_mode is 'unsafe_reentrant'."""
model = SyncEventMockedModel(
[
{"role": "assistant", "content": [{"text": "hello"}]},
{"role": "assistant", "content": [{"text": "world"}]},
]
)
agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant")

results = []
errors = []
lock = threading.Lock()

def invoke():
try:
result = agent("test")
with lock:
results.append(result)
except ConcurrencyException as e:
with lock:
errors.append(e)

# Start first thread and wait for it to begin streaming
t1 = threading.Thread(target=invoke)
t1.start()
model.started_event.wait() # Wait until first thread is in the model.stream()

# Start second thread while first is still running
t2 = threading.Thread(target=invoke)
t2.start()

# Let both threads proceed
model.proceed_event.set()
t1.join()
t2.join()

# Both should succeed, no ConcurrencyException raised
assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}"
assert len(results) == 2, f"Expected 2 successes, got {len(results)}"


def test_agent_concurrent_invocation_mode_default_is_throw():
"""Test that the default concurrent_invocation_mode is 'throw'."""
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])
agent = Agent(model=model)

# Verify the default mode
assert agent._concurrent_invocation_mode == "throw"


def test_agent_concurrent_invocation_mode_stores_value():
"""Test that concurrent_invocation_mode is stored correctly as instance variable."""
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])

agent_throw = Agent(model=model, concurrent_invocation_mode="throw")
assert agent_throw._concurrent_invocation_mode == "throw"

agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant")
assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant"


def test_agent_concurrent_invocation_mode_accepts_enum():
"""Test that concurrent_invocation_mode accepts enum values as well as strings."""

model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])

# Using enum values
agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW)
assert agent_throw._concurrent_invocation_mode == "throw"
assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW

agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT)
assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant"
assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT


@pytest.mark.asyncio
async def test_agent_sequential_invocations_work():
"""Test that sequential invocations work correctly after lock is released."""
Expand Down
Loading