Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def __init__(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: HTTPStatus | None = None,
middlewares: list[Callable[..., Response]] | None = None,
):
Expand Down Expand Up @@ -420,6 +421,8 @@ def __init__(
Additional OpenAPI extensions as a dictionary.
deprecated: bool
Whether or not to mark this route as deprecated in the OpenAPI schema
enable_validation: bool | None, optional
Enable or disable validation for this specific route. If None, inherits from resolver setting.
custom_response_validation_http_code: int | HTTPStatus | None, optional
Whether to have custom http status code for this route if response validation fails
middlewares: list[Callable[..., Response]] | None
Expand Down Expand Up @@ -449,6 +452,7 @@ def __init__(
self.middlewares = middlewares or []
self.operation_id = operation_id or self._generate_operation_id()
self.deprecated = deprecated
self.enable_validation = enable_validation

# _middleware_stack_built is used to ensure the middleware stack is only built once.
self._middleware_stack_built = False
Expand Down Expand Up @@ -535,15 +539,21 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]],

all_middlewares = []

# Determine if validation should be enabled for this route
# If route has explicit enable_validation setting, use it; otherwise, use resolver's global setting
route_validation_enabled = (
self.enable_validation if self.enable_validation is not None else app._enable_validation
)

# Add request validation middleware first if validation is enabled
if hasattr(app, "_request_validation_middleware"):
if route_validation_enabled and hasattr(app, "_request_validation_middleware"):
all_middlewares.append(app._request_validation_middleware)

# Add user middlewares in the middle
all_middlewares.extend(router_middlewares + self.middlewares)

# Add response validation middleware before the route handler if validation is enabled
if hasattr(app, "_response_validation_middleware"):
if route_validation_enabled and hasattr(app, "_response_validation_middleware"):
all_middlewares.append(app._response_validation_middleware)

logger.debug(f"Building middleware stack: {all_middlewares}")
Expand Down Expand Up @@ -1132,6 +1142,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -1194,6 +1205,7 @@ def get(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -1235,6 +1247,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -1255,6 +1268,7 @@ def post(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -1297,6 +1311,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -1317,6 +1332,7 @@ def put(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -1359,6 +1375,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -1379,6 +1396,7 @@ def delete(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -1420,6 +1438,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -1440,6 +1459,7 @@ def patch(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -1484,6 +1504,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -1504,6 +1525,7 @@ def head(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -1547,6 +1569,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand Down Expand Up @@ -2568,6 +2591,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -2602,6 +2626,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT:
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand Down Expand Up @@ -3117,6 +3142,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand Down Expand Up @@ -3144,6 +3170,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT:
frozen_security,
frozen_openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
)

Expand Down Expand Up @@ -3233,6 +3260,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[AnyCallableT], AnyCallableT]:
Expand All @@ -3253,6 +3281,7 @@ def route(
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand Down
10 changes: 10 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get( # type: ignore[override]
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
Expand All @@ -144,6 +145,7 @@ def get( # type: ignore[override]
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -165,6 +167,7 @@ def post( # type: ignore[override]
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
Expand All @@ -185,6 +188,7 @@ def post( # type: ignore[override]
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -206,6 +210,7 @@ def put( # type: ignore[override]
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
Expand All @@ -226,6 +231,7 @@ def put( # type: ignore[override]
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -247,6 +253,7 @@ def patch( # type: ignore[override]
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
):
Expand All @@ -267,6 +274,7 @@ def patch( # type: ignore[override]
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand All @@ -288,6 +296,7 @@ def delete( # type: ignore[override]
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
enable_validation: bool | None = None,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
Expand All @@ -308,6 +317,7 @@ def delete( # type: ignore[override]
security,
openapi_extensions,
deprecated,
enable_validation,
custom_response_validation_http_code,
middlewares,
)
Expand Down
10 changes: 8 additions & 2 deletions aws_lambda_powertools/event_handler/http_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,18 @@ async def _run_middleware_chain_async(self, route: Route) -> Response:
# Build middleware list
all_middlewares: list[Callable[..., Any]] = []

if hasattr(self, "_request_validation_middleware"):
# Determine if validation should be enabled for this route
# If route has explicit enable_validation setting, use it; otherwise, use resolver's global setting
route_validation_enabled = (
route.enable_validation if route.enable_validation is not None else self._enable_validation
)

if route_validation_enabled and hasattr(self, "_request_validation_middleware"):
all_middlewares.append(self._request_validation_middleware)

all_middlewares.extend(self._router_middlewares + route.middlewares)

if hasattr(self, "_response_validation_middleware"):
if route_validation_enabled and hasattr(self, "_response_validation_middleware"):
all_middlewares.append(self._response_validation_middleware)

# Create the final handler that calls the route function
Expand Down
Loading