diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index c6f7de3e0cc..6555e0b8eb0 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -378,6 +378,7 @@ def __init__( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: HTTPStatus | None = None, middlewares: list[Callable[..., Response]] | None = None, ): @@ -420,6 +421,8 @@ def __init__( Additional OpenAPI extensions as a dictionary. deprecated: bool Whether or not to mark this route as deprecated in the OpenAPI schema + enable_validation: bool | None, optional + Enable or disable validation for this specific route. If None, inherits from resolver setting. custom_response_validation_http_code: int | HTTPStatus | None, optional Whether to have custom http status code for this route if response validation fails middlewares: list[Callable[..., Response]] | None @@ -449,6 +452,7 @@ def __init__( self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() self.deprecated = deprecated + self.enable_validation = enable_validation # _middleware_stack_built is used to ensure the middleware stack is only built once. self._middleware_stack_built = False @@ -535,15 +539,21 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]], all_middlewares = [] + # Determine if validation should be enabled for this route + # If route has explicit enable_validation setting, use it; otherwise, use resolver's global setting + route_validation_enabled = ( + self.enable_validation if self.enable_validation is not None else app._enable_validation + ) + # Add request validation middleware first if validation is enabled - if hasattr(app, "_request_validation_middleware"): + if route_validation_enabled and hasattr(app, "_request_validation_middleware"): all_middlewares.append(app._request_validation_middleware) # Add user middlewares in the middle all_middlewares.extend(router_middlewares + self.middlewares) # Add response validation middleware before the route handler if validation is enabled - if hasattr(app, "_response_validation_middleware"): + if route_validation_enabled and hasattr(app, "_response_validation_middleware"): all_middlewares.append(app._response_validation_middleware) logger.debug(f"Building middleware stack: {all_middlewares}") @@ -1132,6 +1142,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -1194,6 +1205,7 @@ def get( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -1235,6 +1247,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -1255,6 +1268,7 @@ def post( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -1297,6 +1311,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -1317,6 +1332,7 @@ def put( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -1359,6 +1375,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -1379,6 +1396,7 @@ def delete( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -1420,6 +1438,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -1440,6 +1459,7 @@ def patch( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -1484,6 +1504,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -1504,6 +1525,7 @@ def head( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -1547,6 +1569,7 @@ def lambda_handler(event, context): security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -2568,6 +2591,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -2602,6 +2626,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT: security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -3117,6 +3142,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -3144,6 +3170,7 @@ def register_route(func: AnyCallableT) -> AnyCallableT: frozen_security, frozen_openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, ) @@ -3233,6 +3260,7 @@ def route( security: list[dict[str, list[str]]] | None = None, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[AnyCallableT], AnyCallableT]: @@ -3253,6 +3281,7 @@ def route( security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 7a9e0cde972..4593715e88d 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -124,6 +124,7 @@ def get( # type: ignore[override] include_in_schema: bool = True, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -144,6 +145,7 @@ def get( # type: ignore[override] security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -165,6 +167,7 @@ def post( # type: ignore[override] include_in_schema: bool = True, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): @@ -185,6 +188,7 @@ def post( # type: ignore[override] security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -206,6 +210,7 @@ def put( # type: ignore[override] include_in_schema: bool = True, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): @@ -226,6 +231,7 @@ def put( # type: ignore[override] security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -247,6 +253,7 @@ def patch( # type: ignore[override] include_in_schema: bool = True, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable] | None = None, ): @@ -267,6 +274,7 @@ def patch( # type: ignore[override] security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) @@ -288,6 +296,7 @@ def delete( # type: ignore[override] include_in_schema: bool = True, openapi_extensions: dict[str, Any] | None = None, deprecated: bool = False, + enable_validation: bool | None = None, custom_response_validation_http_code: int | HTTPStatus | None = None, middlewares: list[Callable[..., Any]] | None = None, ): @@ -308,6 +317,7 @@ def delete( # type: ignore[override] security, openapi_extensions, deprecated, + enable_validation, custom_response_validation_http_code, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/http_resolver.py b/aws_lambda_powertools/event_handler/http_resolver.py index 5b6ff3f5adf..0be443bd200 100644 --- a/aws_lambda_powertools/event_handler/http_resolver.py +++ b/aws_lambda_powertools/event_handler/http_resolver.py @@ -290,12 +290,18 @@ async def _run_middleware_chain_async(self, route: Route) -> Response: # Build middleware list all_middlewares: list[Callable[..., Any]] = [] - if hasattr(self, "_request_validation_middleware"): + # Determine if validation should be enabled for this route + # If route has explicit enable_validation setting, use it; otherwise, use resolver's global setting + route_validation_enabled = ( + route.enable_validation if route.enable_validation is not None else self._enable_validation + ) + + if route_validation_enabled and hasattr(self, "_request_validation_middleware"): all_middlewares.append(self._request_validation_middleware) all_middlewares.extend(self._router_middlewares + route.middlewares) - if hasattr(self, "_response_validation_middleware"): + if route_validation_enabled and hasattr(self, "_response_validation_middleware"): all_middlewares.append(self._response_validation_middleware) # Create the final handler that calls the route function diff --git a/examples/event_handler_rest/src/per_route_validation.py b/examples/event_handler_rest/src/per_route_validation.py new file mode 100644 index 00000000000..bdeac13f4ff --- /dev/null +++ b/examples/event_handler_rest/src/per_route_validation.py @@ -0,0 +1,135 @@ +from typing import List + +from pydantic import BaseModel, Field + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.utilities.typing import LambdaContext + +logger = Logger() +# Enable validation globally +app = APIGatewayRestResolver(enable_validation=True) + + +class Task(BaseModel): + """Task model with validation""" + + id: int = Field(ge=1, description="Task ID must be positive") + title: str = Field(min_length=1, max_length=100, description="Task title") + completed: bool = Field(default=False, description="Task completion status") + + +class LegacyResponse(BaseModel): + """Response model used by legacy endpoints""" + + status: str + data: dict + + +@app.get("/tasks/") +def get_task(task_id: int) -> Task: + """ + This route inherits global validation (enable_validation=True from resolver). + Request and response will be validated against OpenAPI schema. + """ + logger.info(f"Getting task {task_id}") + return Task(id=task_id, title="Sample Task", completed=False) + + +@app.post("/tasks") +def create_task(task: Task) -> Task: + """ + This route also inherits global validation. + Request body will be validated and parsed into Task model. + """ + logger.info(f"Creating task: {task.title}") + return task + + +@app.get("/legacy/status", enable_validation=False) +def legacy_status_check(): + """ + This route explicitly disables validation even though resolver has it enabled. + Useful for legacy endpoints that don't conform to your OpenAPI schema yet. + + The response can be any dict - no validation will occur. + """ + logger.info("Legacy status check - no validation") + # This response doesn't match any model - that's OK with validation disabled + return { + "status": "ok", + "timestamp": "2024-01-01", + "extra_field": "not in schema", + "nested": {"arbitrary": "data"}, + } + + +@app.get("/legacy/info", enable_validation=False) +def legacy_info() -> dict: + """ + Another legacy endpoint with validation disabled. + Can return arbitrary structure without validation. + """ + return { + "version": "1.0", + "mode": "legacy", + "features": ["one", "two", "three"], + } + + +@app.get("/tasks") +def list_tasks() -> List[Task]: + """ + This route has validation enabled (inherited from resolver). + Response will be validated to ensure it's a list of Task objects. + """ + logger.info("Listing all tasks") + return [ + Task(id=1, title="First Task", completed=True), + Task(id=2, title="Second Task", completed=False), + ] + + +@app.delete("/tasks/", enable_validation=False) +def delete_task(task_id: str): + """ + Validation disabled for this endpoint - maybe it's being migrated. + Notice task_id is a str here (not int) - validation would normally catch this. + """ + logger.info(f"Deleting task (no validation): {task_id}") + return {"message": f"Task {task_id} deleted"} + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) + + +""" +Benefits of per-route validation: + +1. **Gradual Migration**: Enable validation globally, then disable it for legacy routes + that need more time to be updated. + +2. **Mixed Workloads**: Validate critical business logic endpoints while allowing + flexibility for internal/admin endpoints. + +3. **Performance**: Disable validation for high-throughput endpoints where you trust + the input and want to minimize overhead. + +4. **Development**: Enable validation for new features while keeping old code working. + +Example requests: + +# Validated endpoint (will check task_id is int, response matches Task model) +GET /tasks/123 + +# Legacy endpoint (no validation, returns any structure) +GET /legacy/status + +# Validated POST (request body must match Task model) +POST /tasks +{"id": 1, "title": "New Task", "completed": false} + +# Legacy delete (no validation, task_id can be any string) +DELETE /tasks/abc123 +""" diff --git a/tests/functional/event_handler/_pydantic/test_per_route_validation.py b/tests/functional/event_handler/_pydantic/test_per_route_validation.py new file mode 100644 index 00000000000..bd5c33ae0b3 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/test_per_route_validation.py @@ -0,0 +1,276 @@ +from typing import cast + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from tests.functional.utils import load_event + + +class TodoItem(BaseModel): + name: str + completed: bool = False + + +def test_per_route_validation_enabled_on_single_route(): + # GIVEN APIGatewayRestResolver with global enable_validation + # AND one route with explicit enable_validation=True + # AND one route without explicit validation (inherits global) + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/explicitly-validated", enable_validation=True) + def explicitly_validated_route() -> TodoItem: + return TodoItem(name="test", completed=True) + + @app.get("/inherit-validated") + def inherit_validated_route() -> TodoItem: + return TodoItem(name="inherit", completed=False) + + # WHEN calling the explicitly validated route + event = load_event("apiGatewayProxyEvent.json") + event["path"] = "/explicitly-validated" + event["httpMethod"] = "GET" + + result = app(event, {}) + + # THEN response should be validated and successful + assert result["statusCode"] == 200 + assert '"name":"test"' in result["body"] + + # WHEN calling the route that inherits validation + event["path"] = "/inherit-validated" + result = app(event, {}) + + # THEN response should also be validated + assert result["statusCode"] == 200 + assert "inherit" in result["body"] + + +def test_per_route_validation_disabled_on_single_route(): + # GIVEN APIGatewayRestResolver with global enable_validation=True + # AND one route with enable_validation=False + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/validated") + def validated_route() -> TodoItem: + return TodoItem(name="test", completed=True) + + @app.get("/not-validated", enable_validation=False) + def not_validated_route() -> dict: + # This returns invalid data that doesn't match TodoItem but should not fail + return {"invalid": "data", "extra": "field"} + + # WHEN calling the validated route + event = load_event("apiGatewayProxyEvent.json") + event["path"] = "/validated" + event["httpMethod"] = "GET" + + result = app(event, {}) + + # THEN response should be validated and successful + assert result["statusCode"] == 200 + assert '"name":"test"' in result["body"] + + # WHEN calling the non-validated route with invalid response + event["path"] = "/not-validated" + result = app(event, {}) + + # THEN response should bypass validation + assert result["statusCode"] == 200 + assert "invalid" in result["body"] + + +def test_per_route_validation_request_body_validation(): + # GIVEN APIGatewayRestResolver WITH global validation enabled + # AND routes with different validation settings + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/create") + def create_item(item: TodoItem) -> TodoItem: + return item + + @app.post("/create-no-validation", enable_validation=False) + def create_item_no_validation() -> dict: + # Without validation, we manually parse the body + body = app.current_event.json_body + return body + + # WHEN calling validated route with valid body + event = load_event("apiGatewayProxyEvent.json") + event["path"] = "/create" + event["httpMethod"] = "POST" + event["body"] = '{"name": "New Task", "completed": false}' + + result = app(event, {}) + + # THEN request should be validated and successful + assert result["statusCode"] == 200 + assert "New Task" in result["body"] + + # WHEN calling validated route with invalid body + event["body"] = '{"invalid": "data"}' + result = app(event, {}) + + # THEN validation should fail with 422 + assert result["statusCode"] == 422 + + # WHEN calling non-validated route with any body + event["path"] = "/create-no-validation" + event["body"] = '{"invalid": "data"}' + result = app(event, {}) + + # THEN should succeed without validation + assert result["statusCode"] == 200 + + +def test_per_route_validation_inherits_from_resolver(): + # GIVEN APIGatewayRestResolver with global enable_validation=True + # AND routes without explicit enable_validation setting + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/route1") + def route1() -> TodoItem: + return TodoItem(name="test", completed=True) + + @app.post("/route2") + def route2(item: TodoItem) -> TodoItem: + return item + + # WHEN calling routes without explicit validation setting + event = load_event("apiGatewayProxyEvent.json") + event["path"] = "/route1" + event["httpMethod"] = "GET" + + result = app(event, {}) + + # THEN they should inherit global validation setting + assert result["statusCode"] == 200 + + # WHEN calling POST with invalid body + event["path"] = "/route2" + event["httpMethod"] = "POST" + event["body"] = '{"invalid": "data"}' + + result = app(event, {}) + + # THEN validation should be applied (422 error) + assert result["statusCode"] == 422 + + +def test_per_route_validation_mixed_routes(): + # GIVEN APIGatewayRestResolver with mixed validation settings + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/always-validated") + def always_validated() -> TodoItem: + return TodoItem(name="validated", completed=True) + + @app.get("/never-validated", enable_validation=False) + def never_validated(): + # Return invalid TodoItem structure + return {"wrong": "structure"} + + @app.get("/inherit-global") + def inherit_global() -> TodoItem: + return TodoItem(name="inherit", completed=False) + + event = load_event("apiGatewayProxyEvent.json") + event["httpMethod"] = "GET" + + # WHEN calling route with global validation (enable_validation not set) + event["path"] = "/inherit-global" + result = app(event, {}) + assert result["statusCode"] == 200 + assert "inherit" in result["body"] + + # WHEN calling route with explicit validation=False returning invalid data + event["path"] = "/never-validated" + result = app(event, {}) + # THEN should succeed without validation + assert result["statusCode"] == 200 + assert "wrong" in result["body"] + + # WHEN calling route with inherited validation + event["path"] = "/always-validated" + result = app(event, {}) + assert result["statusCode"] == 200 + assert "validated" in result["body"] + + +def test_per_route_validation_with_resolver_disabled(): + # GIVEN APIGatewayRestResolver with global validation disabled (default) + # Note: Per-route enable_validation=True requires the resolver to have + # enable_validation=True for the middleware to exist. This test documents + # that you can't opt-in to validation per-route without global validation. + app = APIGatewayRestResolver() # enable_validation=False by default + + @app.get("/no-explicit-setting") + def default_route() -> TodoItem: + return TodoItem(name="test", completed=True) + + event = load_event("apiGatewayProxyEvent.json") + event["httpMethod"] = "GET" + + # WHEN calling route without explicit setting (inherits False) + event["path"] = "/no-explicit-setting" + result = app(event, {}) + + # THEN should not be validated (returns as-is) + assert result["statusCode"] == 200 + assert "test" in result["body"] + + +def test_per_route_validation_response_error_code(): + # GIVEN APIGatewayRestResolver with custom response_validation_error_http_code + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/invalid-response") + def invalid_response() -> TodoItem: + # Return dict that doesn't match TodoItem model to test validation error handling + return cast(TodoItem, {"bad": "response"}) + + # WHEN calling route that returns invalid response + event = load_event("apiGatewayProxyEvent.json") + event["path"] = "/invalid-response" + event["httpMethod"] = "GET" + + result = app(event, {}) + + # THEN should return 422 Unprocessable Entity (default response validation error code) + assert result["statusCode"] == 422 + + +def test_per_route_validation_with_pydantic_v2(): + """Test that per-route validation works correctly with Pydantic v2 models""" + # GIVEN APIGatewayRestResolver with mixed validation + app = APIGatewayRestResolver() + + class Task(BaseModel): + title: str + priority: int + + @app.get("/task", enable_validation=True) + def get_task() -> Task: + return Task(title="Important", priority=1) + + @app.get("/unvalidated-task") + def get_unvalidated_task(): + return {"title": "Anything", "extra": "field"} + + event = load_event("apiGatewayProxyEvent.json") + event["httpMethod"] = "GET" + + # WHEN calling validated route + event["path"] = "/task" + result = app(event, {}) + + # THEN should validate and serialize correctly + assert result["statusCode"] == 200 + assert "Important" in result["body"] + + # WHEN calling unvalidated route + event["path"] = "/unvalidated-task" + result = app(event, {}) + + # THEN should return as-is without validation + assert result["statusCode"] == 200 + assert "extra" in result["body"]