Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import sys
import types
from dataclasses import dataclass
from typing import Any, cast

Expand Down Expand Up @@ -110,7 +111,7 @@ def output_response(self) -> bool:
return self._output_response

@property
def workflow_output_types(self) -> list[type[Any]]:
def workflow_output_types(self) -> list[type[Any] | types.UnionType]:
# Override to declare AgentResponse as a possible output type only if enabled.
if self._output_response:
return [AgentResponse]
Expand Down
218 changes: 163 additions & 55 deletions python/packages/core/agent_framework/_workflows/_executor.py

Large diffs are not rendered by default.

122 changes: 112 additions & 10 deletions python/packages/core/agent_framework/_workflows/_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import asyncio
import inspect
import sys
import types
import typing
from collections.abc import Awaitable, Callable
from typing import Any

from ._executor import Executor
from ._typing_utils import normalize_type_to_list, resolve_type_annotation
from ._workflow_context import WorkflowContext, validate_workflow_context_annotation

if sys.version_info >= (3, 11):
Expand All @@ -41,12 +43,32 @@ class FunctionExecutor(Executor):
blocking the event loop.
"""

def __init__(self, func: Callable[..., Any], id: str | None = None):
def __init__(
self,
func: Callable[..., Any],
id: str | None = None,
*,
input: type | types.UnionType | str | None = None,
output: type | types.UnionType | str | None = None,
workflow_output: type | types.UnionType | str | None = None,
):
"""Initialize the FunctionExecutor with a user-defined function.

Args:
func: The function to wrap as an executor (can be sync or async)
id: Optional executor ID. If None, uses the function name.
input: Optional explicit input type(s) for this executor. Supports union types
(e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``).
When provided, takes precedence over introspection from the function's message
parameter annotation.
output: Optional explicit output type(s) that can be sent via ``ctx.send_message()``.
Supports union types (e.g., ``str | int``) and string forward references.
When provided, takes precedence over introspection from the ``WorkflowContext``
first generic parameter (T_Out).
workflow_output: Optional explicit output type(s) that can be yielded via
``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string
forward references. When provided, takes precedence over introspection from the
``WorkflowContext`` second generic parameter (T_W_Out).

Raises:
ValueError: If func is a staticmethod or classmethod (use @handler on instance methods instead)
Expand All @@ -60,8 +82,37 @@ def __init__(self, func: Callable[..., Any], id: str | None = None):
f"or create an Executor subclass and use @handler on instance methods instead."
)

# Resolve string forward references using the function's globals
resolved_input_type = resolve_type_annotation(input, func.__globals__) if input is not None else None
resolved_output_type = resolve_type_annotation(output, func.__globals__) if output is not None else None
resolved_workflow_output_type = (
resolve_type_annotation(workflow_output, func.__globals__) if workflow_output is not None else None
)

# Validate function signature and extract types
message_type, ctx_annotation, output_types, workflow_output_types = _validate_function_signature(func)
introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = (
_validate_function_signature(func, skip_message_annotation=resolved_input_type is not None)
)

# Use explicit types if provided, otherwise fall back to introspection
message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type
output_types: list[type[Any] | types.UnionType] = (
normalize_type_to_list(resolved_output_type)
if resolved_output_type is not None
else list(inferred_output_types)
)
final_workflow_output_types: list[type[Any] | types.UnionType] = (
normalize_type_to_list(resolved_workflow_output_type)
if resolved_workflow_output_type is not None
else list(inferred_workflow_output_types)
)

# Validate that we have a message type - provides a clear error if type information is missing
if message_type is None:
raise ValueError(
f"Function {func.__name__} requires either a message parameter type annotation "
"or an explicit input_type parameter"
)

# Store the original function
self._original_func = func
Expand Down Expand Up @@ -106,7 +157,7 @@ async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any:
message_type=message_type,
ctx_annotation=ctx_annotation,
output_types=output_types,
workflow_output_types=workflow_output_types,
workflow_output_types=final_workflow_output_types,
)

# Now we can safely call _discover_handlers (it won't find any class-level handlers)
Expand All @@ -127,11 +178,22 @@ def executor(func: Callable[..., Any]) -> FunctionExecutor: ...


@overload
def executor(*, id: str | None = None) -> Callable[[Callable[..., Any]], FunctionExecutor]: ...
def executor(
*,
id: str | None = None,
input: type | types.UnionType | str | None = None,
output: type | types.UnionType | str | None = None,
workflow_output: type | types.UnionType | str | None = None,
) -> Callable[[Callable[..., Any]], FunctionExecutor]: ...


def executor(
func: Callable[..., Any] | None = None, *, id: str | None = None
func: Callable[..., Any] | None = None,
*,
id: str | None = None,
input: type | types.UnionType | str | None = None,
output: type | types.UnionType | str | None = None,
workflow_output: type | types.UnionType | str | None = None,
) -> Callable[[Callable[..., Any]], FunctionExecutor] | FunctionExecutor:
"""Decorator that converts a standalone function into a FunctionExecutor instance.

Expand Down Expand Up @@ -162,6 +224,25 @@ def process_data(data: str):
return data.upper()


# Using explicit types (takes precedence over introspection):
# Note: No type annotations on function parameters when using explicit types
@executor(id="my_executor", input=str | int, output=bool)
async def process(message, ctx):
await ctx.send_message(True)


# Using string forward references:
@executor(input="MyCustomType | int", output="ResponseType")
async def process(message, ctx): ...


# Specifying both output types (send_message and yield_output):
@executor(input=str, output=int, workflow_output=bool)
async def process(message, ctx):
await ctx.send_message(42) # int - matches output
await ctx.yield_output(True) # bool - matches workflow_output


# For class-based executors, use @handler instead:
class MyExecutor(Executor):
def __init__(self):
Expand All @@ -174,6 +255,18 @@ async def process(self, data: str, ctx: WorkflowContext[str]):
Args:
func: The function to decorate (when used without parentheses)
id: Optional custom ID for the executor. If None, uses the function name.
input: Optional explicit input type(s) for this executor. Supports union types
(e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``).
When provided, takes precedence over introspection from the function's message
parameter annotation.
output: Optional explicit output type(s) that can be sent via ``ctx.send_message()``.
Supports union types (e.g., ``str | int``) and string forward references.
When provided, takes precedence over introspection from the ``WorkflowContext``
first generic parameter (T_Out).
workflow_output: Optional explicit output type(s) that can be yielded via
``ctx.yield_output()``. Supports union types (e.g., ``str | int``) and string
forward references. When provided, takes precedence over introspection from the
``WorkflowContext`` second generic parameter (T_W_Out).

Returns:
A FunctionExecutor instance that can be wired into a Workflow.
Expand All @@ -183,7 +276,7 @@ async def process(self, data: str, ctx: WorkflowContext[str]):
"""

def wrapper(func: Callable[..., Any]) -> FunctionExecutor:
return FunctionExecutor(func, id=id)
return FunctionExecutor(func, id=id, input=input, output=output, workflow_output=workflow_output)

# If func is provided, this means @executor was used without parentheses
if func is not None:
Expand All @@ -198,14 +291,21 @@ def wrapper(func: Callable[..., Any]) -> FunctionExecutor:
# region Function Validation


def _validate_function_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]:
def _validate_function_signature(
func: Callable[..., Any],
*,
skip_message_annotation: bool = False,
) -> tuple[type | None, Any, list[type[Any] | types.UnionType], list[type[Any] | types.UnionType]]:
"""Validate function signature for executor functions.

Args:
func: The function to validate
skip_message_annotation: If True, skip validation that message parameter has a type
annotation. Used when input is explicitly provided to the @executor decorator.

Returns:
Tuple of (message_type, ctx_annotation, output_types, workflow_output_types)
Tuple of (message_type, ctx_annotation, output_types, workflow_output_types).
message_type may be None if skip_message_annotation is True and no annotation exists.

Raises:
ValueError: If the function signature is invalid
Expand All @@ -220,13 +320,15 @@ def _validate_function_signature(func: Callable[..., Any]) -> tuple[type, Any, l
f"Function instance {func.__name__} must have {param_description}. Got {len(params)} parameters."
)

# Check message parameter has type annotation
# Check message parameter has type annotation (unless skipped)
message_param = params[0]
if message_param.annotation == inspect.Parameter.empty:
if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty:
raise ValueError(f"Function instance {func.__name__} must have a type annotation for the message parameter")

type_hints = typing.get_type_hints(func)
message_type = type_hints.get(message_param.name, message_param.annotation)
if message_type == inspect.Parameter.empty:
message_type = None

# Check if there's a context parameter
if len(params) == 2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ class MagenticBuilder:

- `.with_plan_review()` - Review and approve/revise plans before execution
- `.with_human_input_on_stall()` - Intervene when workflow stalls
- Tool approval via `FunctionApprovalRequestContent` - Approve individual tool calls
- Tool approval via `function_approval_request` - Approve individual tool calls

These emit `MagenticHumanInterventionRequest` events that provide structured
decision options (APPROVE, REVISE, CONTINUE, REPLAN, GUIDANCE) appropriate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def clean_conversation_for_handoff(conversation: list[ChatMessage]) -> list[Chat
This creates a cleaned copy removing ALL tool-related content.

Removes:
- FunctionApprovalRequestContent and FunctionCallContent from assistant messages
- function_approval_request and function_call from assistant messages
- Tool response messages (Role.TOOL)
- Messages with only tool calls and no text

Expand Down
Loading
Loading