diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3718a29c5..8335cbd18 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -2,11 +2,17 @@ from . import agent, models, telemetry, types from .agent.agent import Agent +from .agent.serializers import JSONSerializer, PickleSerializer, StateSerializer +from .agent.state import AgentState from .tools.decorator import tool from .types.tools import ToolContext __all__ = [ "Agent", + "AgentState", + "JSONSerializer", + "PickleSerializer", + "StateSerializer", "agent", "models", "tool", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..017dc8a79 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -4,6 +4,7 @@ - Agent: The main interface for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows +- Serializers: Pluggable serialization strategies for agent state (JSONSerializer, PickleSerializer) """ from .agent import Agent @@ -14,12 +15,18 @@ SlidingWindowConversationManager, SummarizingConversationManager, ) +from .serializers import JSONSerializer, PickleSerializer, StateSerializer +from .state import AgentState __all__ = [ "Agent", "AgentResult", + "AgentState", "ConversationManager", + "JSONSerializer", "NullConversationManager", + "PickleSerializer", "SlidingWindowConversationManager", + "StateSerializer", "SummarizingConversationManager", ] diff --git a/src/strands/agent/serializers.py b/src/strands/agent/serializers.py new file mode 100644 index 000000000..6c2ab0991 --- /dev/null +++ b/src/strands/agent/serializers.py @@ -0,0 +1,159 @@ +"""State serializers for agent state management. + +This module provides pluggable serialization strategies for AgentState: +- JSONSerializer: Default serializer, backward compatible, validates on set() +- PickleSerializer: Supports any picklable Python object, validates on set() +- StateSerializer: Protocol for custom serializers +""" + +import copy +import json +import pickle +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class StateSerializer(Protocol): + """Protocol for state serializers. + + Custom serializers can implement this protocol to provide + alternative serialization strategies for agent state. + """ + + def serialize(self, data: dict[str, Any]) -> bytes: + """Serialize state dict to bytes. + + Args: + data: Dictionary of state data to serialize + + Returns: + Serialized state as bytes + """ + ... + + def deserialize(self, data: bytes) -> dict[str, Any]: + """Deserialize bytes back to state dict. + + Args: + data: Serialized state bytes + + Returns: + Deserialized state dictionary + """ + ... + + def validate(self, value: Any) -> None: + """Validate a value can be serialized. + + Serializers that accept any value should implement this as a no-op. + + Args: + value: The value to validate + + Raises: + ValueError: If value cannot be serialized by this serializer + """ + ... + + +class JSONSerializer: + """JSON-based state serializer. + + Default serializer that provides: + - Human-readable serialization format + - Validation on set() to maintain current behavior + - Backward compatibility with existing code + """ + + def serialize(self, data: dict[str, Any]) -> bytes: + """Serialize state dict to JSON bytes. + + Args: + data: Dictionary of state data to serialize + + Returns: + JSON serialized state as bytes + """ + return json.dumps(data).encode("utf-8") + + def deserialize(self, data: bytes) -> dict[str, Any]: + """Deserialize JSON bytes back to state dict. + + Args: + data: JSON serialized state bytes + + Returns: + Deserialized state dictionary + """ + result: dict[str, Any] = json.loads(data.decode("utf-8")) + return result + + def validate(self, value: Any) -> None: + """Validate that a value is JSON serializable. + + Args: + value: The value to validate + + Raises: + ValueError: If value is not JSON serializable + """ + try: + json.dumps(value) + except (TypeError, ValueError) as e: + raise ValueError( + f"Value is not JSON serializable: {type(value).__name__}. " + f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." + ) from e + + +class PickleSerializer: + """Pickle-based state serializer. + + Provides: + - Support for any picklable Python object (datetime, UUID, dataclass, Pydantic models, etc.) + - Validation on set() to catch unpicklable objects (DB connections, file handles, etc.) + + Security Warning: + Pickle can execute arbitrary code during deserialization. + Only unpickle data from trusted sources. + """ + + def serialize(self, data: dict[str, Any]) -> bytes: + """Serialize state dict using pickle. + + Args: + data: Dictionary of state data to serialize + + Returns: + Pickle serialized state as bytes + """ + return pickle.dumps(copy.deepcopy(data)) + + def deserialize(self, data: bytes) -> dict[str, Any]: + """Deserialize pickle bytes back to state dict. + + Args: + data: Pickle serialized state bytes + + Returns: + Deserialized state dictionary + """ + result: dict[str, Any] = pickle.loads(data) # noqa: S301 + return result + + def validate(self, value: Any) -> None: + """Validate that a value can be pickled. + + Args: + value: The value to validate + + Raises: + ValueError: If value cannot be pickled + """ + try: + pickle.dumps(value) + except TypeError as e: + raise ValueError( + f"Value is not picklable: {type(value).__name__}. " + f"Objects like database connections, file handles, and sockets cannot be serialized." + ) from e diff --git a/src/strands/agent/state.py b/src/strands/agent/state.py index c323041a3..2e660ff88 100644 --- a/src/strands/agent/state.py +++ b/src/strands/agent/state.py @@ -1,6 +1,147 @@ -"""Agent state management.""" +"""Agent state management. -from ..types.json_dict import JSONSerializableDict +Provides flexible state container with pluggable serialization. +""" -# Type alias for agent state -AgentState = JSONSerializableDict +import copy +from typing import Any + +from .serializers import JSONSerializer, StateSerializer + + +class AgentState: + """Flexible state container with pluggable serialization. + + AgentState provides a key-value store for agent state with: + - Pluggable serialization (JSON by default, Pickle for rich types) + - Backward compatible API with existing code + + Example: + Basic usage (backward compatible): + ```python + state = AgentState() + state.set("count", 42) + state.get("count") # Returns 42 + ``` + + Rich types with PickleSerializer: + ```python + from strands.agent.serializers import PickleSerializer + from datetime import datetime + + state = AgentState(serializer=PickleSerializer()) + state.set("created_at", datetime.now()) # Works with Pickle + ``` + """ + + def __init__( + self, + initial_state: dict[str, Any] | None = None, + serializer: StateSerializer | None = None, + ): + """Initialize AgentState. + + Args: + initial_state: Optional initial state dictionary + serializer: Serializer to use for state persistence. + Defaults to JSONSerializer for backward compatibility. + + Raises: + ValueError: If initial_state contains non-serializable values (with JSONSerializer) + """ + self._serializer = serializer if serializer is not None else JSONSerializer() + self._data: dict[str, Any] + + if initial_state: + # Validate initial state + self._serializer.validate(initial_state) + self._data = copy.deepcopy(initial_state) + else: + self._data = {} + + @property + def serializer(self) -> StateSerializer: + """Get the current serializer. + + Returns: + The serializer used for state persistence + """ + return self._serializer + + @serializer.setter + def serializer(self, value: StateSerializer) -> None: + """Set the serializer. + + Args: + value: New serializer to use for state persistence + """ + self._serializer = value + + def set(self, key: str, value: Any) -> None: + """Set a value in the store. + + Args: + key: The key to store the value under + value: The value to store + + Raises: + ValueError: If key is invalid, or if value is not serializable + """ + self._validate_key(key) + self._serializer.validate(value) + self._data[key] = copy.deepcopy(value) + + def get(self, key: str | None = None) -> Any: + """Get a value or entire data. + + Args: + key: The key to retrieve (if None, returns entire data dict) + + Returns: + The stored value, entire data dict, or None if not found + """ + if key is None: + return copy.deepcopy(self._data) + else: + return copy.deepcopy(self._data.get(key)) + + def delete(self, key: str) -> None: + """Delete a specific key from the store. + + Args: + key: The key to delete + """ + self._validate_key(key) + self._data.pop(key, None) + + def serialize(self) -> bytes: + """Serialize state. + + Returns: + Serialized state as bytes + """ + return self._serializer.serialize(self._data) + + def deserialize(self, data: bytes) -> None: + """Deserialize state. + + Args: + data: Serialized state bytes to restore + """ + self._data = self._serializer.deserialize(data) + + def _validate_key(self, key: str) -> None: + """Validate that a key is valid. + + Args: + key: The key to validate + + Raises: + ValueError: If key is invalid + """ + if key is None: + raise ValueError("Key cannot be None") + if not isinstance(key, str): + raise ValueError("Key must be a string") + if not key.strip(): + raise ValueError("Key cannot be empty") diff --git a/tests/strands/agent/test_agent_state.py b/tests/strands/agent/test_agent_state.py index bc2321a56..028441f42 100644 --- a/tests/strands/agent/test_agent_state.py +++ b/tests/strands/agent/test_agent_state.py @@ -1,8 +1,11 @@ """Tests for AgentState class.""" +from datetime import datetime +from uuid import uuid4 + import pytest -from strands import Agent, tool +from strands import Agent, JSONSerializer, PickleSerializer, tool from strands.agent.state import AgentState from strands.types.content import Messages @@ -143,3 +146,111 @@ def update_state(agent: Agent): assert agent.state.get("hello") == "world" assert agent.state.get("foo") == "baz" + + +def test_default_serializer_is_json(): + """Test that default serializer is JSONSerializer.""" + state = AgentState() + assert isinstance(state.serializer, JSONSerializer) + + +def test_pickle_serializer_allows_rich_types(): + """Test that PickleSerializer allows datetime, UUID, and other rich types.""" + state = AgentState(serializer=PickleSerializer()) + + # Rich types that don't work with JSONSerializer + now = datetime.now() + user_id = uuid4() + + state.set("created_at", now) + state.set("user_id", user_id) + state.set("config", {"nested": now}) + + assert state.get("created_at") == now + assert state.get("user_id") == user_id + + +def test_json_serializer_rejects_rich_types(): + """Test that JSONSerializer rejects datetime and other non-JSON types.""" + state = AgentState(serializer=JSONSerializer()) + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("created_at", datetime.now()) + + +def test_serialize_deserialize_json(): + """Test serialize and deserialize with JSONSerializer.""" + state = AgentState(serializer=JSONSerializer()) + state.set("name", "test") + state.set("count", 42) + + # Serialize + data = state.serialize() + assert isinstance(data, bytes) + + # Deserialize into new state + new_state = AgentState(serializer=JSONSerializer()) + new_state.deserialize(data) + + assert new_state.get("name") == "test" + assert new_state.get("count") == 42 + + +def test_serialize_deserialize_pickle(): + """Test serialize and deserialize with PickleSerializer.""" + state = AgentState(serializer=PickleSerializer()) + now = datetime.now() + user_id = uuid4() + state.set("created_at", now) + state.set("user_id", user_id) + + # Serialize + data = state.serialize() + assert isinstance(data, bytes) + + # Deserialize into new state + new_state = AgentState(serializer=PickleSerializer()) + new_state.deserialize(data) + + assert new_state.get("created_at") == now + assert new_state.get("user_id") == user_id + + +def test_serializer_property(): + """Test serializer property getter and setter.""" + state = AgentState(serializer=JSONSerializer()) + assert isinstance(state.serializer, JSONSerializer) + + # Change serializer + state.serializer = PickleSerializer() + assert isinstance(state.serializer, PickleSerializer) + + +def test_agent_state_with_pickle_allows_datetime(): + """Test using datetime in agent state with PickleSerializer.""" + agent_messages: Messages = [ + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + now = datetime.now() + agent = Agent( + model=mocked_model_provider, + state=AgentState(serializer=PickleSerializer()), + ) + + agent.state.set("created_at", now) + assert agent.state.get("created_at") == now + + +def test_pickle_serializer_rejects_unpicklable(): + """Test that PickleSerializer rejects unpicklable objects like DB connections.""" + import sqlite3 + + state = AgentState(serializer=PickleSerializer()) + conn = sqlite3.connect(":memory:") + + with pytest.raises(ValueError, match="not picklable"): + state.set("connection", conn) + + conn.close()