From b5742c598a8d5fe49ec0cc87662466318d16f951 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Tue, 27 Jan 2026 17:57:27 -0800 Subject: [PATCH 1/9] Support specifying types via handler and executor decorators --- .../agent_framework/_workflows/_executor.py | 143 +++++++++---- .../_workflows/_function_executor.py | 76 ++++++- .../_workflows/_typing_utils.py | 35 ++- .../core/tests/workflow/test_executor.py | 201 ++++++++++++++++++ .../tests/workflow/test_function_executor.py | 200 +++++++++++++++++ .../core/tests/workflow/test_typing_utils.py | 70 +++++- .../_start-here/step1_executors_and_edges.py | 95 ++++++++- 7 files changed, 761 insertions(+), 59 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 49f3dafd06..127466ea8e 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -5,8 +5,9 @@ import functools import inspect import logging +import types from collections.abc import Awaitable, Callable -from typing import Any, TypeVar +from typing import Any, TypeVar, overload from ..observability import create_processing_span from ._events import ( @@ -326,34 +327,37 @@ def _discover_handlers(self) -> None: """Discover message handlers in the executor class.""" # Use __class__.__dict__ to avoid accessing pydantic's dynamic attributes for attr_name in dir(self.__class__): + # Narrow the exception scope - only catch AttributeError when accessing the attribute try: attr = getattr(self.__class__, attr_name) - # Discover @handler methods - if callable(attr) and hasattr(attr, "_handler_spec"): - handler_spec = attr._handler_spec # type: ignore - message_type = handler_spec["message_type"] - - # Keep full generic types for handler registration to avoid conflicts - if self._handlers.get(message_type) is not None: - raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}") - - # Get the bound method - bound_method = getattr(self, attr_name) - self._handlers[message_type] = bound_method - - # Add to unified handler specs list - self._handler_specs.append({ - "name": handler_spec["name"], - "message_type": message_type, - "output_types": handler_spec.get("output_types", []), - "workflow_output_types": handler_spec.get("workflow_output_types", []), - "ctx_annotation": handler_spec.get("ctx_annotation"), - "source": "class_method", # Distinguish from instance handlers if needed - }) except AttributeError: - # Skip attributes that may not be accessible + # Skip attributes that may not be accessible (e.g., dynamic descriptors) + logger.debug(f"Could not access attribute {attr_name} on {self.__class__.__name__}") continue + # Discover @handler methods - let AttributeError propagate for malformed handler specs + if callable(attr) and hasattr(attr, "_handler_spec"): + handler_spec = attr._handler_spec # type: ignore + message_type = handler_spec["message_type"] + + # Keep full generic types for handler registration to avoid conflicts + if self._handlers.get(message_type) is not None: + raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}") + + # Get the bound method + bound_method = getattr(self, attr_name) + self._handlers[message_type] = bound_method + + # Add to unified handler specs list + self._handler_specs.append({ + "name": handler_spec["name"], + "message_type": message_type, + "output_types": handler_spec.get("output_types", []), + "workflow_output_types": handler_spec.get("workflow_output_types", []), + "ctx_annotation": handler_spec.get("ctx_annotation"), + "source": "class_method", # Distinguish from instance handlers if needed + }) + def can_handle(self, message: Message) -> bool: """Check if the executor can handle a given message type. @@ -529,35 +533,88 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: ContextT = TypeVar("ContextT", bound="WorkflowContext[Any, Any]") +@overload def handler( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], -) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: +) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: ... + + +@overload +def handler( + *, + input_type: type | types.UnionType | None = None, + output_type: type | types.UnionType | None = None, +) -> Callable[ + [Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]], + Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], +]: ... + + +def handler( + func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | None = None, + *, + input_type: type | types.UnionType | None = None, + output_type: type | types.UnionType | None = None, +) -> ( + Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] + | Callable[ + [Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]], + Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], + ] +): """Decorator to register a handler for an executor. Args: - func: The function to decorate. Can be None when used without parameters. + func: The function to decorate. Can be None when used with parameters. + input_type: Optional explicit input type(s) for this handler. Supports union types + (e.g., ``str | int``). When provided, takes precedence over introspection from + the function's message parameter annotation. + output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + Supports union types (e.g., ``str | int``). When provided, takes precedence over + introspection from the ``WorkflowContext`` generic parameters. Returns: The decorated function with handler metadata. Example: + # Using introspection (existing behavior) @handler async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: ... - @handler - async def handle_data(self, message: dict, ctx: WorkflowContext[str | int]) -> None: + # Using explicit types (takes precedence over introspection) + @handler(input_type=str | int, output_type=bool) + async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: + ... + + # Only specifying input_type (output_type falls back to introspection) + @handler(input_type=MyCustomType) + async def handle_custom(self, message: Any, ctx: WorkflowContext[str]) -> None: ... """ + from ._typing_utils import normalize_type_to_list def decorator( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: # Extract the message type and validate using unified validation - message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( - _validate_handler_signature(func) + introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( + _validate_handler_signature(func, skip_message_annotation=input_type is not None) ) + # Use explicit types if provided, otherwise fall back to introspection + message_type = input_type if input_type is not None else introspected_message_type + + # Validate that we have a message type - this should never happen if signature + # validation passed, but provides a clear error if type information is missing + if message_type is None: + raise ValueError( + f"Handler {func.__name__} requires either a message parameter type annotation " + "or an explicit input_type parameter" + ) + + final_output_types = normalize_type_to_list(output_type) if output_type is not None else inferred_output_types + # Get signature for preservation sig = inspect.signature(func) @@ -574,14 +631,19 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: "name": func.__name__, "message_type": message_type, # Keep output_types and workflow_output_types in spec for validators - "output_types": inferred_output_types, + "output_types": final_output_types, "workflow_output_types": inferred_workflow_output_types, "ctx_annotation": ctx_annotation, } return wrapper - return decorator(func) + # Handle both @handler and @handler(...) usage patterns + if func is not None: + # Called as @handler without parentheses + return decorator(func) + # Called as @handler(...) with parentheses + return decorator # endregion: Handler Decorator @@ -589,14 +651,21 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: # region Handler Validation -def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: +def _validate_handler_signature( + func: Callable[..., Any], + *, + skip_message_annotation: bool = False, +) -> tuple[type | None, Any, list[type[Any]], list[type[Any]]]: """Validate function signature for executor functions. Args: func: The function to validate + skip_message_annotation: If True, skip validation that message parameter has a type + annotation. Used when input_type is explicitly provided to the @handler decorator. Returns: - Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + Tuple of (message_type, ctx_annotation, output_types, workflow_output_types). + message_type may be None if skip_message_annotation is True and no annotation exists. Raises: ValueError: If the function signature is invalid @@ -609,9 +678,9 @@ def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, li if len(params) != expected_counts: raise ValueError(f"Handler {func.__name__} must have {param_description}. Got {len(params)} parameters.") - # Check message parameter has type annotation + # Check message parameter has type annotation (unless skipped) message_param = params[1] - if message_param.annotation == inspect.Parameter.empty: + if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty: raise ValueError(f"Handler {func.__name__} must have a type annotation for the message parameter") # Validate ctx parameter is WorkflowContext and extract type args @@ -620,7 +689,7 @@ def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, li ctx_param.annotation, f"parameter '{ctx_param.name}'", "Handler" ) - message_type = message_param.annotation + message_type = message_param.annotation if message_param.annotation != inspect.Parameter.empty else None ctx_annotation = ctx_param.annotation return message_type, ctx_annotation, output_types, workflow_output_types diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index d7b68c10fd..f6167976d3 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -18,11 +18,13 @@ import asyncio import inspect import sys +import types import typing from collections.abc import Awaitable, Callable from typing import Any from ._executor import Executor +from ._typing_utils import normalize_type_to_list from ._workflow_context import WorkflowContext, validate_workflow_context_annotation if sys.version_info >= (3, 11): @@ -41,12 +43,25 @@ class FunctionExecutor(Executor): blocking the event loop. """ - def __init__(self, func: Callable[..., Any], id: str | None = None): + def __init__( + self, + func: Callable[..., Any], + id: str | None = None, + *, + input_type: type | types.UnionType | None = None, + output_type: type | types.UnionType | None = None, + ): """Initialize the FunctionExecutor with a user-defined function. Args: func: The function to wrap as an executor (can be sync or async) id: Optional executor ID. If None, uses the function name. + input_type: Optional explicit input type(s) for this executor. Supports union types + (e.g., ``str | int``). When provided, takes precedence over introspection from + the function's message parameter annotation. + output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + Supports union types (e.g., ``str | int``). When provided, takes precedence over + introspection from the ``WorkflowContext`` generic parameters. Raises: ValueError: If func is a staticmethod or classmethod (use @handler on instance methods instead) @@ -61,7 +76,20 @@ def __init__(self, func: Callable[..., Any], id: str | None = None): ) # Validate function signature and extract types - message_type, ctx_annotation, output_types, workflow_output_types = _validate_function_signature(func) + introspected_message_type, ctx_annotation, inferred_output_types, workflow_output_types = ( + _validate_function_signature(func, skip_message_annotation=input_type is not None) + ) + + # Use explicit types if provided, otherwise fall back to introspection + message_type = input_type if input_type is not None else introspected_message_type + output_types = normalize_type_to_list(output_type) if output_type is not None else inferred_output_types + + # Validate that we have a message type - provides a clear error if type information is missing + if message_type is None: + raise ValueError( + f"Function {func.__name__} requires either a message parameter type annotation " + "or an explicit input_type parameter" + ) # Store the original function self._original_func = func @@ -127,11 +155,20 @@ def executor(func: Callable[..., Any]) -> FunctionExecutor: ... @overload -def executor(*, id: str | None = None) -> Callable[[Callable[..., Any]], FunctionExecutor]: ... +def executor( + *, + id: str | None = None, + input_type: type | types.UnionType | None = None, + output_type: type | types.UnionType | None = None, +) -> Callable[[Callable[..., Any]], FunctionExecutor]: ... def executor( - func: Callable[..., Any] | None = None, *, id: str | None = None + func: Callable[..., Any] | None = None, + *, + id: str | None = None, + input_type: type | types.UnionType | None = None, + output_type: type | types.UnionType | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor] | FunctionExecutor: """Decorator that converts a standalone function into a FunctionExecutor instance. @@ -162,6 +199,12 @@ def process_data(data: str): return data.upper() + # Using explicit types (takes precedence over introspection): + @executor(id="my_executor", input_type=str | int, output_type=bool) + async def process(message: Any, ctx: WorkflowContext): + await ctx.send_message(True) + + # For class-based executors, use @handler instead: class MyExecutor(Executor): def __init__(self): @@ -174,6 +217,12 @@ async def process(self, data: str, ctx: WorkflowContext[str]): Args: func: The function to decorate (when used without parentheses) id: Optional custom ID for the executor. If None, uses the function name. + input_type: Optional explicit input type(s) for this executor. Supports union types + (e.g., ``str | int``). When provided, takes precedence over introspection from + the function's message parameter annotation. + output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + Supports union types (e.g., ``str | int``). When provided, takes precedence over + introspection from the ``WorkflowContext`` generic parameters. Returns: A FunctionExecutor instance that can be wired into a Workflow. @@ -183,7 +232,7 @@ async def process(self, data: str, ctx: WorkflowContext[str]): """ def wrapper(func: Callable[..., Any]) -> FunctionExecutor: - return FunctionExecutor(func, id=id) + return FunctionExecutor(func, id=id, input_type=input_type, output_type=output_type) # If func is provided, this means @executor was used without parentheses if func is not None: @@ -198,14 +247,21 @@ def wrapper(func: Callable[..., Any]) -> FunctionExecutor: # region Function Validation -def _validate_function_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: +def _validate_function_signature( + func: Callable[..., Any], + *, + skip_message_annotation: bool = False, +) -> tuple[type | None, Any, list[type[Any]], list[type[Any]]]: """Validate function signature for executor functions. Args: func: The function to validate + skip_message_annotation: If True, skip validation that message parameter has a type + annotation. Used when input_type is explicitly provided to the @executor decorator. Returns: - Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + Tuple of (message_type, ctx_annotation, output_types, workflow_output_types). + message_type may be None if skip_message_annotation is True and no annotation exists. Raises: ValueError: If the function signature is invalid @@ -220,13 +276,15 @@ def _validate_function_signature(func: Callable[..., Any]) -> tuple[type, Any, l f"Function instance {func.__name__} must have {param_description}. Got {len(params)} parameters." ) - # Check message parameter has type annotation + # Check message parameter has type annotation (unless skipped) message_param = params[0] - if message_param.annotation == inspect.Parameter.empty: + if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty: raise ValueError(f"Function instance {func.__name__} must have a type annotation for the message parameter") type_hints = typing.get_type_hints(func) message_type = type_hints.get(message_param.name, message_param.annotation) + if message_type == inspect.Parameter.empty: + message_type = None # Check if there's a context parameter if len(params) == 2: diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 5619fb9bf3..9516d29438 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -1,14 +1,45 @@ # Copyright (c) Microsoft. All rights reserved. -import logging from types import UnionType from typing import Any, TypeVar, Union, cast, get_args, get_origin -logger = logging.getLogger(__name__) +from agent_framework import get_logger + +logger = get_logger("agent_framework._workflows._typing_utils") T = TypeVar("T") +def normalize_type_to_list(type_annotation: type[Any] | UnionType | None) -> list[type[Any]]: + """Normalize a type annotation (possibly a union) to a list of concrete types. + + Args: + type_annotation: A type, union type (using | or Union[]), or None + + Returns: + A list of types. For union types, returns all members. + For None, returns an empty list. + For Optional[T] (Union[T, None]), returns [T, type(None)]. + + Examples: + - normalize_type_to_list(str) -> [str] + - normalize_type_to_list(str | int) -> [str, int] + - normalize_type_to_list(Union[str, int]) -> [str, int] + - normalize_type_to_list(None) -> [] + """ + if type_annotation is None: + return [] + + origin = get_origin(type_annotation) + + # Handle Union types (str | int or Union[str, int]) + if origin is Union or origin is UnionType: + return list(get_args(type_annotation)) + + # Single type + return [type_annotation] + + def is_instance_of(data: Any, target_type: type | UnionType | Any) -> bool: """Check if the data is an instance of the target type. diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index a812f6dae6..ab5de9b6d5 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -538,3 +538,204 @@ async def mutator(messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMes f"{[m.text for m in mutator_invoked.data]}" ) assert mutator_invoked.data[0].text == "hello" + + +# region: Tests for @handler decorator with explicit input_type and output_type + + +class TestHandlerExplicitTypes: + """Test suite for @handler decorator with explicit input_type and output_type parameters.""" + + def test_handler_with_explicit_input_type(self): + """Test that explicit input_type takes precedence over introspection.""" + from typing import Any + + class ExplicitInputExecutor(Executor): + @handler(input_type=str) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = ExplicitInputExecutor(id="explicit_input") + + # Handler should be registered for str (explicit), not Any (introspected) + assert str in exec_instance._handlers + assert len(exec_instance._handlers) == 1 + + # Can handle str messages + assert exec_instance.can_handle(Message(data="hello", source_id="mock")) + # Cannot handle int messages (since explicit type is str) + assert not exec_instance.can_handle(Message(data=42, source_id="mock")) + + def test_handler_with_explicit_output_type(self): + """Test that explicit output_type takes precedence over introspection.""" + + class ExplicitOutputExecutor(Executor): + @handler(output_type=int) + async def handle(self, message: str, ctx: WorkflowContext[str]) -> None: + pass + + exec_instance = ExplicitOutputExecutor(id="explicit_output") + + # Handler spec should have int as output type (explicit), not str (introspected) + handler_func = exec_instance._handlers[str] + assert handler_func._handler_spec["output_types"] == [int] + + # Executor output_types property should reflect explicit type + assert int in exec_instance.output_types + assert str not in exec_instance.output_types + + def test_handler_with_explicit_input_and_output_types(self): + """Test that both explicit input_type and output_type work together.""" + from typing import Any + + class ExplicitBothExecutor(Executor): + @handler(input_type=dict, output_type=list) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = ExplicitBothExecutor(id="explicit_both") + + # Handler should be registered for dict (explicit input type) + assert dict in exec_instance._handlers + assert len(exec_instance._handlers) == 1 + + # Output type should be list (explicit) + handler_func = exec_instance._handlers[dict] + assert handler_func._handler_spec["output_types"] == [list] + + # Verify can_handle + assert exec_instance.can_handle(Message(data={"key": "value"}, source_id="mock")) + assert not exec_instance.can_handle(Message(data="string", source_id="mock")) + + def test_handler_with_explicit_union_input_type(self): + """Test that explicit union input_type is handled correctly.""" + from typing import Any + + class UnionInputExecutor(Executor): + @handler(input_type=str | int) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = UnionInputExecutor(id="union_input") + + # Handler should be registered for the union type + # The union type itself is stored as the key + assert len(exec_instance._handlers) == 1 + + # Can handle both str and int messages + assert exec_instance.can_handle(Message(data="hello", source_id="mock")) + assert exec_instance.can_handle(Message(data=42, source_id="mock")) + # Cannot handle float + assert not exec_instance.can_handle(Message(data=3.14, source_id="mock")) + + def test_handler_with_explicit_union_output_type(self): + """Test that explicit union output_type is normalized to a list.""" + from typing import Any + + class UnionOutputExecutor(Executor): + @handler(output_type=str | int | bool) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = UnionOutputExecutor(id="union_output") + + # Output types should be a list with all union members + assert set(exec_instance.output_types) == {str, int, bool} + + def test_handler_explicit_types_precedence_over_introspection(self): + """Test that explicit types always take precedence over introspected types.""" + + class PrecedenceExecutor(Executor): + # Introspection would give: input=str, output=[int] + # Explicit gives: input=bytes, output=[float] + @handler(input_type=bytes, output_type=float) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_instance = PrecedenceExecutor(id="precedence") + + # Should use explicit input type (bytes), not introspected (str) + assert bytes in exec_instance._handlers + assert str not in exec_instance._handlers + + # Should use explicit output type (float), not introspected (int) + assert float in exec_instance.output_types + assert int not in exec_instance.output_types + + def test_handler_fallback_to_introspection_when_no_explicit_types(self): + """Test that introspection is used when no explicit types are provided.""" + + class IntrospectedExecutor(Executor): + @handler + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_instance = IntrospectedExecutor(id="introspected") + + # Should use introspected types + assert str in exec_instance._handlers + assert int in exec_instance.output_types + + def test_handler_partial_explicit_types(self): + """Test that partial explicit types work (only input_type or only output_type).""" + + # Only explicit input_type, introspect output_type + class OnlyInputExecutor(Executor): + @handler(input_type=bytes) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_input = OnlyInputExecutor(id="only_input") + assert bytes in exec_input._handlers # Explicit + assert int in exec_input.output_types # Introspected + + # Only explicit output_type, introspect input_type + class OnlyOutputExecutor(Executor): + @handler(output_type=float) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_output = OnlyOutputExecutor(id="only_output") + assert str in exec_output._handlers # Introspected + assert float in exec_output.output_types # Explicit + assert int not in exec_output.output_types # Not introspected when explicit provided + + def test_handler_explicit_input_type_allows_no_message_annotation(self): + """Test that explicit input_type allows handler without message type annotation.""" + + class NoAnnotationExecutor(Executor): + @handler(input_type=str) + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = NoAnnotationExecutor(id="no_annotation") + + # Should work with explicit input_type + assert str in exec_instance._handlers + assert exec_instance.can_handle(Message(data="hello", source_id="mock")) + + def test_handler_multiple_handlers_mixed_explicit_and_introspected(self): + """Test executor with multiple handlers, some with explicit types and some introspected.""" + + class MixedExecutor(Executor): + @handler(input_type=str, output_type=int) + async def handle_explicit(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + @handler + async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None: + pass + + exec_instance = MixedExecutor(id="mixed") + + # Should have both handlers + assert len(exec_instance._handlers) == 2 + assert str in exec_instance._handlers # Explicit + assert float in exec_instance._handlers # Introspected + + # Should have both output types + assert int in exec_instance.output_types # Explicit + assert bool in exec_instance.output_types # Introspected + + +# endregion: Tests for @handler decorator with explicit input_type and output_type diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index a034f42a38..4168342857 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -535,3 +535,203 @@ class C: async_static = static_wrapped assert asyncio.iscoroutinefunction(C.async_static) # Works via descriptor protocol + + +class TestExecutorExplicitTypes: + """Test suite for @executor decorator with explicit input_type and output_type parameters.""" + + def test_executor_with_explicit_input_type(self): + """Test that explicit input_type takes precedence over introspection.""" + + @executor(input_type=str) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Handler should be registered for str (explicit) + assert str in process._handlers + assert len(process._handlers) == 1 + + # Can handle str messages + assert process.can_handle(Message(data="hello", source_id="mock")) + # Cannot handle int messages + assert not process.can_handle(Message(data=42, source_id="mock")) + + def test_executor_with_explicit_output_type(self): + """Test that explicit output_type takes precedence over introspection.""" + + @executor(output_type=int) + async def process(message: str, ctx: WorkflowContext[str]) -> None: + pass + + # Handler spec should have int as output type (explicit), not str (introspected) + spec = process._handler_specs[0] + assert spec["output_types"] == [int] + + # Executor output_types property should reflect explicit type + assert int in process.output_types + assert str not in process.output_types + + def test_executor_with_explicit_input_and_output_types(self): + """Test that both explicit input_type and output_type work together.""" + + @executor(id="explicit_both", input_type=dict, output_type=list) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Handler should be registered for dict (explicit input type) + assert dict in process._handlers + assert len(process._handlers) == 1 + + # Output type should be list (explicit) + spec = process._handler_specs[0] + assert spec["output_types"] == [list] + + # Verify can_handle + assert process.can_handle(Message(data={"key": "value"}, source_id="mock")) + assert not process.can_handle(Message(data="string", source_id="mock")) + + def test_executor_with_explicit_union_input_type(self): + """Test that explicit union input_type is handled correctly.""" + + @executor(input_type=str | int) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Handler should be registered for the union type + assert len(process._handlers) == 1 + + # Can handle both str and int messages + assert process.can_handle(Message(data="hello", source_id="mock")) + assert process.can_handle(Message(data=42, source_id="mock")) + # Cannot handle float + assert not process.can_handle(Message(data=3.14, source_id="mock")) + + def test_executor_with_explicit_union_output_type(self): + """Test that explicit union output_type is normalized to a list.""" + + @executor(output_type=str | int | bool) + async def process(message: Any, ctx: WorkflowContext) -> None: + pass + + # Output types should be a list with all union members + assert set(process.output_types) == {str, int, bool} + + def test_executor_explicit_types_precedence_over_introspection(self): + """Test that explicit types always take precedence over introspected types.""" + + # Introspection would give: input=str, output=[int] + # Explicit gives: input=bytes, output=[float] + @executor(input_type=bytes, output_type=float) + async def process(message: str, ctx: WorkflowContext[int]) -> None: + pass + + # Should use explicit input type (bytes), not introspected (str) + assert bytes in process._handlers + assert str not in process._handlers + + # Should use explicit output type (float), not introspected (int) + assert float in process.output_types + assert int not in process.output_types + + def test_executor_fallback_to_introspection_when_no_explicit_types(self): + """Test that introspection is used when no explicit types are provided.""" + + @executor + async def process(message: str, ctx: WorkflowContext[int]) -> None: + pass + + # Should use introspected types + assert str in process._handlers + assert int in process.output_types + + def test_executor_partial_explicit_types(self): + """Test that partial explicit types work (only input_type or only output_type).""" + + # Only explicit input_type, introspect output_type + @executor(input_type=bytes) + async def process_input(message: str, ctx: WorkflowContext[int]) -> None: + pass + + assert bytes in process_input._handlers # Explicit + assert int in process_input.output_types # Introspected + + # Only explicit output_type, introspect input_type + @executor(output_type=float) + async def process_output(message: str, ctx: WorkflowContext[int]) -> None: + pass + + assert str in process_output._handlers # Introspected + assert float in process_output.output_types # Explicit + assert int not in process_output.output_types # Not introspected when explicit provided + + def test_executor_explicit_input_type_allows_no_message_annotation(self): + """Test that explicit input_type allows function without message type annotation.""" + + @executor(input_type=str) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should work with explicit input_type + assert str in process._handlers + assert process.can_handle(Message(data="hello", source_id="mock")) + + def test_executor_explicit_types_with_id(self): + """Test that explicit types work together with id parameter.""" + + @executor(id="custom_id", input_type=bytes, output_type=int) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + assert process.id == "custom_id" + assert bytes in process._handlers + assert int in process.output_types + + def test_executor_explicit_types_with_single_param_function(self): + """Test that explicit input_type works with single-parameter functions.""" + + @executor(input_type=str) + async def process(message): # type: ignore[no-untyped-def] + return message.upper() + + # Should work with explicit input_type + assert str in process._handlers + assert process.can_handle(Message(data="hello", source_id="mock")) + assert not process.can_handle(Message(data=42, source_id="mock")) + + def test_executor_explicit_types_with_sync_function(self): + """Test that explicit types work with synchronous functions.""" + + @executor(input_type=int, output_type=str) + def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + assert int in process._handlers + assert str in process.output_types + + def test_function_executor_constructor_with_explicit_types(self): + """Test FunctionExecutor constructor with explicit input_type and output_type.""" + + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + func_exec = FunctionExecutor(process, id="test", input_type=dict, output_type=list) + + assert dict in func_exec._handlers + spec = func_exec._handler_specs[0] + assert spec["message_type"] is dict + assert spec["output_types"] == [list] + + def test_executor_explicit_union_types_via_typing_union(self): + """Test that Union[] syntax also works for explicit types.""" + from typing import Union + + @executor(input_type=Union[str, int], output_type=Union[bool, float]) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Can handle both str and int + assert process.can_handle(Message(data="hello", source_id="mock")) + assert process.can_handle(Message(data=42, source_id="mock")) + + # Output types should include both + assert set(process.output_types) == {bool, float} diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index 4294f35f4b..c882fc7813 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -1,16 +1,84 @@ # Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, Optional, TypeVar, Union from agent_framework import RequestInfoEvent from agent_framework._workflows._typing_utils import ( deserialize_type, is_instance_of, is_type_compatible, + normalize_type_to_list, serialize_type, ) +# region: normalize_type_to_list tests + + +def test_normalize_type_to_list_single_type() -> None: + """Test normalize_type_to_list with single types.""" + assert normalize_type_to_list(str) == [str] + assert normalize_type_to_list(int) == [int] + assert normalize_type_to_list(float) == [float] + assert normalize_type_to_list(bool) == [bool] + assert normalize_type_to_list(list) == [list] + assert normalize_type_to_list(dict) == [dict] + + +def test_normalize_type_to_list_none() -> None: + """Test normalize_type_to_list with None returns empty list.""" + assert normalize_type_to_list(None) == [] + + +def test_normalize_type_to_list_union_pipe_syntax() -> None: + """Test normalize_type_to_list with union types using | syntax.""" + result = normalize_type_to_list(str | int) + assert set(result) == {str, int} + + result = normalize_type_to_list(str | int | bool) + assert set(result) == {str, int, bool} + + +def test_normalize_type_to_list_union_typing_syntax() -> None: + """Test normalize_type_to_list with Union[] from typing module.""" + result = normalize_type_to_list(Union[str, int]) + assert set(result) == {str, int} + + result = normalize_type_to_list(Union[str, int, bool]) + assert set(result) == {str, int, bool} + + +def test_normalize_type_to_list_optional() -> None: + """Test normalize_type_to_list with Optional types (Union[T, None]).""" + # Optional[str] is Union[str, None] + result = normalize_type_to_list(Optional[str]) + assert str in result + assert type(None) in result + assert len(result) == 2 + + # str | None is equivalent + result = normalize_type_to_list(str | None) + assert str in result + assert type(None) in result + assert len(result) == 2 + + +def test_normalize_type_to_list_custom_types() -> None: + """Test normalize_type_to_list with custom class types.""" + + @dataclass + class CustomMessage: + content: str + + result = normalize_type_to_list(CustomMessage) + assert result == [CustomMessage] + + result = normalize_type_to_list(CustomMessage | str) + assert set(result) == {CustomMessage, str} + + +# endregion: normalize_type_to_list tests + def test_basic_types() -> None: """Test basic built-in types.""" diff --git a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py index b5c80062dd..d070173885 100644 --- a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py +++ b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py @@ -33,6 +33,16 @@ Simple steps can use this form; a terminal step can yield output using ctx.yield_output() to provide workflow results. +- Explicit type parameters with @handler: + Instead of relying on type introspection from function signatures, you can explicitly + specify `input_type` and/or `output_type` on the @handler decorator. These explicit + types take precedence over introspection and support union types (e.g., `str | int`). + + Examples: + @handler(input_type=str | int) # Accepts str or int, output from introspection + @handler(output_type=str | int) # Input from introspection, outputs str or int + @handler(input_type=str, output_type=int) # Both explicitly specified + - Fluent WorkflowBuilder API: add_edge(A, B) to connect nodes, set_start_executor(A), then build() -> Workflow. @@ -45,8 +55,8 @@ """ -# Example 1: A custom Executor subclass -# ------------------------------------ +# Example 1: A custom Executor subclass using introspection (traditional approach) +# --------------------------------------------------------------------------------- # # Subclassing Executor lets you define a named node with lifecycle hooks if needed. # The work itself is implemented in an async method decorated with @handler. @@ -70,14 +80,15 @@ async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None: Note: The WorkflowContext is parameterized with the type this handler will emit. Here WorkflowContext[str] means downstream nodes should expect str. """ + result = text.upper() # Send the result to the next executor in the workflow. await ctx.send_message(result) -# Example 2: A standalone function-based executor -# ----------------------------------------------- +# Example 2: A standalone function-based executor using introspection +# -------------------------------------------------------------------- # # For simple steps you can skip subclassing and define an async function with the # same signature pattern (typed input + WorkflowContext[T_Out, T_W_Out]) and decorate it with @@ -101,30 +112,94 @@ async def reverse_text(text: str, ctx: WorkflowContext[Never, str]) -> None: await ctx.yield_output(result) +# Example 3: Using explicit type parameters on @handler +# ----------------------------------------------------- +# +# Instead of relying on type introspection, you can explicitly specify input_type +# and/or output_type on the @handler decorator. These take precedence over introspection +# and support union types (e.g., str | int). +# +# This is useful when: +# - You want to accept multiple types (union types) without complex type annotations +# - The function signature uses Any or a base type for flexibility +# - You want to decouple the runtime type routing from the static type annotations + + +class ExclamationAdder(Executor): + """An executor that adds exclamation marks, demonstrating explicit @handler types. + + This example shows how to use explicit input_type and output_type parameters + on the @handler decorator instead of relying on introspection from the function + signature. This approach is especially useful for union types. + """ + + def __init__(self, id: str): + super().__init__(id=id) + + @handler(input_type=str, output_type=str) + async def add_exclamation(self, message: str, ctx: WorkflowContext) -> None: + """Add exclamation marks to the input. + + Note: The input_type=str and output_type=str are explicitly specified on @handler, + so the framework uses those instead of introspecting the function signature. + The WorkflowContext here has no type parameters because the explicit types + on @handler take precedence. + """ + result = f"{message}!!!" + await ctx.send_message(result) + + async def main(): - """Build and run a simple 2-step workflow using the fluent builder API.""" + """Build and run workflows using the fluent builder API.""" + # Workflow 1: Using introspection-based type detection + # ----------------------------------------------------- upper_case = UpperCase(id="upper_case_executor") # Build the workflow using a fluent pattern: # 1) add_edge(from_node, to_node) defines a directed edge upper_case -> reverse_text # 2) set_start_executor(node) declares the entry point # 3) build() finalizes and returns an immutable Workflow object - workflow = WorkflowBuilder().add_edge(upper_case, reverse_text).set_start_executor(upper_case).build() + workflow1 = WorkflowBuilder().add_edge(upper_case, reverse_text).set_start_executor(upper_case).build() # Run the workflow by sending the initial message to the start node. # The run(...) call returns an event collection; its get_outputs() method # retrieves the outputs yielded by any terminal nodes. - events = await workflow.run("hello world") - print(events.get_outputs()) - # Summarize the final run state (e.g., IDLE) - print("Final state:", events.get_final_state()) + print("Workflow 1 (introspection-based types):") + events1 = await workflow1.run("hello world") + print(events1.get_outputs()) + print("Final state:", events1.get_final_state()) + + # Workflow 2: Using explicit type parameters on @handler + # ------------------------------------------------------- + exclamation_adder = ExclamationAdder(id="exclamation_adder") + + # This workflow demonstrates the explicit input_type/output_type feature: + # exclamation_adder uses @handler(input_type=str, output_type=str) to + # explicitly declare types instead of relying on introspection. + workflow2 = ( + WorkflowBuilder() + .add_edge(upper_case, exclamation_adder) + .add_edge(exclamation_adder, reverse_text) + .set_start_executor(upper_case) + .build() + ) + + print("\nWorkflow 2 (explicit @handler types):") + events2 = await workflow2.run("hello world") + print(events2.get_outputs()) + print("Final state:", events2.get_final_state()) """ Sample Output: + Workflow 1 (introspection-based types): ['DLROW OLLEH'] Final state: WorkflowRunState.IDLE + + Workflow 2 (explicit @handler types): + ['!!!DLROW OLLEH'] + Final state: WorkflowRunState.IDLE """ From 06504a0b1f6758df549f4da5b5d6133b0d7afcde Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 28 Jan 2026 07:58:22 -0800 Subject: [PATCH 2/9] Add handling for string types --- .../agent_framework/_workflows/_executor.py | 63 ++++++++++------- .../_workflows/_function_executor.py | 53 +++++++++----- .../_workflows/_typing_utils.py | 55 +++++++++++++++ .../core/tests/workflow/test_executor.py | 64 +++++++++++++++++ .../tests/workflow/test_function_executor.py | 54 +++++++++++++++ .../core/tests/workflow/test_typing_utils.py | 69 +++++++++++++++++++ 6 files changed, 314 insertions(+), 44 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 127466ea8e..31ffc7fd6f 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -542,8 +542,8 @@ def handler( @overload def handler( *, - input_type: type | types.UnionType | None = None, - output_type: type | types.UnionType | None = None, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, ) -> Callable[ [Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]], Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], @@ -553,8 +553,8 @@ def handler( def handler( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | None = None, *, - input_type: type | types.UnionType | None = None, - output_type: type | types.UnionType | None = None, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, ) -> ( Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | Callable[ @@ -567,43 +567,52 @@ def handler( Args: func: The function to decorate. Can be None when used with parameters. input_type: Optional explicit input type(s) for this handler. Supports union types - (e.g., ``str | int``). When provided, takes precedence over introspection from - the function's message parameter annotation. + (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). + When provided, takes precedence over introspection from the function's message + parameter annotation. output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. - Supports union types (e.g., ``str | int``). When provided, takes precedence over - introspection from the ``WorkflowContext`` generic parameters. + Supports union types (e.g., ``str | int``) and string forward references. + When provided, takes precedence over introspection from the ``WorkflowContext`` + generic parameters. Returns: The decorated function with handler metadata. Example: - # Using introspection (existing behavior) - @handler - async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: - ... - - # Using explicit types (takes precedence over introspection) - @handler(input_type=str | int, output_type=bool) - async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: - ... - - # Only specifying input_type (output_type falls back to introspection) - @handler(input_type=MyCustomType) - async def handle_custom(self, message: Any, ctx: WorkflowContext[str]) -> None: - ... + .. code-block:: python + + # Using introspection (existing behavior) + @handler + async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: ... + + + # Using explicit types (takes precedence over introspection) + @handler(input_type=str | int, output_type=bool) + async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: ... + + + # Using string forward references + @handler(input_type="MyCustomType | int", output_type="ResponseType") + async def handle_custom(self, message: Any, ctx: WorkflowContext) -> None: ... """ - from ._typing_utils import normalize_type_to_list + from ._typing_utils import normalize_type_to_list, resolve_type_annotation def decorator( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: + # Resolve string forward references using the function's globals + resolved_input_type = resolve_type_annotation(input_type, func.__globals__) if input_type is not None else None + resolved_output_type = ( + resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None + ) + # Extract the message type and validate using unified validation introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( - _validate_handler_signature(func, skip_message_annotation=input_type is not None) + _validate_handler_signature(func, skip_message_annotation=resolved_input_type is not None) ) # Use explicit types if provided, otherwise fall back to introspection - message_type = input_type if input_type is not None else introspected_message_type + message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type # Validate that we have a message type - this should never happen if signature # validation passed, but provides a clear error if type information is missing @@ -613,7 +622,9 @@ def decorator( "or an explicit input_type parameter" ) - final_output_types = normalize_type_to_list(output_type) if output_type is not None else inferred_output_types + final_output_types = ( + normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types + ) # Get signature for preservation sig = inspect.signature(func) diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index f6167976d3..54108c14d0 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -24,7 +24,7 @@ from typing import Any from ._executor import Executor -from ._typing_utils import normalize_type_to_list +from ._typing_utils import normalize_type_to_list, resolve_type_annotation from ._workflow_context import WorkflowContext, validate_workflow_context_annotation if sys.version_info >= (3, 11): @@ -48,8 +48,8 @@ def __init__( func: Callable[..., Any], id: str | None = None, *, - input_type: type | types.UnionType | None = None, - output_type: type | types.UnionType | None = None, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, ): """Initialize the FunctionExecutor with a user-defined function. @@ -57,11 +57,13 @@ def __init__( func: The function to wrap as an executor (can be sync or async) id: Optional executor ID. If None, uses the function name. input_type: Optional explicit input type(s) for this executor. Supports union types - (e.g., ``str | int``). When provided, takes precedence over introspection from - the function's message parameter annotation. + (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). + When provided, takes precedence over introspection from the function's message + parameter annotation. output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. - Supports union types (e.g., ``str | int``). When provided, takes precedence over - introspection from the ``WorkflowContext`` generic parameters. + Supports union types (e.g., ``str | int``) and string forward references. + When provided, takes precedence over introspection from the ``WorkflowContext`` + generic parameters. Raises: ValueError: If func is a staticmethod or classmethod (use @handler on instance methods instead) @@ -75,14 +77,22 @@ def __init__( f"or create an Executor subclass and use @handler on instance methods instead." ) + # Resolve string forward references using the function's globals + resolved_input_type = resolve_type_annotation(input_type, func.__globals__) if input_type is not None else None + resolved_output_type = ( + resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None + ) + # Validate function signature and extract types introspected_message_type, ctx_annotation, inferred_output_types, workflow_output_types = ( - _validate_function_signature(func, skip_message_annotation=input_type is not None) + _validate_function_signature(func, skip_message_annotation=resolved_input_type is not None) ) # Use explicit types if provided, otherwise fall back to introspection - message_type = input_type if input_type is not None else introspected_message_type - output_types = normalize_type_to_list(output_type) if output_type is not None else inferred_output_types + message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type + output_types = ( + normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types + ) # Validate that we have a message type - provides a clear error if type information is missing if message_type is None: @@ -158,8 +168,8 @@ def executor(func: Callable[..., Any]) -> FunctionExecutor: ... def executor( *, id: str | None = None, - input_type: type | types.UnionType | None = None, - output_type: type | types.UnionType | None = None, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor]: ... @@ -167,8 +177,8 @@ def executor( func: Callable[..., Any] | None = None, *, id: str | None = None, - input_type: type | types.UnionType | None = None, - output_type: type | types.UnionType | None = None, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor] | FunctionExecutor: """Decorator that converts a standalone function into a FunctionExecutor instance. @@ -205,6 +215,11 @@ async def process(message: Any, ctx: WorkflowContext): await ctx.send_message(True) + # Using string forward references: + @executor(input_type="MyCustomType | int", output_type="ResponseType") + async def process(message: Any, ctx: WorkflowContext): ... + + # For class-based executors, use @handler instead: class MyExecutor(Executor): def __init__(self): @@ -218,11 +233,13 @@ async def process(self, data: str, ctx: WorkflowContext[str]): func: The function to decorate (when used without parentheses) id: Optional custom ID for the executor. If None, uses the function name. input_type: Optional explicit input type(s) for this executor. Supports union types - (e.g., ``str | int``). When provided, takes precedence over introspection from - the function's message parameter annotation. + (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). + When provided, takes precedence over introspection from the function's message + parameter annotation. output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. - Supports union types (e.g., ``str | int``). When provided, takes precedence over - introspection from the ``WorkflowContext`` generic parameters. + Supports union types (e.g., ``str | int``) and string forward references. + When provided, takes precedence over introspection from the ``WorkflowContext`` + generic parameters. Returns: A FunctionExecutor instance that can be wired into a Workflow. diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 9516d29438..d0e9490a24 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -10,6 +10,61 @@ T = TypeVar("T") +def resolve_type_annotation( + type_annotation: type[Any] | UnionType | str | None, + globalns: dict[str, Any] | None = None, + localns: dict[str, Any] | None = None, +) -> type[Any] | UnionType | None: + """Resolve a type annotation, including string forward references. + + Args: + type_annotation: A type, union type, string forward reference, or None + globalns: Global namespace for resolving forward references (typically func.__globals__) + localns: Local namespace for resolving forward references + + Returns: + The resolved type annotation. For string annotations, evaluates them in the + provided namespace. Returns None if type_annotation is None. + + Raises: + NameError: If a forward reference cannot be resolved in the provided namespaces + SyntaxError: If a string annotation contains invalid Python syntax + + Note: + This function uses eval() to resolve string type annotations. This is the same + approach used by Python's typing.get_type_hints() and typing.ForwardRef internally. + Security is managed by: (1) strings come from decorator parameters in source code, + not runtime user input, and (2) the eval namespace is restricted to the function's + module globals plus Union/Optional from typing. + + Examples: + - resolve_type_annotation(str) -> str + - resolve_type_annotation("str | int", {"str": str, "int": int}) -> str | int + - resolve_type_annotation("MyClass", {"MyClass": MyClass}) -> MyClass + """ + if type_annotation is None: + return None + + if isinstance(type_annotation, str): + # Resolve string forward reference by evaluating it. + # This uses eval() which is the same approach as Python's typing.get_type_hints() + # and typing.ForwardRef._evaluate(). The namespace is restricted to the function's + # globals plus typing constructs, and input comes from developer source code. + eval_globalns = globalns.copy() if globalns else {} + eval_globalns.setdefault("Union", Union) + eval_globalns.setdefault("Optional", __import__("typing").Optional) + + try: + return eval(type_annotation, eval_globalns, localns) # noqa: S307 # nosec B307 + except NameError as e: + raise NameError( + f"Could not resolve type annotation '{type_annotation}'. " + f"Make sure the type is defined or imported. Original error: {e}" + ) from e + + return type_annotation + + def normalize_type_to_list(type_annotation: type[Any] | UnionType | None) -> list[type[Any]]: """Normalize a type annotation (possibly a union) to a list of concrete types. diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index ab5de9b6d5..b34015d9b5 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from dataclasses import dataclass + import pytest from agent_framework import ( @@ -16,6 +18,27 @@ ) +# Module-level types for string forward reference tests +@dataclass +class ForwardRefMessage: + content: str + + +@dataclass +class ForwardRefTypeA: + value: str + + +@dataclass +class ForwardRefTypeB: + value: int + + +@dataclass +class ForwardRefResponse: + result: str + + def test_executor_without_id(): """Test that an executor without an ID raises an error when trying to run.""" @@ -737,5 +760,46 @@ async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) assert int in exec_instance.output_types # Explicit assert bool in exec_instance.output_types # Introspected + def test_handler_with_string_forward_reference_input_type(self): + """Test that string forward references work for input_type.""" + + class StringRefExecutor(Executor): + @handler(input_type="ForwardRefMessage") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringRefExecutor(id="string_ref") + + # Should resolve the string to the actual type + assert ForwardRefMessage in exec_instance._handlers + assert exec_instance.can_handle(Message(data=ForwardRefMessage("hello"), source_id="mock")) + + def test_handler_with_string_forward_reference_union(self): + """Test that string forward references work with union types.""" + + class StringUnionExecutor(Executor): + @handler(input_type="ForwardRefTypeA | ForwardRefTypeB") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringUnionExecutor(id="string_union") + + # Should handle both types + assert exec_instance.can_handle(Message(data=ForwardRefTypeA("hello"), source_id="mock")) + assert exec_instance.can_handle(Message(data=ForwardRefTypeB(42), source_id="mock")) + + def test_handler_with_string_forward_reference_output_type(self): + """Test that string forward references work for output_type.""" + + class StringOutputExecutor(Executor): + @handler(input_type=str, output_type="ForwardRefResponse") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringOutputExecutor(id="string_output") + + # Should resolve the string output type + assert ForwardRefResponse in exec_instance.output_types + # endregion: Tests for @handler decorator with explicit input_type and output_type diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index 4168342857..71d6cb34e2 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +from dataclasses import dataclass from typing import Any import pytest @@ -14,6 +15,27 @@ ) +# Module-level types for string forward reference tests +@dataclass +class FuncExecForwardRefMessage: + content: str + + +@dataclass +class FuncExecForwardRefTypeA: + value: str + + +@dataclass +class FuncExecForwardRefTypeB: + value: int + + +@dataclass +class FuncExecForwardRefResponse: + result: str + + class TestFunctionExecutor: """Test suite for FunctionExecutor and @executor decorator.""" @@ -735,3 +757,35 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt # Output types should include both assert set(process.output_types) == {bool, float} + + def test_executor_with_string_forward_reference_input_type(self): + """Test that string forward references work for input_type.""" + + @executor(input_type="FuncExecForwardRefMessage") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should resolve the string to the actual type + assert FuncExecForwardRefMessage in process._handlers + assert process.can_handle(Message(data=FuncExecForwardRefMessage("hello"), source_id="mock")) + + def test_executor_with_string_forward_reference_union(self): + """Test that string forward references work with union types.""" + + @executor(input_type="FuncExecForwardRefTypeA | FuncExecForwardRefTypeB") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should handle both types + assert process.can_handle(Message(data=FuncExecForwardRefTypeA("hello"), source_id="mock")) + assert process.can_handle(Message(data=FuncExecForwardRefTypeB(42), source_id="mock")) + + def test_executor_with_string_forward_reference_output_type(self): + """Test that string forward references work for output_type.""" + + @executor(input_type=str, output_type="FuncExecForwardRefResponse") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should resolve the string output type + assert FuncExecForwardRefResponse in process.output_types diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index c882fc7813..3e8d1051e7 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -3,12 +3,15 @@ from dataclasses import dataclass from typing import Any, Generic, Optional, TypeVar, Union +import pytest + from agent_framework import RequestInfoEvent from agent_framework._workflows._typing_utils import ( deserialize_type, is_instance_of, is_type_compatible, normalize_type_to_list, + resolve_type_annotation, serialize_type, ) @@ -80,6 +83,72 @@ class CustomMessage: # endregion: normalize_type_to_list tests +# region: resolve_type_annotation tests + + +def test_resolve_type_annotation_none() -> None: + """Test resolve_type_annotation with None returns None.""" + assert resolve_type_annotation(None) is None + + +def test_resolve_type_annotation_actual_types() -> None: + """Test resolve_type_annotation passes through actual types unchanged.""" + assert resolve_type_annotation(str) is str + assert resolve_type_annotation(int) is int + assert resolve_type_annotation(str | int) == str | int + + +def test_resolve_type_annotation_string_builtin() -> None: + """Test resolve_type_annotation resolves string references to builtin types.""" + result = resolve_type_annotation("str", {"str": str}) + assert result is str + + result = resolve_type_annotation("int", {"int": int}) + assert result is int + + +def test_resolve_type_annotation_string_union() -> None: + """Test resolve_type_annotation resolves string union types.""" + result = resolve_type_annotation("str | int", {"str": str, "int": int}) + assert result == str | int + + +def test_resolve_type_annotation_string_custom_type() -> None: + """Test resolve_type_annotation resolves string references to custom types.""" + + @dataclass + class MyCustomType: + value: int + + result = resolve_type_annotation("MyCustomType", {"MyCustomType": MyCustomType}) + assert result is MyCustomType + + result = resolve_type_annotation("MyCustomType | str", {"MyCustomType": MyCustomType, "str": str}) + assert set(result.__args__) == {MyCustomType, str} # type: ignore[union-attr] + + +def test_resolve_type_annotation_string_typing_union() -> None: + """Test resolve_type_annotation resolves Union[] syntax in strings.""" + result = resolve_type_annotation("Union[str, int]", {"str": str, "int": int}) + assert set(result.__args__) == {str, int} # type: ignore[union-attr] + + +def test_resolve_type_annotation_string_optional() -> None: + """Test resolve_type_annotation resolves Optional[] syntax in strings.""" + result = resolve_type_annotation("Optional[str]", {"str": str}) + assert str in result.__args__ # type: ignore[union-attr] + assert type(None) in result.__args__ # type: ignore[union-attr] + + +def test_resolve_type_annotation_unresolvable_raises() -> None: + """Test resolve_type_annotation raises NameError for unresolvable types.""" + with pytest.raises(NameError, match="Could not resolve type annotation"): + resolve_type_annotation("NonExistentType", {}) + + +# endregion: resolve_type_annotation tests + + def test_basic_types() -> None: """Test basic built-in types.""" assert is_instance_of(5, int) From 69638860cf7efeb49189e8425e629ea6b2c1ca20 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 29 Jan 2026 11:36:32 -0800 Subject: [PATCH 3/9] Fix typing --- .../_workflows/_agent_executor.py | 3 +- .../agent_framework/_workflows/_executor.py | 47 ++++++-- .../_workflows/_function_executor.py | 46 ++++++-- .../_workflows/_typing_utils.py | 7 +- .../agent_framework/_workflows/_validation.py | 7 +- .../agent_framework/_workflows/_workflow.py | 7 +- .../_workflows/_workflow_executor.py | 9 +- .../core/tests/workflow/test_executor.py | 109 ++++++++++++++++++ .../tests/workflow/test_function_executor.py | 106 +++++++++++++++++ 9 files changed, 310 insertions(+), 31 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 9beaf06a65..3503f87717 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -2,6 +2,7 @@ import logging import sys +import types from dataclasses import dataclass from typing import Any, cast @@ -110,7 +111,7 @@ def output_response(self) -> bool: return self._output_response @property - def workflow_output_types(self) -> list[type[Any]]: + def workflow_output_types(self) -> list[type[Any] | types.UnionType]: # Override to declare AgentResponse as a possible output type only if enabled. if self._output_response: return [AgentResponse] diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 31ffc7fd6f..e0b0338de2 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -201,7 +201,9 @@ def __init__( from builtins import type as builtin_type - self._handlers: dict[builtin_type[Any], Callable[[Any, WorkflowContext[Any, Any]], Awaitable[None]]] = {} + self._handlers: dict[ + builtin_type[Any] | types.UnionType, Callable[[Any, WorkflowContext[Any, Any]], Awaitable[None]] + ] = {} self._handler_specs: list[dict[str, Any]] = [] if not defer_discovery: self._discover_handlers() @@ -386,10 +388,10 @@ def _register_instance_handler( self, name: str, func: Callable[[Any, WorkflowContext[Any]], Awaitable[Any]], - message_type: type, + message_type: type | types.UnionType, ctx_annotation: Any, - output_types: list[type], - workflow_output_types: list[type], + output_types: list[type[Any] | types.UnionType], + workflow_output_types: list[type[Any] | types.UnionType], ) -> None: """Register a handler at instance level. @@ -415,7 +417,7 @@ def _register_instance_handler( }) @property - def input_types(self) -> list[type[Any]]: + def input_types(self) -> list[type[Any] | types.UnionType]: """Get the list of input types that this executor can handle. Returns: @@ -424,13 +426,13 @@ def input_types(self) -> list[type[Any]]: return list(self._handlers.keys()) @property - def output_types(self) -> list[type[Any]]: + def output_types(self) -> list[type[Any] | types.UnionType]: """Get the list of output types that this executor can produce via send_message(). Returns: A list of the output types inferred from the handlers' WorkflowContext[T] annotations. """ - output_types: set[type[Any]] = set() + output_types: set[type[Any] | types.UnionType] = set() # Collect output types from all handlers for handler_spec in self._handler_specs + self._response_handler_specs: @@ -440,13 +442,13 @@ def output_types(self) -> list[type[Any]]: return list(output_types) @property - def workflow_output_types(self) -> list[type[Any]]: + def workflow_output_types(self) -> list[type[Any] | types.UnionType]: """Get the list of workflow output types that this executor can produce via yield_output(). Returns: A list of the workflow output types inferred from handlers' WorkflowContext[T, U] annotations. """ - output_types: set[type[Any]] = set() + output_types: set[type[Any] | types.UnionType] = set() # Collect workflow output types from all handlers for handler_spec in self._handler_specs + self._response_handler_specs: @@ -544,6 +546,7 @@ def handler( *, input_type: type | types.UnionType | str | None = None, output_type: type | types.UnionType | str | None = None, + workflow_output_type: type | types.UnionType | str | None = None, ) -> Callable[ [Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]], Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], @@ -555,6 +558,7 @@ def handler( *, input_type: type | types.UnionType | str | None = None, output_type: type | types.UnionType | str | None = None, + workflow_output_type: type | types.UnionType | str | None = None, ) -> ( Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | Callable[ @@ -573,7 +577,11 @@ def handler( output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` - generic parameters. + first generic parameter (T_Out). + workflow_output_type: Optional explicit output type(s) that can be yielded via + ``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string + forward references. When provided, takes precedence over introspection from the + ``WorkflowContext`` second generic parameter (T_W_Out). Returns: The decorated function with handler metadata. @@ -594,6 +602,13 @@ async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: ... # Using string forward references @handler(input_type="MyCustomType | int", output_type="ResponseType") async def handle_custom(self, message: Any, ctx: WorkflowContext) -> None: ... + + + # Specifying both output types (send_message and yield_output) + @handler(input_type=str, output_type=int, workflow_output_type=bool) + async def handle_full(self, message: Any, ctx: WorkflowContext) -> None: + await ctx.send_message(42) # int - matches output_type + await ctx.yield_output(True) # bool - matches workflow_output_type """ from ._typing_utils import normalize_type_to_list, resolve_type_annotation @@ -605,6 +620,11 @@ def decorator( resolved_output_type = ( resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None ) + resolved_workflow_output_type = ( + resolve_type_annotation(workflow_output_type, func.__globals__) + if workflow_output_type is not None + else None + ) # Extract the message type and validate using unified validation introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( @@ -625,6 +645,11 @@ def decorator( final_output_types = ( normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types ) + final_workflow_output_types = ( + normalize_type_to_list(resolved_workflow_output_type) + if resolved_workflow_output_type is not None + else inferred_workflow_output_types + ) # Get signature for preservation sig = inspect.signature(func) @@ -643,7 +668,7 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: "message_type": message_type, # Keep output_types and workflow_output_types in spec for validators "output_types": final_output_types, - "workflow_output_types": inferred_workflow_output_types, + "workflow_output_types": final_workflow_output_types, "ctx_annotation": ctx_annotation, } diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index 54108c14d0..7e18311d13 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -50,6 +50,7 @@ def __init__( *, input_type: type | types.UnionType | str | None = None, output_type: type | types.UnionType | str | None = None, + workflow_output_type: type | types.UnionType | str | None = None, ): """Initialize the FunctionExecutor with a user-defined function. @@ -63,7 +64,11 @@ def __init__( output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` - generic parameters. + first generic parameter (T_Out). + workflow_output_type: Optional explicit output type(s) that can be yielded via + ``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string + forward references. When provided, takes precedence over introspection from the + ``WorkflowContext`` second generic parameter (T_W_Out). Raises: ValueError: If func is a staticmethod or classmethod (use @handler on instance methods instead) @@ -82,16 +87,28 @@ def __init__( resolved_output_type = ( resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None ) + resolved_workflow_output_type = ( + resolve_type_annotation(workflow_output_type, func.__globals__) + if workflow_output_type is not None + else None + ) # Validate function signature and extract types - introspected_message_type, ctx_annotation, inferred_output_types, workflow_output_types = ( + introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( _validate_function_signature(func, skip_message_annotation=resolved_input_type is not None) ) # Use explicit types if provided, otherwise fall back to introspection message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type - output_types = ( - normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types + output_types: list[type[Any] | types.UnionType] = ( + normalize_type_to_list(resolved_output_type) + if resolved_output_type is not None + else list(inferred_output_types) + ) + final_workflow_output_types: list[type[Any] | types.UnionType] = ( + normalize_type_to_list(resolved_workflow_output_type) + if resolved_workflow_output_type is not None + else list(inferred_workflow_output_types) ) # Validate that we have a message type - provides a clear error if type information is missing @@ -144,7 +161,7 @@ async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: message_type=message_type, ctx_annotation=ctx_annotation, output_types=output_types, - workflow_output_types=workflow_output_types, + workflow_output_types=final_workflow_output_types, ) # Now we can safely call _discover_handlers (it won't find any class-level handlers) @@ -170,6 +187,7 @@ def executor( id: str | None = None, input_type: type | types.UnionType | str | None = None, output_type: type | types.UnionType | str | None = None, + workflow_output_type: type | types.UnionType | str | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor]: ... @@ -179,6 +197,7 @@ def executor( id: str | None = None, input_type: type | types.UnionType | str | None = None, output_type: type | types.UnionType | str | None = None, + workflow_output_type: type | types.UnionType | str | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor] | FunctionExecutor: """Decorator that converts a standalone function into a FunctionExecutor instance. @@ -220,6 +239,13 @@ async def process(message: Any, ctx: WorkflowContext): async def process(message: Any, ctx: WorkflowContext): ... + # Specifying both output types (send_message and yield_output): + @executor(input_type=str, output_type=int, workflow_output_type=bool) + async def process(message: Any, ctx: WorkflowContext): + await ctx.send_message(42) # int - matches output_type + await ctx.yield_output(True) # bool - matches workflow_output_type + + # For class-based executors, use @handler instead: class MyExecutor(Executor): def __init__(self): @@ -239,7 +265,11 @@ async def process(self, data: str, ctx: WorkflowContext[str]): output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` - generic parameters. + first generic parameter (T_Out). + workflow_output_type: Optional explicit output type(s) that can be yielded via + ``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string + forward references. When provided, takes precedence over introspection from the + ``WorkflowContext`` second generic parameter (T_W_Out). Returns: A FunctionExecutor instance that can be wired into a Workflow. @@ -249,7 +279,9 @@ async def process(self, data: str, ctx: WorkflowContext[str]): """ def wrapper(func: Callable[..., Any]) -> FunctionExecutor: - return FunctionExecutor(func, id=id, input_type=input_type, output_type=output_type) + return FunctionExecutor( + func, id=id, input_type=input_type, output_type=output_type, workflow_output_type=workflow_output_type + ) # If func is provided, this means @executor was used without parentheses if func is not None: diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index d0e9490a24..491a7649bd 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -55,7 +55,10 @@ def resolve_type_annotation( eval_globalns.setdefault("Optional", __import__("typing").Optional) try: - return eval(type_annotation, eval_globalns, localns) # noqa: S307 # nosec B307 + return cast( + "type[Any] | UnionType", + eval(type_annotation, eval_globalns, localns), # noqa: S307 # nosec B307 + ) except NameError as e: raise NameError( f"Could not resolve type annotation '{type_annotation}'. " @@ -65,7 +68,7 @@ def resolve_type_annotation( return type_annotation -def normalize_type_to_list(type_annotation: type[Any] | UnionType | None) -> list[type[Any]]: +def normalize_type_to_list(type_annotation: type[Any] | UnionType | None) -> list[type[Any] | UnionType]: """Normalize a type annotation (possibly a union) to a list of concrete types. Args: diff --git a/python/packages/core/agent_framework/_workflows/_validation.py b/python/packages/core/agent_framework/_workflows/_validation.py index fc59bb94e1..ff8a74028d 100644 --- a/python/packages/core/agent_framework/_workflows/_validation.py +++ b/python/packages/core/agent_framework/_workflows/_validation.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import logging +import types from collections import defaultdict from collections.abc import Sequence from enum import Enum @@ -55,8 +56,8 @@ def __init__( self, source_executor_id: str, target_executor_id: str, - source_types: list[type[Any]], - target_types: list[type[Any]], + source_types: list[type[Any] | types.UnionType], + target_types: list[type[Any] | types.UnionType], ): # Use a placeholder for incompatible types - will be computed in WorkflowGraphValidator super().__init__( @@ -253,7 +254,7 @@ def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) - # Check if any source output type is compatible with any target input type compatible = False - compatible_pairs: list[tuple[type[Any], type[Any]]] = [] + compatible_pairs: list[tuple[type[Any] | types.UnionType, type[Any] | types.UnionType]] = [] for source_type in source_output_types: for target_type in target_input_types: diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index d6c612bff6..9cf06367b4 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -5,6 +5,7 @@ import hashlib import json import logging +import types import uuid from collections.abc import AsyncIterable, Awaitable, Callable from typing import Any @@ -815,7 +816,7 @@ def graph_signature_hash(self) -> str: return self._graph_signature_hash @property - def input_types(self) -> list[type[Any]]: + def input_types(self) -> list[type[Any] | types.UnionType]: """Get the input types of the workflow. The input types are the list of input types of the start executor. @@ -827,7 +828,7 @@ def input_types(self) -> list[type[Any]]: return start_executor.input_types @property - def output_types(self) -> list[type[Any]]: + def output_types(self) -> list[type[Any] | types.UnionType]: """Get the output types of the workflow. The output types are the list of all workflow output types from executors @@ -836,7 +837,7 @@ def output_types(self) -> list[type[Any]]: Returns: A list of output types that the workflow can produce. """ - output_types: set[type[Any]] = set() + output_types: set[type[Any] | types.UnionType] = set() for executor in self.executors.values(): workflow_output_types = executor.workflow_output_types diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 69f24bcf2c..67eb45b029 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -3,6 +3,7 @@ import asyncio import logging import sys +import types import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -302,13 +303,13 @@ def __init__( self._propagate_request = propagate_request @property - def input_types(self) -> list[type[Any]]: + def input_types(self) -> list[type[Any] | types.UnionType]: """Get the input types based on the underlying workflow's input types plus WorkflowExecutor-specific types. Returns: A list of input types that the WorkflowExecutor can accept. """ - input_types = list(self.workflow.input_types) + input_types: list[type[Any] | types.UnionType] = list(self.workflow.input_types) # WorkflowExecutor can also handle SubWorkflowResponseMessage for sub-workflow responses if SubWorkflowResponseMessage not in input_types: @@ -317,7 +318,7 @@ def input_types(self) -> list[type[Any]]: return input_types @property - def output_types(self) -> list[type[Any]]: + def output_types(self) -> list[type[Any] | types.UnionType]: """Get the output types based on the underlying workflow's output types. Returns: @@ -325,7 +326,7 @@ def output_types(self) -> list[type[Any]]: Includes the SubWorkflowRequestMessage type if any executor in the sub-workflow is request-response capable. """ - output_types = list(self.workflow.output_types) + output_types: list[type[Any] | types.UnionType] = list(self.workflow.output_types) is_request_response_capable = any( executor.is_request_response_capable for executor in self.workflow.executors.values() diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index b34015d9b5..1bf0ec577f 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -801,5 +801,114 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n # Should resolve the string output type assert ForwardRefResponse in exec_instance.output_types + def test_handler_with_explicit_workflow_output_type(self): + """Test that explicit workflow_output_type takes precedence over introspection.""" + + class ExplicitWorkflowOutputExecutor(Executor): + @handler(workflow_output_type=bool) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_instance = ExplicitWorkflowOutputExecutor(id="explicit_workflow_output") + + # Handler spec should have bool as workflow_output_type (explicit) + handler_func = exec_instance._handlers[str] + assert handler_func._handler_spec["workflow_output_types"] == [bool] + + # Executor workflow_output_types property should reflect explicit type + assert bool in exec_instance.workflow_output_types + # output_types should still come from introspection (int from WorkflowContext[int]) + assert int in exec_instance.output_types + + def test_handler_with_explicit_workflow_output_type_precedence(self): + """Test that explicit workflow_output_type overrides introspected WorkflowContext second param.""" + + class PrecedenceExecutor(Executor): + @handler(workflow_output_type=str) + async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None: + pass + + exec_instance = PrecedenceExecutor(id="precedence") + + # workflow_output_types should be str (explicit), not bool (introspected from ctx) + assert str in exec_instance.workflow_output_types + assert bool not in exec_instance.workflow_output_types + + def test_handler_with_all_explicit_types(self): + """Test that all three explicit type parameters work together.""" + from typing import Any + + class AllExplicitExecutor(Executor): + @handler(input_type=str, output_type=int, workflow_output_type=bool) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = AllExplicitExecutor(id="all_explicit") + + # Check input type + assert str in exec_instance._handlers + assert exec_instance.can_handle(Message(data="hello", source_id="mock")) + + # Check output_type + assert int in exec_instance.output_types + + # Check workflow_output_type + assert bool in exec_instance.workflow_output_types + + def test_handler_with_union_workflow_output_type(self): + """Test that union types work for workflow_output_type.""" + + class UnionWorkflowOutputExecutor(Executor): + @handler(workflow_output_type=str | int) + async def handle(self, message: str, ctx: WorkflowContext) -> None: + pass + + exec_instance = UnionWorkflowOutputExecutor(id="union_workflow_output") + + # Should include both types from union + assert str in exec_instance.workflow_output_types + assert int in exec_instance.workflow_output_types + + def test_handler_with_string_forward_reference_workflow_output_type(self): + """Test that string forward references work for workflow_output_type.""" + + class StringWorkflowOutputExecutor(Executor): + @handler(input_type=str, workflow_output_type="ForwardRefResponse") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringWorkflowOutputExecutor(id="string_workflow_output") + + # Should resolve the string workflow_output_type + assert ForwardRefResponse in exec_instance.workflow_output_types + + def test_handler_with_string_forward_reference_union_workflow_output_type(self): + """Test that string forward reference union types work for workflow_output_type.""" + + class StringUnionWorkflowOutputExecutor(Executor): + @handler(input_type=str, workflow_output_type="ForwardRefTypeA | ForwardRefTypeB") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringUnionWorkflowOutputExecutor(id="string_union_workflow_output") + + # Should resolve both types from string union + assert ForwardRefTypeA in exec_instance.workflow_output_types + assert ForwardRefTypeB in exec_instance.workflow_output_types + + def test_handler_fallback_to_introspection_for_workflow_output_type(self): + """Test that workflow_output_type falls back to introspection when not explicitly provided.""" + + class IntrospectedWorkflowOutputExecutor(Executor): + @handler + async def handle(self, message: str, ctx: WorkflowContext[int, bool]) -> None: + pass + + exec_instance = IntrospectedWorkflowOutputExecutor(id="introspected_workflow_output") + + # Should use introspected types from WorkflowContext[int, bool] + assert int in exec_instance.output_types + assert bool in exec_instance.workflow_output_types + # endregion: Tests for @handler decorator with explicit input_type and output_type diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index 71d6cb34e2..5d8b310752 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -789,3 +789,109 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt # Should resolve the string output type assert FuncExecForwardRefResponse in process.output_types + + def test_executor_with_explicit_workflow_output_type(self): + """Test that explicit workflow_output_type takes precedence over introspection.""" + + @executor(workflow_output_type=bool) + async def process(message: str, ctx: WorkflowContext[int]) -> None: + pass + + # Handler spec should have bool as workflow_output_type (explicit) + spec = process._handler_specs[0] + assert spec["workflow_output_types"] == [bool] + + # Executor workflow_output_types property should reflect explicit type + assert bool in process.workflow_output_types + # output_types should still come from introspection (int from WorkflowContext[int]) + assert int in process.output_types + + def test_executor_with_explicit_workflow_output_type_precedence(self): + """Test that explicit workflow_output_type overrides introspected WorkflowContext second param.""" + + @executor(workflow_output_type=str) + async def process(message: int, ctx: WorkflowContext[int, bool]) -> None: + pass + + # workflow_output_types should be str (explicit), not bool (introspected from ctx) + assert str in process.workflow_output_types + assert bool not in process.workflow_output_types + + def test_executor_with_all_explicit_types(self): + """Test that all three explicit type parameters work together.""" + from typing import Any + + @executor(input_type=str, output_type=int, workflow_output_type=bool) + async def process(message: Any, ctx: WorkflowContext) -> None: + pass + + # Check input type + assert str in process._handlers + assert process.can_handle(Message(data="hello", source_id="mock")) + + # Check output_type + assert int in process.output_types + + # Check workflow_output_type + assert bool in process.workflow_output_types + + def test_executor_with_union_workflow_output_type(self): + """Test that union types work for workflow_output_type.""" + + @executor(workflow_output_type=str | int) + async def process(message: str, ctx: WorkflowContext) -> None: + pass + + # Should include both types from union + assert str in process.workflow_output_types + assert int in process.workflow_output_types + + def test_executor_with_string_forward_reference_workflow_output_type(self): + """Test that string forward references work for workflow_output_type.""" + + @executor(input_type=str, workflow_output_type="FuncExecForwardRefResponse") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should resolve the string workflow_output_type + assert FuncExecForwardRefResponse in process.workflow_output_types + + def test_executor_with_string_forward_reference_union_workflow_output_type(self): + """Test that string forward reference union types work for workflow_output_type.""" + + @executor(input_type=str, workflow_output_type="FuncExecForwardRefTypeA | FuncExecForwardRefTypeB") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should resolve both types from string union + assert FuncExecForwardRefTypeA in process.workflow_output_types + assert FuncExecForwardRefTypeB in process.workflow_output_types + + def test_executor_fallback_to_introspection_for_workflow_output_type(self): + """Test that workflow_output_type falls back to introspection when not explicitly provided.""" + + @executor + async def process(message: str, ctx: WorkflowContext[int, bool]) -> None: + pass + + # Should use introspected types from WorkflowContext[int, bool] + assert int in process.output_types + assert bool in process.workflow_output_types + + def test_function_executor_constructor_with_workflow_output_type(self): + """Test FunctionExecutor constructor accepts workflow_output_type parameter.""" + + async def my_func(message: str, ctx: WorkflowContext) -> None: + pass + + exec_instance = FunctionExecutor( + my_func, + id="test_constructor", + input_type=str, + output_type=int, + workflow_output_type=bool, + ) + + assert str in exec_instance._handlers + assert int in exec_instance.output_types + assert bool in exec_instance.workflow_output_types From b275fd2aca85aaaa11e14c141de9107c9380367c Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 30 Jan 2026 12:49:39 -0800 Subject: [PATCH 4/9] Address PR feedback --- .../agent_framework/_workflows/_executor.py | 53 ++++++++----------- .../_workflows/_typing_utils.py | 4 -- .../core/tests/workflow/test_executor.py | 38 ++++++------- .../tests/workflow/test_request_info_mixin.py | 1 - .../_start-here/step1_executors_and_edges.py | 22 ++++---- 5 files changed, 53 insertions(+), 65 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index e0b0338de2..cb893c4477 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -21,7 +21,7 @@ from ._request_info_mixin import RequestInfoMixin from ._runner_context import Message, MessageType, RunnerContext from ._shared_state import SharedState -from ._typing_utils import is_instance_of +from ._typing_utils import is_instance_of, normalize_type_to_list, resolve_type_annotation from ._workflow_context import WorkflowContext, validate_workflow_context_annotation logger = logging.getLogger(__name__) @@ -329,15 +329,13 @@ def _discover_handlers(self) -> None: """Discover message handlers in the executor class.""" # Use __class__.__dict__ to avoid accessing pydantic's dynamic attributes for attr_name in dir(self.__class__): - # Narrow the exception scope - only catch AttributeError when accessing the attribute try: attr = getattr(self.__class__, attr_name) except AttributeError: # Skip attributes that may not be accessible (e.g., dynamic descriptors) - logger.debug(f"Could not access attribute {attr_name} on {self.__class__.__name__}") continue - # Discover @handler methods - let AttributeError propagate for malformed handler specs + # Discover @handler methods if callable(attr) and hasattr(attr, "_handler_spec"): handler_spec = attr._handler_spec # type: ignore message_type = handler_spec["message_type"] @@ -357,7 +355,6 @@ def _discover_handlers(self) -> None: "output_types": handler_spec.get("output_types", []), "workflow_output_types": handler_spec.get("workflow_output_types", []), "ctx_annotation": handler_spec.get("ctx_annotation"), - "source": "class_method", # Distinguish from instance handlers if needed }) def can_handle(self, message: Message) -> bool: @@ -413,7 +410,6 @@ def _register_instance_handler( "ctx_annotation": ctx_annotation, "output_types": output_types, "workflow_output_types": workflow_output_types, - "source": "instance_method", # Distinguish from class handlers if needed }) @property @@ -544,9 +540,9 @@ def handler( @overload def handler( *, - input_type: type | types.UnionType | str | None = None, - output_type: type | types.UnionType | str | None = None, - workflow_output_type: type | types.UnionType | str | None = None, + input: type | types.UnionType | str | None = None, + output: type | types.UnionType | str | None = None, + workflow_output: type | types.UnionType | str | None = None, ) -> Callable[ [Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]], Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], @@ -556,9 +552,9 @@ def handler( def handler( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | None = None, *, - input_type: type | types.UnionType | str | None = None, - output_type: type | types.UnionType | str | None = None, - workflow_output_type: type | types.UnionType | str | None = None, + input: type | types.UnionType | str | None = None, + output: type | types.UnionType | str | None = None, + workflow_output: type | types.UnionType | str | None = None, ) -> ( Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | Callable[ @@ -570,15 +566,15 @@ def handler( Args: func: The function to decorate. Can be None when used with parameters. - input_type: Optional explicit input type(s) for this handler. Supports union types + input: Optional explicit input type(s) for this handler. Supports union types (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). When provided, takes precedence over introspection from the function's message parameter annotation. - output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + output: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` first generic parameter (T_Out). - workflow_output_type: Optional explicit output type(s) that can be yielded via + workflow_output: Optional explicit output type(s) that can be yielded via ``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` second generic parameter (T_W_Out). @@ -595,38 +591,35 @@ async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: # Using explicit types (takes precedence over introspection) - @handler(input_type=str | int, output_type=bool) + @handler(input=str | int, output=bool) async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: ... # Using string forward references - @handler(input_type="MyCustomType | int", output_type="ResponseType") + @handler(input="MyCustomType | int", output="ResponseType") async def handle_custom(self, message: Any, ctx: WorkflowContext) -> None: ... # Specifying both output types (send_message and yield_output) - @handler(input_type=str, output_type=int, workflow_output_type=bool) + @handler(input=str, output=int, workflow_output=bool) async def handle_full(self, message: Any, ctx: WorkflowContext) -> None: - await ctx.send_message(42) # int - matches output_type - await ctx.yield_output(True) # bool - matches workflow_output_type + await ctx.send_message(42) # int - matches output + await ctx.yield_output(True) # bool - matches workflow_output """ - from ._typing_utils import normalize_type_to_list, resolve_type_annotation def decorator( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: # Resolve string forward references using the function's globals - resolved_input_type = resolve_type_annotation(input_type, func.__globals__) if input_type is not None else None - resolved_output_type = ( - resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None - ) + resolved_input_type = resolve_type_annotation(input, func.__globals__) if input is not None else None + resolved_output_type = resolve_type_annotation(output, func.__globals__) if output is not None else None resolved_workflow_output_type = ( - resolve_type_annotation(workflow_output_type, func.__globals__) - if workflow_output_type is not None - else None + resolve_type_annotation(workflow_output, func.__globals__) if workflow_output is not None else None ) - # Extract the message type and validate using unified validation + # Extract the message type and validate using unified validation. + # This runs even when explicit params are provided to allow mixing: + # e.g., input from decorator, output from WorkflowContext annotation. introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( _validate_handler_signature(func, skip_message_annotation=resolved_input_type is not None) ) @@ -639,7 +632,7 @@ def decorator( if message_type is None: raise ValueError( f"Handler {func.__name__} requires either a message parameter type annotation " - "or an explicit input_type parameter" + "or an explicit input parameter" ) final_output_types = ( diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 491a7649bd..3fe42fd053 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -3,10 +3,6 @@ from types import UnionType from typing import Any, TypeVar, Union, cast, get_args, get_origin -from agent_framework import get_logger - -logger = get_logger("agent_framework._workflows._typing_utils") - T = TypeVar("T") diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 1bf0ec577f..3f7c3cbe1f 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -574,7 +574,7 @@ def test_handler_with_explicit_input_type(self): from typing import Any class ExplicitInputExecutor(Executor): - @handler(input_type=str) + @handler(input=str) async def handle(self, message: Any, ctx: WorkflowContext) -> None: pass @@ -593,7 +593,7 @@ def test_handler_with_explicit_output_type(self): """Test that explicit output_type takes precedence over introspection.""" class ExplicitOutputExecutor(Executor): - @handler(output_type=int) + @handler(output=int) async def handle(self, message: str, ctx: WorkflowContext[str]) -> None: pass @@ -612,7 +612,7 @@ def test_handler_with_explicit_input_and_output_types(self): from typing import Any class ExplicitBothExecutor(Executor): - @handler(input_type=dict, output_type=list) + @handler(input=dict, output=list) async def handle(self, message: Any, ctx: WorkflowContext) -> None: pass @@ -635,7 +635,7 @@ def test_handler_with_explicit_union_input_type(self): from typing import Any class UnionInputExecutor(Executor): - @handler(input_type=str | int) + @handler(input=str | int) async def handle(self, message: Any, ctx: WorkflowContext) -> None: pass @@ -656,7 +656,7 @@ def test_handler_with_explicit_union_output_type(self): from typing import Any class UnionOutputExecutor(Executor): - @handler(output_type=str | int | bool) + @handler(output=str | int | bool) async def handle(self, message: Any, ctx: WorkflowContext) -> None: pass @@ -671,7 +671,7 @@ def test_handler_explicit_types_precedence_over_introspection(self): class PrecedenceExecutor(Executor): # Introspection would give: input=str, output=[int] # Explicit gives: input=bytes, output=[float] - @handler(input_type=bytes, output_type=float) + @handler(input=bytes, output=float) async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: pass @@ -704,7 +704,7 @@ def test_handler_partial_explicit_types(self): # Only explicit input_type, introspect output_type class OnlyInputExecutor(Executor): - @handler(input_type=bytes) + @handler(input=bytes) async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: pass @@ -714,7 +714,7 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: # Only explicit output_type, introspect input_type class OnlyOutputExecutor(Executor): - @handler(output_type=float) + @handler(output=float) async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: pass @@ -727,7 +727,7 @@ def test_handler_explicit_input_type_allows_no_message_annotation(self): """Test that explicit input_type allows handler without message type annotation.""" class NoAnnotationExecutor(Executor): - @handler(input_type=str) + @handler(input=str) async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -741,7 +741,7 @@ def test_handler_multiple_handlers_mixed_explicit_and_introspected(self): """Test executor with multiple handlers, some with explicit types and some introspected.""" class MixedExecutor(Executor): - @handler(input_type=str, output_type=int) + @handler(input=str, output=int) async def handle_explicit(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -764,7 +764,7 @@ def test_handler_with_string_forward_reference_input_type(self): """Test that string forward references work for input_type.""" class StringRefExecutor(Executor): - @handler(input_type="ForwardRefMessage") + @handler(input="ForwardRefMessage") async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -778,7 +778,7 @@ def test_handler_with_string_forward_reference_union(self): """Test that string forward references work with union types.""" class StringUnionExecutor(Executor): - @handler(input_type="ForwardRefTypeA | ForwardRefTypeB") + @handler(input="ForwardRefTypeA | ForwardRefTypeB") async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -792,7 +792,7 @@ def test_handler_with_string_forward_reference_output_type(self): """Test that string forward references work for output_type.""" class StringOutputExecutor(Executor): - @handler(input_type=str, output_type="ForwardRefResponse") + @handler(input=str, output="ForwardRefResponse") async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -805,7 +805,7 @@ def test_handler_with_explicit_workflow_output_type(self): """Test that explicit workflow_output_type takes precedence over introspection.""" class ExplicitWorkflowOutputExecutor(Executor): - @handler(workflow_output_type=bool) + @handler(workflow_output=bool) async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: pass @@ -824,7 +824,7 @@ def test_handler_with_explicit_workflow_output_type_precedence(self): """Test that explicit workflow_output_type overrides introspected WorkflowContext second param.""" class PrecedenceExecutor(Executor): - @handler(workflow_output_type=str) + @handler(workflow_output=str) async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None: pass @@ -839,7 +839,7 @@ def test_handler_with_all_explicit_types(self): from typing import Any class AllExplicitExecutor(Executor): - @handler(input_type=str, output_type=int, workflow_output_type=bool) + @handler(input=str, output=int, workflow_output=bool) async def handle(self, message: Any, ctx: WorkflowContext) -> None: pass @@ -859,7 +859,7 @@ def test_handler_with_union_workflow_output_type(self): """Test that union types work for workflow_output_type.""" class UnionWorkflowOutputExecutor(Executor): - @handler(workflow_output_type=str | int) + @handler(workflow_output=str | int) async def handle(self, message: str, ctx: WorkflowContext) -> None: pass @@ -873,7 +873,7 @@ def test_handler_with_string_forward_reference_workflow_output_type(self): """Test that string forward references work for workflow_output_type.""" class StringWorkflowOutputExecutor(Executor): - @handler(input_type=str, workflow_output_type="ForwardRefResponse") + @handler(input=str, workflow_output="ForwardRefResponse") async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -886,7 +886,7 @@ def test_handler_with_string_forward_reference_union_workflow_output_type(self): """Test that string forward reference union types work for workflow_output_type.""" class StringUnionWorkflowOutputExecutor(Executor): - @handler(input_type=str, workflow_output_type="ForwardRefTypeA | ForwardRefTypeB") + @handler(input=str, workflow_output="ForwardRefTypeA | ForwardRefTypeB") async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass diff --git a/python/packages/core/tests/workflow/test_request_info_mixin.py b/python/packages/core/tests/workflow/test_request_info_mixin.py index d5528f721d..d89794ef82 100644 --- a/python/packages/core/tests/workflow/test_request_info_mixin.py +++ b/python/packages/core/tests/workflow/test_request_info_mixin.py @@ -247,7 +247,6 @@ async def test_handler(self, original_request: str, response: int, ctx: Workflow assert "output_types" in spec assert "workflow_output_types" in spec assert "ctx_annotation" in spec - assert spec["source"] == "class_method" def test_multiple_discovery_calls_raise_error(self): """Test that multiple calls to _discover_response_handlers raise an error for duplicates.""" diff --git a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py index d070173885..e9e3cb3592 100644 --- a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py +++ b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py @@ -35,13 +35,13 @@ - Explicit type parameters with @handler: Instead of relying on type introspection from function signatures, you can explicitly - specify `input_type` and/or `output_type` on the @handler decorator. These explicit + specify `input` and/or `output` on the @handler decorator. These explicit types take precedence over introspection and support union types (e.g., `str | int`). Examples: - @handler(input_type=str | int) # Accepts str or int, output from introspection - @handler(output_type=str | int) # Input from introspection, outputs str or int - @handler(input_type=str, output_type=int) # Both explicitly specified + @handler(input=str | int) # Accepts str or int, output from introspection + @handler(output=str | int) # Input from introspection, outputs str or int + @handler(input=str, output=int) # Both explicitly specified - Fluent WorkflowBuilder API: add_edge(A, B) to connect nodes, set_start_executor(A), then build() -> Workflow. @@ -115,8 +115,8 @@ async def reverse_text(text: str, ctx: WorkflowContext[Never, str]) -> None: # Example 3: Using explicit type parameters on @handler # ----------------------------------------------------- # -# Instead of relying on type introspection, you can explicitly specify input_type -# and/or output_type on the @handler decorator. These take precedence over introspection +# Instead of relying on type introspection, you can explicitly specify input +# and/or output on the @handler decorator. These take precedence over introspection # and support union types (e.g., str | int). # # This is useful when: @@ -128,7 +128,7 @@ async def reverse_text(text: str, ctx: WorkflowContext[Never, str]) -> None: class ExclamationAdder(Executor): """An executor that adds exclamation marks, demonstrating explicit @handler types. - This example shows how to use explicit input_type and output_type parameters + This example shows how to use explicit input and output parameters on the @handler decorator instead of relying on introspection from the function signature. This approach is especially useful for union types. """ @@ -136,11 +136,11 @@ class ExclamationAdder(Executor): def __init__(self, id: str): super().__init__(id=id) - @handler(input_type=str, output_type=str) + @handler(input=str, output=str) async def add_exclamation(self, message: str, ctx: WorkflowContext) -> None: """Add exclamation marks to the input. - Note: The input_type=str and output_type=str are explicitly specified on @handler, + Note: The input=str and output=str are explicitly specified on @handler, so the framework uses those instead of introspecting the function signature. The WorkflowContext here has no type parameters because the explicit types on @handler take precedence. @@ -174,8 +174,8 @@ async def main(): # ------------------------------------------------------- exclamation_adder = ExclamationAdder(id="exclamation_adder") - # This workflow demonstrates the explicit input_type/output_type feature: - # exclamation_adder uses @handler(input_type=str, output_type=str) to + # This workflow demonstrates the explicit input/output feature: + # exclamation_adder uses @handler(input=str, output=str) to # explicitly declare types instead of relying on introspection. workflow2 = ( WorkflowBuilder() From 1b9fce00cd6a14ec33b4faca8499332121465d36 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 30 Jan 2026 13:15:50 -0800 Subject: [PATCH 5/9] All or nothing for handler typing approach --- .../agent_framework/_workflows/_executor.py | 101 ++++++++++-------- .../core/tests/workflow/test_executor.py | 64 ++++++----- .../_start-here/step1_executors_and_edges.py | 19 ++-- 3 files changed, 105 insertions(+), 79 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index cb893c4477..a51801f2e1 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -564,20 +564,24 @@ def handler( ): """Decorator to register a handler for an executor. + Type information can be provided in two mutually exclusive ways: + + 1. **Introspection** (default): Types are inferred from function signature annotations. + Use type annotations on the message parameter and WorkflowContext generic parameters. + + 2. **Explicit parameters**: Types are specified via decorator parameters (input, output, + workflow_output). When ANY explicit parameter is provided, ALL types must come from + explicit parameters - introspection is completely disabled. The ``input`` parameter + is required; ``output`` and ``workflow_output`` are optional (default to no outputs). + Args: func: The function to decorate. Can be None when used with parameters. - input: Optional explicit input type(s) for this handler. Supports union types - (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). - When provided, takes precedence over introspection from the function's message - parameter annotation. - output: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + input: Explicit input type(s) for this handler. Required when using explicit mode. Supports union types (e.g., ``str | int``) and string forward references. - When provided, takes precedence over introspection from the ``WorkflowContext`` - first generic parameter (T_Out). - workflow_output: Optional explicit output type(s) that can be yielded via - ``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string - forward references. When provided, takes precedence over introspection from the - ``WorkflowContext`` second generic parameter (T_W_Out). + output: Explicit output type(s) that can be sent via ``ctx.send_message()``. + Optional; defaults to no outputs if not specified. + workflow_output: Explicit output type(s) that can be yielded via ``ctx.yield_output()``. + Optional; defaults to no outputs if not specified. Returns: The decorated function with handler metadata. @@ -585,22 +589,22 @@ def handler( Example: .. code-block:: python - # Using introspection (existing behavior) + # Mode 1: Introspection - types from annotations @handler async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: ... - # Using explicit types (takes precedence over introspection) + # Mode 2: Explicit types - ALL types from decorator params @handler(input=str | int, output=bool) async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: ... - # Using string forward references + # Explicit with string forward references @handler(input="MyCustomType | int", output="ResponseType") async def handle_custom(self, message: Any, ctx: WorkflowContext) -> None: ... - # Specifying both output types (send_message and yield_output) + # Explicit with all three type parameters @handler(input=str, output=int, workflow_output=bool) async def handle_full(self, message: Any, ctx: WorkflowContext) -> None: await ctx.send_message(42) # int - matches output @@ -610,39 +614,50 @@ async def handle_full(self, message: Any, ctx: WorkflowContext) -> None: def decorator( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: - # Resolve string forward references using the function's globals - resolved_input_type = resolve_type_annotation(input, func.__globals__) if input is not None else None - resolved_output_type = resolve_type_annotation(output, func.__globals__) if output is not None else None - resolved_workflow_output_type = ( - resolve_type_annotation(workflow_output, func.__globals__) if workflow_output is not None else None - ) + # Check if ANY explicit type parameter was provided - if so, use ONLY explicit params. + # This is "all or nothing" - no mixing of explicit params with introspection. + use_explicit_types = input is not None or output is not None or workflow_output is not None + + if use_explicit_types: + # Resolve string forward references using the function's globals + resolved_input_type = resolve_type_annotation(input, func.__globals__) if input is not None else None + resolved_output_type = resolve_type_annotation(output, func.__globals__) if output is not None else None + resolved_workflow_output_type = ( + resolve_type_annotation(workflow_output, func.__globals__) if workflow_output is not None else None + ) - # Extract the message type and validate using unified validation. - # This runs even when explicit params are provided to allow mixing: - # e.g., input from decorator, output from WorkflowContext annotation. - introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( - _validate_handler_signature(func, skip_message_annotation=resolved_input_type is not None) - ) + # Validate signature structure (correct number of params, ctx is WorkflowContext) + # but skip type extraction since we're using explicit types + _validate_handler_signature(func, skip_message_annotation=True) - # Use explicit types if provided, otherwise fall back to introspection - message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type + # Use explicit types only - missing params default to empty + message_type = resolved_input_type + if message_type is None: + raise ValueError(f"Handler {func.__name__} with explicit type parameters must specify 'input' type") - # Validate that we have a message type - this should never happen if signature - # validation passed, but provides a clear error if type information is missing - if message_type is None: - raise ValueError( - f"Handler {func.__name__} requires either a message parameter type annotation " - "or an explicit input parameter" + final_output_types = normalize_type_to_list(resolved_output_type) if resolved_output_type else [] + final_workflow_output_types = ( + normalize_type_to_list(resolved_workflow_output_type) if resolved_workflow_output_type else [] + ) + # Get ctx_annotation for consistency (even though types come from explicit params) + ctx_annotation = ( + inspect.signature(func).parameters[list(inspect.signature(func).parameters.keys())[2]].annotation + ) + else: + # Use introspection for ALL types - no explicit params provided + introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( + _validate_handler_signature(func, skip_message_annotation=False) ) - final_output_types = ( - normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types - ) - final_workflow_output_types = ( - normalize_type_to_list(resolved_workflow_output_type) - if resolved_workflow_output_type is not None - else inferred_workflow_output_types - ) + message_type = introspected_message_type + if message_type is None: + raise ValueError( + f"Handler {func.__name__} requires either a message parameter type annotation " + "or explicit type parameters (input, output, workflow_output)" + ) + + final_output_types = inferred_output_types + final_workflow_output_types = inferred_workflow_output_types # Get signature for preservation sig = inspect.signature(func) diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 3f7c3cbe1f..c9dc46d9e9 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -590,16 +590,16 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert not exec_instance.can_handle(Message(data=42, source_id="mock")) def test_handler_with_explicit_output_type(self): - """Test that explicit output_type takes precedence over introspection.""" + """Test that explicit output works when input is also specified.""" class ExplicitOutputExecutor(Executor): - @handler(output=int) + @handler(input=str, output=int) async def handle(self, message: str, ctx: WorkflowContext[str]) -> None: pass exec_instance = ExplicitOutputExecutor(id="explicit_output") - # Handler spec should have int as output type (explicit), not str (introspected) + # Handler spec should have int as output type (explicit) handler_func = exec_instance._handlers[str] assert handler_func._handler_spec["output_types"] == [int] @@ -652,11 +652,11 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert not exec_instance.can_handle(Message(data=3.14, source_id="mock")) def test_handler_with_explicit_union_output_type(self): - """Test that explicit union output_type is normalized to a list.""" + """Test that explicit union output is normalized to a list.""" from typing import Any class UnionOutputExecutor(Executor): - @handler(output=str | int | bool) + @handler(input=bytes, output=str | int | bool) async def handle(self, message: Any, ctx: WorkflowContext) -> None: pass @@ -699,10 +699,10 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: assert str in exec_instance._handlers assert int in exec_instance.output_types - def test_handler_partial_explicit_types(self): - """Test that partial explicit types work (only input_type or only output_type).""" + def test_handler_explicit_mode_requires_input(self): + """Test that using any explicit type param requires input to be specified.""" - # Only explicit input_type, introspect output_type + # Only explicit input - output defaults to empty (no introspection) class OnlyInputExecutor(Executor): @handler(input=bytes) async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: @@ -710,18 +710,23 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: exec_input = OnlyInputExecutor(id="only_input") assert bytes in exec_input._handlers # Explicit - assert int in exec_input.output_types # Introspected + assert exec_input.output_types == [] # No output types (not introspected) - # Only explicit output_type, introspect input_type - class OnlyOutputExecutor(Executor): - @handler(output=float) - async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: - pass + # Only explicit output without input should raise error + with pytest.raises(ValueError, match="must specify 'input' type"): + + class OnlyOutputExecutor(Executor): + @handler(output=float) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass - exec_output = OnlyOutputExecutor(id="only_output") - assert str in exec_output._handlers # Introspected - assert float in exec_output.output_types # Explicit - assert int not in exec_output.output_types # Not introspected when explicit provided + # Only explicit workflow_output without input should raise error + with pytest.raises(ValueError, match="must specify 'input' type"): + + class OnlyWorkflowOutputExecutor(Executor): + @handler(workflow_output=bool) + async def handle(self, message: str, ctx: WorkflowContext[int, str]) -> None: + pass def test_handler_explicit_input_type_allows_no_message_annotation(self): """Test that explicit input_type allows handler without message type annotation.""" @@ -802,10 +807,10 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n assert ForwardRefResponse in exec_instance.output_types def test_handler_with_explicit_workflow_output_type(self): - """Test that explicit workflow_output_type takes precedence over introspection.""" + """Test that explicit workflow_output works when input is also specified.""" class ExplicitWorkflowOutputExecutor(Executor): - @handler(workflow_output=bool) + @handler(input=str, workflow_output=bool) async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: pass @@ -817,21 +822,24 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: # Executor workflow_output_types property should reflect explicit type assert bool in exec_instance.workflow_output_types - # output_types should still come from introspection (int from WorkflowContext[int]) - assert int in exec_instance.output_types + # output_types should be empty (explicit mode, output not specified) + assert exec_instance.output_types == [] - def test_handler_with_explicit_workflow_output_type_precedence(self): - """Test that explicit workflow_output_type overrides introspected WorkflowContext second param.""" + def test_handler_with_explicit_workflow_output_and_output(self): + """Test that explicit workflow_output works alongside explicit output.""" class PrecedenceExecutor(Executor): - @handler(workflow_output=str) + @handler(input=int, output=float, workflow_output=str) async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None: pass exec_instance = PrecedenceExecutor(id="precedence") - # workflow_output_types should be str (explicit), not bool (introspected from ctx) + # All types should come from explicit params + assert int in exec_instance._handlers + assert float in exec_instance.output_types assert str in exec_instance.workflow_output_types + # Introspected types should NOT be present assert bool not in exec_instance.workflow_output_types def test_handler_with_all_explicit_types(self): @@ -856,10 +864,10 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert bool in exec_instance.workflow_output_types def test_handler_with_union_workflow_output_type(self): - """Test that union types work for workflow_output_type.""" + """Test that union types work for workflow_output.""" class UnionWorkflowOutputExecutor(Executor): - @handler(workflow_output=str | int) + @handler(input=str, workflow_output=str | int) async def handle(self, message: str, ctx: WorkflowContext) -> None: pass diff --git a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py index e9e3cb3592..7c9f7a4cbb 100644 --- a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py +++ b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py @@ -35,13 +35,15 @@ - Explicit type parameters with @handler: Instead of relying on type introspection from function signatures, you can explicitly - specify `input` and/or `output` on the @handler decorator. These explicit - types take precedence over introspection and support union types (e.g., `str | int`). + specify `input`, `output`, and/or `workflow_output` on the @handler decorator. + This is "all or nothing": when ANY explicit parameter is provided, ALL types come + from explicit parameters (introspection is disabled). The `input` parameter is + required; `output` and `workflow_output` are optional. Examples: - @handler(input=str | int) # Accepts str or int, output from introspection - @handler(output=str | int) # Input from introspection, outputs str or int - @handler(input=str, output=int) # Both explicitly specified + @handler(input=str | int) # Accepts str or int, no outputs + @handler(input=str, output=int) # Accepts str, outputs int + @handler(input=str, output=int, workflow_output=bool) # All three specified - Fluent WorkflowBuilder API: add_edge(A, B) to connect nodes, set_start_executor(A), then build() -> Workflow. @@ -115,9 +117,10 @@ async def reverse_text(text: str, ctx: WorkflowContext[Never, str]) -> None: # Example 3: Using explicit type parameters on @handler # ----------------------------------------------------- # -# Instead of relying on type introspection, you can explicitly specify input -# and/or output on the @handler decorator. These take precedence over introspection -# and support union types (e.g., str | int). +# Instead of relying on type introspection, you can explicitly specify input, +# output, and/or workflow_output on the @handler decorator. This is "all or nothing": +# when ANY explicit parameter is provided, ALL types come from explicit parameters +# (introspection is completely disabled). The input parameter is required. # # This is useful when: # - You want to accept multiple types (union types) without complex type annotations From 919a1c017e6b058a5251c37b3acb103daa8cb498 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 30 Jan 2026 13:32:25 -0800 Subject: [PATCH 6/9] Fix mypy issues --- .../core/agent_framework/_workflows/_executor.py | 2 +- .../_workflows/_function_executor.py | 2 +- .../_workflows/_request_info_mixin.py | 3 ++- .../_workflows/_workflow_context.py | 14 ++++++++------ 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index a51801f2e1..d7ca1efa97 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -699,7 +699,7 @@ def _validate_handler_signature( func: Callable[..., Any], *, skip_message_annotation: bool = False, -) -> tuple[type | None, Any, list[type[Any]], list[type[Any]]]: +) -> tuple[type | None, Any, list[type[Any] | types.UnionType], list[type[Any] | types.UnionType]]: """Validate function signature for executor functions. Args: diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index 7e18311d13..0bd829f843 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -300,7 +300,7 @@ def _validate_function_signature( func: Callable[..., Any], *, skip_message_annotation: bool = False, -) -> tuple[type | None, Any, list[type[Any]], list[type[Any]]]: +) -> tuple[type | None, Any, list[type[Any] | types.UnionType], list[type[Any] | types.UnionType]]: """Validate function signature for executor functions. Args: diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index d4c5f6d2bc..489a331a75 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -6,6 +6,7 @@ import logging from builtins import type as builtin_type from collections.abc import Awaitable, Callable +from types import UnionType from typing import TYPE_CHECKING, Any, TypeVar from ._typing_utils import is_instance_of, is_type_compatible @@ -195,7 +196,7 @@ async def wrapper(self: ExecutorT, original_request: Any, response: Any, ctx: Co def _validate_response_handler_signature( func: Callable[..., Any], -) -> tuple[type, type, Any, list[type[Any]], list[type[Any]]]: +) -> tuple[type, type, Any, list[type[Any] | UnionType], list[type[Any] | UnionType]]: """Validate function signature for executor functions. Args: diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index 893f0ccfe9..708cdf3c51 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -38,7 +38,9 @@ logger = logging.getLogger(__name__) -def infer_output_types_from_ctx_annotation(ctx_annotation: Any) -> tuple[list[type[Any]], list[type[Any]]]: +def infer_output_types_from_ctx_annotation( + ctx_annotation: Any, +) -> tuple[list[type[Any] | UnionType], list[type[Any] | UnionType]]: """Infer message types and workflow output types from the WorkflowContext generic parameters. Examples: @@ -81,8 +83,8 @@ def infer_output_types_from_ctx_annotation(ctx_annotation: Any) -> tuple[list[ty return [cast(type[Any], Any)], [] if t_origin in (Union, UnionType): - message_types = [arg for arg in get_args(t) if arg is not Any and arg is not Never] - return message_types, [] + msg_types: list[type[Any] | UnionType] = [arg for arg in get_args(t) if arg is not Any and arg is not Never] + return msg_types, [] if t is Never: return [], [] @@ -92,7 +94,7 @@ def infer_output_types_from_ctx_annotation(ctx_annotation: Any) -> tuple[list[ty t_out, t_w_out = args[:2] # Take first two args in case there are more # Process T_Out for message_types - message_types = [] + message_types: list[type[Any] | UnionType] = [] t_out_origin = get_origin(t_out) if t_out is Any: message_types = [cast(type[Any], Any)] @@ -103,7 +105,7 @@ def infer_output_types_from_ctx_annotation(ctx_annotation: Any) -> tuple[list[ty message_types = [t_out] # Process T_W_Out for workflow_output_types - workflow_output_types = [] + workflow_output_types: list[type[Any] | UnionType] = [] t_w_out_origin = get_origin(t_w_out) if t_w_out is Any: workflow_output_types = [cast(type[Any], Any)] @@ -129,7 +131,7 @@ def validate_workflow_context_annotation( annotation: Any, parameter_name: str, context_description: str, -) -> tuple[list[type[Any]], list[type[Any]]]: +) -> tuple[list[type[Any] | UnionType], list[type[Any] | UnionType]]: """Validate a WorkflowContext annotation and return inferred types. Args: From 200f36a3cdca6a16f3242c70b4ba70ed9e58beb2 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Tue, 3 Feb 2026 09:15:43 +0900 Subject: [PATCH 7/9] type support for request info --- .../agent_framework/_workflows/_executor.py | 15 +- .../_workflows/_function_executor.py | 63 +++--- .../_workflows/_request_info_mixin.py | 202 ++++++++++++++---- .../tests/workflow/test_function_executor.py | 52 ++--- .../tests/workflow/test_request_info_mixin.py | 167 +++++++++++++++ ...ff_with_tool_approval_checkpoint_resume.py | 79 +++---- 6 files changed, 424 insertions(+), 154 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index d7ca1efa97..18adc4b904 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -349,13 +349,7 @@ def _discover_handlers(self) -> None: self._handlers[message_type] = bound_method # Add to unified handler specs list - self._handler_specs.append({ - "name": handler_spec["name"], - "message_type": message_type, - "output_types": handler_spec.get("output_types", []), - "workflow_output_types": handler_spec.get("workflow_output_types", []), - "ctx_annotation": handler_spec.get("ctx_annotation"), - }) + self._handler_specs.append({**handler_spec}) def can_handle(self, message: Message) -> bool: """Check if the executor can handle a given message type. @@ -595,18 +589,19 @@ async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: # Mode 2: Explicit types - ALL types from decorator params + # Note: No type annotations on function parameters when using explicit types @handler(input=str | int, output=bool) - async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: ... + async def handle_data(self, message, ctx): ... # Explicit with string forward references @handler(input="MyCustomType | int", output="ResponseType") - async def handle_custom(self, message: Any, ctx: WorkflowContext) -> None: ... + async def handle_custom(self, message, ctx): ... # Explicit with all three type parameters @handler(input=str, output=int, workflow_output=bool) - async def handle_full(self, message: Any, ctx: WorkflowContext) -> None: + async def handle_full(self, message, ctx): await ctx.send_message(42) # int - matches output await ctx.yield_output(True) # bool - matches workflow_output """ diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index 0bd829f843..cac77d8173 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -48,24 +48,24 @@ def __init__( func: Callable[..., Any], id: str | None = None, *, - input_type: type | types.UnionType | str | None = None, - output_type: type | types.UnionType | str | None = None, - workflow_output_type: type | types.UnionType | str | None = None, + input: type | types.UnionType | str | None = None, + output: type | types.UnionType | str | None = None, + workflow_output: type | types.UnionType | str | None = None, ): """Initialize the FunctionExecutor with a user-defined function. Args: func: The function to wrap as an executor (can be sync or async) id: Optional executor ID. If None, uses the function name. - input_type: Optional explicit input type(s) for this executor. Supports union types + input: Optional explicit input type(s) for this executor. Supports union types (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). When provided, takes precedence over introspection from the function's message parameter annotation. - output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + output: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` first generic parameter (T_Out). - workflow_output_type: Optional explicit output type(s) that can be yielded via + workflow_output: Optional explicit output type(s) that can be yielded via ``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` second generic parameter (T_W_Out). @@ -83,14 +83,10 @@ def __init__( ) # Resolve string forward references using the function's globals - resolved_input_type = resolve_type_annotation(input_type, func.__globals__) if input_type is not None else None - resolved_output_type = ( - resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None - ) + resolved_input_type = resolve_type_annotation(input, func.__globals__) if input is not None else None + resolved_output_type = resolve_type_annotation(output, func.__globals__) if output is not None else None resolved_workflow_output_type = ( - resolve_type_annotation(workflow_output_type, func.__globals__) - if workflow_output_type is not None - else None + resolve_type_annotation(workflow_output, func.__globals__) if workflow_output is not None else None ) # Validate function signature and extract types @@ -185,9 +181,9 @@ def executor(func: Callable[..., Any]) -> FunctionExecutor: ... def executor( *, id: str | None = None, - input_type: type | types.UnionType | str | None = None, - output_type: type | types.UnionType | str | None = None, - workflow_output_type: type | types.UnionType | str | None = None, + input: type | types.UnionType | str | None = None, + output: type | types.UnionType | str | None = None, + workflow_output: type | types.UnionType | str | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor]: ... @@ -195,9 +191,9 @@ def executor( func: Callable[..., Any] | None = None, *, id: str | None = None, - input_type: type | types.UnionType | str | None = None, - output_type: type | types.UnionType | str | None = None, - workflow_output_type: type | types.UnionType | str | None = None, + input: type | types.UnionType | str | None = None, + output: type | types.UnionType | str | None = None, + workflow_output: type | types.UnionType | str | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor] | FunctionExecutor: """Decorator that converts a standalone function into a FunctionExecutor instance. @@ -229,21 +225,22 @@ def process_data(data: str): # Using explicit types (takes precedence over introspection): - @executor(id="my_executor", input_type=str | int, output_type=bool) - async def process(message: Any, ctx: WorkflowContext): + # Note: No type annotations on function parameters when using explicit types + @executor(id="my_executor", input=str | int, output=bool) + async def process(message, ctx): await ctx.send_message(True) # Using string forward references: - @executor(input_type="MyCustomType | int", output_type="ResponseType") - async def process(message: Any, ctx: WorkflowContext): ... + @executor(input="MyCustomType | int", output="ResponseType") + async def process(message, ctx): ... # Specifying both output types (send_message and yield_output): - @executor(input_type=str, output_type=int, workflow_output_type=bool) - async def process(message: Any, ctx: WorkflowContext): - await ctx.send_message(42) # int - matches output_type - await ctx.yield_output(True) # bool - matches workflow_output_type + @executor(input=str, output=int, workflow_output=bool) + async def process(message, ctx): + await ctx.send_message(42) # int - matches output + await ctx.yield_output(True) # bool - matches workflow_output # For class-based executors, use @handler instead: @@ -258,15 +255,15 @@ async def process(self, data: str, ctx: WorkflowContext[str]): Args: func: The function to decorate (when used without parentheses) id: Optional custom ID for the executor. If None, uses the function name. - input_type: Optional explicit input type(s) for this executor. Supports union types + input: Optional explicit input type(s) for this executor. Supports union types (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). When provided, takes precedence over introspection from the function's message parameter annotation. - output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + output: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` first generic parameter (T_Out). - workflow_output_type: Optional explicit output type(s) that can be yielded via + workflow_output: Optional explicit output type(s) that can be yielded via ``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string forward references. When provided, takes precedence over introspection from the ``WorkflowContext`` second generic parameter (T_W_Out). @@ -279,9 +276,7 @@ async def process(self, data: str, ctx: WorkflowContext[str]): """ def wrapper(func: Callable[..., Any]) -> FunctionExecutor: - return FunctionExecutor( - func, id=id, input_type=input_type, output_type=output_type, workflow_output_type=workflow_output_type - ) + return FunctionExecutor(func, id=id, input=input, output=output, workflow_output=workflow_output) # If func is provided, this means @executor was used without parentheses if func is not None: @@ -306,7 +301,7 @@ def _validate_function_signature( Args: func: The function to validate skip_message_annotation: If True, skip validation that message parameter has a type - annotation. Used when input_type is explicitly provided to the @executor decorator. + annotation. Used when input is explicitly provided to the @executor decorator. Returns: Tuple of (message_type, ctx_annotation, output_types, workflow_output_types). diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index 489a331a75..cd9aa5c3d6 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -4,14 +4,21 @@ import functools import inspect import logging +import sys +import types from builtins import type as builtin_type from collections.abc import Awaitable, Callable from types import UnionType from typing import TYPE_CHECKING, Any, TypeVar -from ._typing_utils import is_instance_of, is_type_compatible +from ._typing_utils import is_instance_of, is_type_compatible, normalize_type_to_list, resolve_type_annotation from ._workflow_context import WorkflowContext, validate_workflow_context_annotation +if sys.version_info >= (3, 11): + from typing import overload # pragma: no cover +else: + from typing_extensions import overload # pragma: no cover + if TYPE_CHECKING: from ._executor import Executor @@ -87,15 +94,7 @@ def _discover_response_handlers(self) -> None: ) self._response_handlers[request_type, response_type] = getattr(self, attr_name) - self._response_handler_specs.append({ - "name": handler_spec["name"], - "request_type": request_type, - "response_type": response_type, - "output_types": handler_spec.get("output_types", []), - "workflow_output_types": handler_spec.get("workflow_output_types", []), - "ctx_annotation": handler_spec.get("ctx_annotation"), - "source": "class_method", # Distinguish from instance handlers if needed - }) + self._response_handler_specs.append({**handler_spec, "source": "class_method"}) except AttributeError: continue # Skip non-callable attributes or those without handler spec @@ -111,13 +110,64 @@ def _discover_response_handlers(self) -> None: # region Handler Decorator +@overload def response_handler( func: Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]], -) -> Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]: +) -> Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]: ... + + +@overload +def response_handler( + func: None = None, + *, + request: type | types.UnionType | str | None = None, + response: type | types.UnionType | str | None = None, + output: type | types.UnionType | str | None = None, + workflow_output: type | types.UnionType | str | None = None, +) -> Callable[ + [Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]], + Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]], +]: ... + + +def response_handler( + func: Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]] | None = None, + *, + request: type | types.UnionType | str | None = None, + response: type | types.UnionType | str | None = None, + output: type | types.UnionType | str | None = None, + workflow_output: type | types.UnionType | str | None = None, +) -> ( + Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]] + | Callable[ + [Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]], + Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]], + ] +): """Decorator to register a handler to handle responses for a request. + Type information can be provided in two mutually exclusive ways: + + 1. **Introspection** (default): Types are inferred from function signature annotations. + Use type annotations on the original_request, response parameters and WorkflowContext + generic parameters. + + 2. **Explicit parameters**: Types are specified via decorator parameters (request, response, + output, workflow_output). When ANY explicit parameter is provided, ALL types must come + from explicit parameters - introspection is completely disabled. The ``request`` and + ``response`` parameters are required; ``output`` and ``workflow_output`` are optional + (default to no outputs). + Args: - func: The function to decorate. + func: The function to decorate. Can be None when used with parameters. + request: Explicit request type for this handler (the original_request parameter type). + Required when using explicit mode. Supports union types and string forward references. + response: Explicit response type for this handler (the response parameter type). + Required when using explicit mode. Supports union types and string forward references. + output: Explicit output type(s) that can be sent via ``ctx.send_message()``. + Optional; defaults to no outputs if not specified. + workflow_output: Explicit output type(s) that can be yielded via ``ctx.yield_output()``. + Optional; defaults to no outputs if not specified. Returns: The decorated function with handler metadata. @@ -125,6 +175,7 @@ def response_handler( Example: .. code-block:: python + # Mode 1: Introspection - types from annotations @handler async def run(self, message: int, context: WorkflowContext[str]) -> None: # Example of a handler that sends a request @@ -144,31 +195,75 @@ async def handle_response( ... - @response_handler - async def handle_response( - self, - original_request: CustomRequest, - response: dict, - context: WorkflowContext[int], - ) -> None: - # Example of a response handler for a request expecting a dict response - ... + # Mode 2: Explicit types - ALL types from decorator params + # Note: No type annotations on function parameters when using explicit types + @response_handler(request=CustomRequest, response=dict, output=int) + async def handle_response(self, original_request, response, context): + # Example of a response handler with explicit types + await context.send_message(42) + + + # Explicit with string forward references + @response_handler(request="MyRequest", response="MyResponse") + async def handle_response(self, original_request, response, context): ... """ def decorator( func: Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]], ) -> Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]: - request_type, response_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( - _validate_response_handler_signature(func) - ) + # Check if ANY explicit type parameter was provided - if so, use ONLY explicit params. + # This is "all or nothing" - no mixing of explicit params with introspection. + use_explicit_types = request is not None or response is not None or output is not None or workflow_output is not None + + if use_explicit_types: + # Resolve string forward references using the function's globals + resolved_request_type = resolve_type_annotation(request, func.__globals__) if request is not None else None + resolved_response_type = ( + resolve_type_annotation(response, func.__globals__) if response is not None else None + ) + resolved_output_type = resolve_type_annotation(output, func.__globals__) if output is not None else None + resolved_workflow_output_type = ( + resolve_type_annotation(workflow_output, func.__globals__) if workflow_output is not None else None + ) + + # Validate signature structure but skip type extraction + _validate_response_handler_signature(func, skip_annotations=True) + + # Validate required parameters + if resolved_request_type is None: + raise ValueError( + f"Response handler {func.__name__} with explicit type parameters must specify 'request' type" + ) + if resolved_response_type is None: + raise ValueError( + f"Response handler {func.__name__} with explicit type parameters must specify 'response' type" + ) + + final_request_type = resolved_request_type + final_response_type = resolved_response_type + final_output_types = normalize_type_to_list(resolved_output_type) if resolved_output_type else [] + final_workflow_output_types = ( + normalize_type_to_list(resolved_workflow_output_type) if resolved_workflow_output_type else [] + ) + # Get ctx_annotation for consistency + ctx_annotation = ( + inspect.signature(func).parameters[list(inspect.signature(func).parameters.keys())[3]].annotation + ) + if ctx_annotation == inspect.Parameter.empty: + ctx_annotation = None + else: + # Use introspection - all types from annotations + final_request_type, final_response_type, ctx_annotation, final_output_types, final_workflow_output_types = ( + _validate_response_handler_signature(func) + ) # Get signature for preservation sig = inspect.signature(func) @functools.wraps(func) - async def wrapper(self: ExecutorT, original_request: Any, response: Any, ctx: ContextT) -> Any: + async def wrapper(self: ExecutorT, original_request: Any, response_msg: Any, ctx: ContextT) -> Any: """Wrapper function to call the handler.""" - return await func(self, original_request, response, ctx) + return await func(self, original_request, response_msg, ctx) # Preserve the original function signature for introspection during validation with contextlib.suppress(AttributeError, TypeError): @@ -176,17 +271,22 @@ async def wrapper(self: ExecutorT, original_request: Any, response: Any, ctx: Co wrapper._response_handler_spec = { # type: ignore "name": func.__name__, - "request_type": request_type, - "response_type": response_type, + "request_type": final_request_type, + "response_type": final_response_type, # Keep output_types and workflow_output_types in spec for validators - "output_types": inferred_output_types, - "workflow_output_types": inferred_workflow_output_types, + "output_types": final_output_types, + "workflow_output_types": final_workflow_output_types, "ctx_annotation": ctx_annotation, } return wrapper - return decorator(func) + # If func is provided, this means @response_handler was used without parentheses + if func is not None: + return decorator(func) + + # Otherwise, return the wrapper for @response_handler(...) with parameters + return decorator # endregion: Handler Decorator @@ -196,14 +296,19 @@ async def wrapper(self: ExecutorT, original_request: Any, response: Any, ctx: Co def _validate_response_handler_signature( func: Callable[..., Any], -) -> tuple[type, type, Any, list[type[Any] | UnionType], list[type[Any] | UnionType]]: - """Validate function signature for executor functions. + *, + skip_annotations: bool = False, +) -> tuple[type | None, type | None, Any, list[type[Any] | UnionType], list[type[Any] | UnionType]]: + """Validate function signature for response handler functions. Args: func: The function to validate + skip_annotations: If True, skip validation that request/response parameters have type + annotations. Used when types are explicitly provided to the @response_handler decorator. Returns: - Tuple of (request_type, response_type, ctx_annotation, output_types, workflow_output_types) + Tuple of (request_type, response_type, ctx_annotation, output_types, workflow_output_types). + request_type and response_type may be None if skip_annotations is True and no annotations exist. Raises: ValueError: If the function signature is invalid @@ -216,33 +321,36 @@ def _validate_response_handler_signature( # to the original request when registering the handler, while maintaining # the order of parameters as if the response handler is a normal handler. expected_counts = 4 # self, original_request, message, ctx - param_description = "(self, original_request: TRequest, message: TResponse, ctx: WorkflowContext[U, V])" + param_description = "(self, original_request, response, ctx)" if len(params) != expected_counts: raise ValueError( f"Response handler {func.__name__} must have {param_description}. Got {len(params)} parameters." ) - # Check original_request parameter exists + # Check original_request parameter exists and has annotation (unless skipped) original_request_param = params[1] - if original_request_param.annotation == inspect.Parameter.empty: + if not skip_annotations and original_request_param.annotation == inspect.Parameter.empty: raise ValueError( f"Response handler {func.__name__} must have a type annotation for the original_request parameter" ) - # Check response parameter has type annotation + # Check response parameter has type annotation (unless skipped) response_param = params[2] - if response_param.annotation == inspect.Parameter.empty: - raise ValueError(f"Response handler {func.__name__} must have a type annotation for the message parameter") + if not skip_annotations and response_param.annotation == inspect.Parameter.empty: + raise ValueError(f"Response handler {func.__name__} must have a type annotation for the response parameter") - # Validate ctx parameter is WorkflowContext and extract type args + # Validate ctx parameter is WorkflowContext and extract type args (if annotated) ctx_param = params[3] - output_types, workflow_output_types = validate_workflow_context_annotation( - ctx_param.annotation, f"parameter '{ctx_param.name}'", "Response handler" - ) + if ctx_param.annotation != inspect.Parameter.empty: + output_types, workflow_output_types = validate_workflow_context_annotation( + ctx_param.annotation, f"parameter '{ctx_param.name}'", "Response handler" + ) + else: + output_types, workflow_output_types = [], [] - request_type = original_request_param.annotation - response_type = response_param.annotation - ctx_annotation = ctx_param.annotation + request_type = original_request_param.annotation if original_request_param.annotation != inspect.Parameter.empty else None + response_type = response_param.annotation if response_param.annotation != inspect.Parameter.empty else None + ctx_annotation = ctx_param.annotation if ctx_param.annotation != inspect.Parameter.empty else None return request_type, response_type, ctx_annotation, output_types, workflow_output_types diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index 5d8b310752..a06f1445e1 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -565,7 +565,7 @@ class TestExecutorExplicitTypes: def test_executor_with_explicit_input_type(self): """Test that explicit input_type takes precedence over introspection.""" - @executor(input_type=str) + @executor(input=str) async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -581,7 +581,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_explicit_output_type(self): """Test that explicit output_type takes precedence over introspection.""" - @executor(output_type=int) + @executor(output=int) async def process(message: str, ctx: WorkflowContext[str]) -> None: pass @@ -596,7 +596,7 @@ async def process(message: str, ctx: WorkflowContext[str]) -> None: def test_executor_with_explicit_input_and_output_types(self): """Test that both explicit input_type and output_type work together.""" - @executor(id="explicit_both", input_type=dict, output_type=list) + @executor(id="explicit_both", input=dict, output=list) async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -615,7 +615,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_explicit_union_input_type(self): """Test that explicit union input_type is handled correctly.""" - @executor(input_type=str | int) + @executor(input=str | int) async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -631,7 +631,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_explicit_union_output_type(self): """Test that explicit union output_type is normalized to a list.""" - @executor(output_type=str | int | bool) + @executor(output=str | int | bool) async def process(message: Any, ctx: WorkflowContext) -> None: pass @@ -643,7 +643,7 @@ def test_executor_explicit_types_precedence_over_introspection(self): # Introspection would give: input=str, output=[int] # Explicit gives: input=bytes, output=[float] - @executor(input_type=bytes, output_type=float) + @executor(input=bytes, output=float) async def process(message: str, ctx: WorkflowContext[int]) -> None: pass @@ -670,7 +670,7 @@ def test_executor_partial_explicit_types(self): """Test that partial explicit types work (only input_type or only output_type).""" # Only explicit input_type, introspect output_type - @executor(input_type=bytes) + @executor(input=bytes) async def process_input(message: str, ctx: WorkflowContext[int]) -> None: pass @@ -678,7 +678,7 @@ async def process_input(message: str, ctx: WorkflowContext[int]) -> None: assert int in process_input.output_types # Introspected # Only explicit output_type, introspect input_type - @executor(output_type=float) + @executor(output=float) async def process_output(message: str, ctx: WorkflowContext[int]) -> None: pass @@ -689,7 +689,7 @@ async def process_output(message: str, ctx: WorkflowContext[int]) -> None: def test_executor_explicit_input_type_allows_no_message_annotation(self): """Test that explicit input_type allows function without message type annotation.""" - @executor(input_type=str) + @executor(input=str) async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -700,7 +700,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_explicit_types_with_id(self): """Test that explicit types work together with id parameter.""" - @executor(id="custom_id", input_type=bytes, output_type=int) + @executor(id="custom_id", input=bytes, output=int) async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -711,7 +711,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_explicit_types_with_single_param_function(self): """Test that explicit input_type works with single-parameter functions.""" - @executor(input_type=str) + @executor(input=str) async def process(message): # type: ignore[no-untyped-def] return message.upper() @@ -723,7 +723,7 @@ async def process(message): # type: ignore[no-untyped-def] def test_executor_explicit_types_with_sync_function(self): """Test that explicit types work with synchronous functions.""" - @executor(input_type=int, output_type=str) + @executor(input=int, output=str) def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -736,7 +736,7 @@ def test_function_executor_constructor_with_explicit_types(self): async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass - func_exec = FunctionExecutor(process, id="test", input_type=dict, output_type=list) + func_exec = FunctionExecutor(process, id="test", input=dict, output=list) assert dict in func_exec._handlers spec = func_exec._handler_specs[0] @@ -747,7 +747,7 @@ def test_executor_explicit_union_types_via_typing_union(self): """Test that Union[] syntax also works for explicit types.""" from typing import Union - @executor(input_type=Union[str, int], output_type=Union[bool, float]) + @executor(input=Union[str, int], output=Union[bool, float]) async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -761,7 +761,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_string_forward_reference_input_type(self): """Test that string forward references work for input_type.""" - @executor(input_type="FuncExecForwardRefMessage") + @executor(input="FuncExecForwardRefMessage") async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -772,7 +772,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_string_forward_reference_union(self): """Test that string forward references work with union types.""" - @executor(input_type="FuncExecForwardRefTypeA | FuncExecForwardRefTypeB") + @executor(input="FuncExecForwardRefTypeA | FuncExecForwardRefTypeB") async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -783,7 +783,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_string_forward_reference_output_type(self): """Test that string forward references work for output_type.""" - @executor(input_type=str, output_type="FuncExecForwardRefResponse") + @executor(input=str, output="FuncExecForwardRefResponse") async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -793,7 +793,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_explicit_workflow_output_type(self): """Test that explicit workflow_output_type takes precedence over introspection.""" - @executor(workflow_output_type=bool) + @executor(workflow_output=bool) async def process(message: str, ctx: WorkflowContext[int]) -> None: pass @@ -809,7 +809,7 @@ async def process(message: str, ctx: WorkflowContext[int]) -> None: def test_executor_with_explicit_workflow_output_type_precedence(self): """Test that explicit workflow_output_type overrides introspected WorkflowContext second param.""" - @executor(workflow_output_type=str) + @executor(workflow_output=str) async def process(message: int, ctx: WorkflowContext[int, bool]) -> None: pass @@ -821,7 +821,7 @@ def test_executor_with_all_explicit_types(self): """Test that all three explicit type parameters work together.""" from typing import Any - @executor(input_type=str, output_type=int, workflow_output_type=bool) + @executor(input=str, output=int, workflow_output=bool) async def process(message: Any, ctx: WorkflowContext) -> None: pass @@ -838,7 +838,7 @@ async def process(message: Any, ctx: WorkflowContext) -> None: def test_executor_with_union_workflow_output_type(self): """Test that union types work for workflow_output_type.""" - @executor(workflow_output_type=str | int) + @executor(workflow_output=str | int) async def process(message: str, ctx: WorkflowContext) -> None: pass @@ -849,7 +849,7 @@ async def process(message: str, ctx: WorkflowContext) -> None: def test_executor_with_string_forward_reference_workflow_output_type(self): """Test that string forward references work for workflow_output_type.""" - @executor(input_type=str, workflow_output_type="FuncExecForwardRefResponse") + @executor(input=str, workflow_output="FuncExecForwardRefResponse") async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -859,7 +859,7 @@ async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-unt def test_executor_with_string_forward_reference_union_workflow_output_type(self): """Test that string forward reference union types work for workflow_output_type.""" - @executor(input_type=str, workflow_output_type="FuncExecForwardRefTypeA | FuncExecForwardRefTypeB") + @executor(input=str, workflow_output="FuncExecForwardRefTypeA | FuncExecForwardRefTypeB") async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -887,9 +887,9 @@ async def my_func(message: str, ctx: WorkflowContext) -> None: exec_instance = FunctionExecutor( my_func, id="test_constructor", - input_type=str, - output_type=int, - workflow_output_type=bool, + input=str, + output=int, + workflow_output=bool, ) assert str in exec_instance._handlers diff --git a/python/packages/core/tests/workflow/test_request_info_mixin.py b/python/packages/core/tests/workflow/test_request_info_mixin.py index d89794ef82..23b7663a0c 100644 --- a/python/packages/core/tests/workflow/test_request_info_mixin.py +++ b/python/packages/core/tests/workflow/test_request_info_mixin.py @@ -785,3 +785,170 @@ async def child_handler(self, original_request: str, response: bool, ctx: Workfl # Should not support unregistered combinations assert child.is_request_supported(str, str) is False assert child.is_request_supported(int, str) is False + + +class TestResponseHandlerExplicitTypes: + """Test cases for response_handler with explicit type parameters.""" + + def test_response_handler_with_explicit_types(self): + """Test response_handler with explicit request and response types.""" + + @response_handler(request=str, response=int) + async def test_handler(self, original_request, response, ctx) -> None: + pass + + spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + assert spec["name"] == "test_handler" + assert spec["request_type"] is str + assert spec["response_type"] is int + + def test_response_handler_with_explicit_output_types(self): + """Test response_handler with explicit output and workflow_output types.""" + + @response_handler(request=str, response=int, output=bool, workflow_output=float) + async def test_handler(self, original_request, response, ctx) -> None: + pass + + spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + assert spec["request_type"] is str + assert spec["response_type"] is int + assert bool in spec["output_types"] + assert float in spec["workflow_output_types"] + + def test_response_handler_with_union_types(self): + """Test response_handler with union types.""" + + @response_handler(request=str | int, response=bool | float) + async def test_handler(self, original_request, response, ctx) -> None: + pass + + spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + assert spec["request_type"] == str | int + assert spec["response_type"] == bool | float + + def test_response_handler_with_string_forward_references(self): + """Test response_handler with string forward references.""" + + @response_handler(request="str", response="int") + async def test_handler(self, original_request, response, ctx) -> None: + pass + + spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + assert spec["request_type"] is str + assert spec["response_type"] is int + + def test_response_handler_explicit_missing_request_raises_error(self): + """Test that using explicit types without request raises an error.""" + with pytest.raises(ValueError, match="must specify 'request' type"): + + @response_handler(response=int) + async def test_handler(self, original_request, response, ctx) -> None: + pass + + def test_response_handler_explicit_missing_response_raises_error(self): + """Test that using explicit types without response raises an error.""" + with pytest.raises(ValueError, match="must specify 'response' type"): + + @response_handler(request=str) + async def test_handler(self, original_request, response, ctx) -> None: + pass + + def test_response_handler_explicit_only_output_raises_error(self): + """Test that using only output without request/response raises an error.""" + with pytest.raises(ValueError, match="must specify 'request' type"): + + @response_handler(output=bool) + async def test_handler(self, original_request, response, ctx) -> None: + pass + + def test_executor_with_explicit_response_handlers(self): + """Test an executor with explicit type response handlers.""" + + class TestExecutor(Executor): + def __init__(self): + super().__init__(id="test_executor") + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler(request=str, response=int, output=bool) + async def handle_explicit(self, original_request, response, ctx) -> None: + pass + + executor = TestExecutor() + + # Should be request-response capable + assert executor.is_request_response_capable is True + + # Should have registered handler + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 1 + assert (str, int) in response_handlers + + # Check specs + specs = executor._response_handler_specs # type: ignore[reportAttributeAccessIssue] + assert len(specs) == 1 + assert specs[0]["request_type"] is str + assert specs[0]["response_type"] is int + assert bool in specs[0]["output_types"] + + def test_response_handler_explicit_callable(self): + """Test that explicit type response handlers can be called.""" + + class TestExecutor(Executor): + def __init__(self): + super().__init__(id="test_executor") + self.handled_request = None + self.handled_response = None + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler(request=str, response=int) + async def handle_response(self, original_request, response, ctx) -> None: + self.handled_request = original_request + self.handled_response = response + + executor = TestExecutor() + + # Get the handler + response_handler_func = executor._response_handlers[(str, int)] # type: ignore[reportAttributeAccessIssue] + + # Call the handler + asyncio.run(response_handler_func("test_request", 42, None)) # type: ignore[reportArgumentType] + + assert executor.handled_request == "test_request" + assert executor.handled_response == 42 + + def test_mixed_introspection_and_explicit_handlers(self): + """Test executor with both introspection and explicit type handlers.""" + + class TestExecutor(Executor): + def __init__(self): + super().__init__(id="test_executor") + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + # Introspection-based handler + @response_handler + async def handle_introspection( + self, original_request: str, response: int, ctx: WorkflowContext[str] + ) -> None: + pass + + # Explicit type handler + @response_handler(request=dict, response=bool) + async def handle_explicit(self, original_request, response, ctx) -> None: + pass + + executor = TestExecutor() + + # Should have both handlers + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 2 + assert (str, int) in response_handlers + assert (dict, bool) in response_handlers diff --git a/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py b/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py index 7ee4d2cf14..d8409918f4 100644 --- a/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py +++ b/python/samples/getting_started/workflows/checkpoint/handoff_with_tool_approval_checkpoint_resume.py @@ -7,17 +7,18 @@ from typing import cast from agent_framework import ( + AgentResponse, ChatAgent, ChatMessage, + Content, FileCheckpointStorage, - FunctionApprovalRequestContent, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, RequestInfoEvent, Workflow, WorkflowOutputEvent, WorkflowStatusEvent, - ai_function, + tool, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential @@ -26,7 +27,7 @@ Sample: Handoff Workflow with Tool Approvals + Checkpoint Resume Demonstrates the two-step pattern for resuming a handoff workflow from a checkpoint -while handling both HandoffUserInputRequest prompts and FunctionApprovalRequestContent +while handling both HandoffAgentUserRequest prompts and function approval request Content for tool calls (e.g., submit_refund). Scenario: @@ -51,7 +52,7 @@ CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) -@ai_function(approval_mode="always_require") +@tool(approval_mode="always_require") def submit_refund(refund_description: str, amount: str, order_id: str) -> str: """Capture a refund request for manual review before processing.""" return f"refund recorded for order {order_id} (amount: {amount}) with details: {refund_description}" @@ -102,7 +103,7 @@ def create_workflow(checkpoint_storage: FileCheckpointStorage) -> tuple[Workflow name="checkpoint_handoff_demo", participants=[triage, refund, order], ) - .set_coordinator("triage_agent") + .with_start_agent(triage) .with_checkpointing(checkpoint_storage) .with_termination_condition( # Terminate after 5 user messages for this demo @@ -114,30 +115,32 @@ def create_workflow(checkpoint_storage: FileCheckpointStorage) -> tuple[Workflow return workflow, triage, refund, order -def _print_handoff_request(request: HandoffUserInputRequest, request_id: str) -> None: +def _print_handoff_request(request: HandoffAgentUserRequest, request_id: str) -> None: """Log pending handoff request details for debugging.""" print(f"\n{'=' * 60}") print("WORKFLOW PAUSED - User input needed") print(f"Request ID: {request_id}") - print(f"Awaiting agent: {request.awaiting_agent_id}") - print(f"Prompt: {request.prompt}") - - # Note: After checkpoint restore, conversation may be empty because it's not serialized - # to prevent duplication (the conversation is preserved in the coordinator's state). - # See issue #2667. - if request.conversation: - print("\nConversation so far:") - for msg in request.conversation[-3:]: - author = msg.author_name or msg.role.value - snippet = msg.text[:120] + "..." if len(msg.text) > 120 else msg.text - print(f" {author}: {snippet}") - else: - print("\n(Conversation restored from checkpoint - context preserved in workflow state)") + + _print_handoff_agent_user_request(request.agent_response) print(f"{'=' * 60}\n") -def _print_function_approval_request(request: FunctionApprovalRequestContent, request_id: str) -> None: +def _print_handoff_agent_user_request(response: AgentResponse) -> None: + """Display the agent's response messages when requesting user input.""" + if not response.messages: + print("(No agent messages)") + return + + print("\n[Agent is requesting your input...]") + for message in response.messages: + if not message.text: + continue + speaker = message.author_name or message.role.value + print(f" {speaker}: {message.text}") + + +def _print_function_approval_request(request: Content, request_id: str) -> None: """Log pending tool approval details for debugging.""" args = request.function_call.parse_arguments() or {} print(f"\n{'=' * 60}") @@ -157,14 +160,14 @@ def _build_responses_for_requests( """Create response payloads for each pending request.""" responses: dict[str, object] = {} for request in pending_requests: - if isinstance(request.data, HandoffUserInputRequest): + if isinstance(request.data, HandoffAgentUserRequest): if user_response is None: - raise ValueError("User response is required for HandoffUserInputRequest") - responses[request.request_id] = user_response - elif isinstance(request.data, FunctionApprovalRequestContent): + raise ValueError("User response is required for HandoffAgentUserRequest") + responses[request.request_id] = HandoffAgentUserRequest.create_response(user_response) + elif isinstance(request.data, Content) and request.data.type == "function_approval_request": if approve_tools is None: - raise ValueError("Approval decision is required for FunctionApprovalRequestContent") - responses[request.request_id] = request.data.create_response(approved=approve_tools) + raise ValueError("Approval decision is required for function approval request") + responses[request.request_id] = request.data.to_function_approval_response(approved=approve_tools) else: raise ValueError(f"Unsupported request type: {type(request.data)}") return responses @@ -199,9 +202,9 @@ async def run_until_user_input_needed( elif isinstance(event, RequestInfoEvent): pending_requests.append(event) - if isinstance(event.data, HandoffUserInputRequest): + if isinstance(event.data, HandoffAgentUserRequest): _print_handoff_request(event.data, event.request_id) - elif isinstance(event.data, FunctionApprovalRequestContent): + elif isinstance(event.data, Content) and event.data.type == "function_approval_request": _print_function_approval_request(event.data, event.request_id) elif isinstance(event, WorkflowOutputEvent): @@ -256,9 +259,9 @@ async def resume_with_responses( async for event in workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id): # type: ignore[attr-defined] if isinstance(event, RequestInfoEvent): restored_requests.append(event) - if isinstance(event.data, HandoffUserInputRequest): + if isinstance(event.data, HandoffAgentUserRequest): _print_handoff_request(event.data, event.request_id) - elif isinstance(event.data, FunctionApprovalRequestContent): + elif isinstance(event.data, Content) and event.data.type == "function_approval_request": _print_function_approval_request(event.data, event.request_id) if not restored_requests: @@ -289,9 +292,9 @@ async def resume_with_responses( elif isinstance(event, RequestInfoEvent): new_pending_requests.append(event) - if isinstance(event.data, HandoffUserInputRequest): + if isinstance(event.data, HandoffAgentUserRequest): _print_handoff_request(event.data, event.request_id) - elif isinstance(event.data, FunctionApprovalRequestContent): + elif isinstance(event.data, Content) and event.data.type == "function_approval_request": _print_function_approval_request(event.data, event.request_id) return new_pending_requests, latest_checkpoint.checkpoint_id @@ -302,7 +305,7 @@ async def main() -> None: Demonstrate the checkpoint-based pause/resume pattern for handoff workflows. This sample shows: - 1. Starting a workflow and getting a HandoffUserInputRequest + 1. Starting a workflow and getting a HandoffAgentUserRequest 2. Pausing (checkpoint is saved automatically) 3. Resuming from checkpoint with a user response or tool approval (two-step pattern) 4. Continuing the conversation until completion @@ -361,8 +364,10 @@ async def main() -> None: print("\n>>> Simulating process restart...\n") workflow_step, _, _, _ = create_workflow(checkpoint_storage=storage) - needs_user_input = any(isinstance(req.data, HandoffUserInputRequest) for req in pending_requests) - needs_tool_approval = any(isinstance(req.data, FunctionApprovalRequestContent) for req in pending_requests) + needs_user_input = any(isinstance(req.data, HandoffAgentUserRequest) for req in pending_requests) + needs_tool_approval = any( + isinstance(req.data, Content) and req.data.type == "function_approval_request" for req in pending_requests + ) user_response = None if needs_user_input: From 5c6f195aeb1b7cbd8489d1f99e3b5efc2c26d64f Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Tue, 3 Feb 2026 10:22:57 +0900 Subject: [PATCH 8/9] Fix naming issue --- python/packages/core/agent_framework/_workflows/_magentic.py | 2 +- .../core/agent_framework/_workflows/_orchestrator_helpers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 8503aae2ce..eff87fd5f0 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -1367,7 +1367,7 @@ class MagenticBuilder: - `.with_plan_review()` - Review and approve/revise plans before execution - `.with_human_input_on_stall()` - Intervene when workflow stalls - - Tool approval via `FunctionApprovalRequestContent` - Approve individual tool calls + - Tool approval via `function_approval_request` - Approve individual tool calls These emit `MagenticHumanInterventionRequest` events that provide structured decision options (APPROVE, REVISE, CONTINUE, REPLAN, GUIDANCE) appropriate diff --git a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py index 09f118a6c6..82f6532ea2 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py +++ b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py @@ -23,7 +23,7 @@ def clean_conversation_for_handoff(conversation: list[ChatMessage]) -> list[Chat This creates a cleaned copy removing ALL tool-related content. Removes: - - FunctionApprovalRequestContent and FunctionCallContent from assistant messages + - function_approval_request and function_call from assistant messages - Tool response messages (Role.TOOL) - Messages with only tool calls and no text From d90f536bd1c20060a4081c45f01141b32d296731 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Tue, 3 Feb 2026 10:41:40 +0900 Subject: [PATCH 9/9] Fix mypy --- .../_workflows/_request_info_mixin.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index 336f78535a..ac7132e2fe 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -9,7 +9,7 @@ from builtins import type as builtin_type from collections.abc import Awaitable, Callable from types import UnionType -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from ._typing_utils import is_instance_of, is_type_compatible, normalize_type_to_list, resolve_type_annotation from ._workflow_context import WorkflowContext, validate_workflow_context_annotation @@ -255,9 +255,16 @@ def decorator( ctx_annotation = None else: # Use introspection - all types from annotations - final_request_type, final_response_type, ctx_annotation, final_output_types, final_workflow_output_types = ( - _validate_response_handler_signature(func) - ) + ( + inferred_request_type, + inferred_response_type, + ctx_annotation, + final_output_types, + final_workflow_output_types, + ) = _validate_response_handler_signature(func) + # In introspection mode, validation ensures these are not None (raises ValueError if missing) + final_request_type = cast(type, inferred_request_type) + final_response_type = cast(type, inferred_response_type) # Get signature for preservation sig = inspect.signature(func)