diff --git a/ninja/patch_dict.py b/ninja/patch_dict.py index 96fb65a85..4ec8f64b1 100644 --- a/ninja/patch_dict.py +++ b/ninja/patch_dict.py @@ -1,6 +1,8 @@ +import copy from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Generic, Optional, @@ -9,6 +11,7 @@ ) from pydantic import BaseModel +from pydantic.fields import FieldInfo from pydantic_core import core_schema from ninja import Body @@ -16,6 +19,12 @@ from ninja.schema import Schema from ninja.utils import is_optional_type +try: + copy_field_info: Callable[[FieldInfo], FieldInfo] = FieldInfo._copy +except AttributeError: + # Fallback for Pydantic<2.11.0 + copy_field_info = copy.copy + class ModelToDict(dict): _wrapped_model: Any = None @@ -45,15 +54,20 @@ def get_schema_annotations(schema_cls: Type[Any]) -> Dict[str, Any]: return annotations -def create_patch_schema(schema_cls: Type[Any]) -> Type[ModelToDict]: +def create_patch_schema(schema_cls: Type[BaseModel]) -> Type[ModelToDict]: schema_annotations = get_schema_annotations(schema_cls) - values, annotations = {}, {} - # assert False, f"{schema_cls} - {schema_cls.model_fields}" - for f in schema_cls.model_fields.keys(): - t = schema_annotations[f] - if not is_optional_type(t): - values[f] = getattr(schema_cls, f, None) - annotations[f] = Optional[t] + values: Dict[str, Any] = {} + annotations = {} + + for name, field in schema_cls.model_fields.items(): + annotation = schema_annotations[name] + if is_optional_type(annotation): + continue + patch_field = copy_field_info(field) + patch_field.default = None + patch_field.default_factory = None + values[name] = patch_field + annotations[name] = Optional[annotation] values["__annotations__"] = annotations OptionalSchema = type(f"{schema_cls.__name__}Patch", (schema_cls,), values) @@ -65,7 +79,7 @@ class OptionalDictSchema(ModelToDict): class PatchDictUtil: - def __getitem__(self, schema_cls: Any) -> Any: + def __getitem__(self, schema_cls: Type[BaseModel]) -> Any: new_cls = create_patch_schema(schema_cls) return Body[new_cls] # type: ignore diff --git a/tests/test_patch_dict.py b/tests/test_patch_dict.py index c376bb4aa..195312afc 100644 --- a/tests/test_patch_dict.py +++ b/tests/test_patch_dict.py @@ -2,7 +2,7 @@ import pytest -from ninja import NinjaAPI, Schema +from ninja import Field, NinjaAPI, Schema from ninja.patch_dict import PatchDict from ninja.testing import TestClient @@ -15,6 +15,7 @@ class SomeSchema(Schema): name: str age: int category: Optional[str] = None + identifier: str = Field(max_length=32) class OtherSchema(SomeSchema): @@ -47,6 +48,11 @@ def test_patch_calls(input: dict, output: dict): assert response.json() == {"payload": output, "type": ""} +def test_patch_calls_bad_request(): + response = client.patch("/patch", json={"identifier": "0" * 100}) + assert response.status_code == 422 + + def test_schema(): "Checking that json schema properties are all optional" schema = api.get_openapi_schema() @@ -66,6 +72,13 @@ def test_schema(): "anyOf": [{"type": "string"}, {"type": "null"}], "title": "Category", }, + "identifier": { + "anyOf": [ + {"maxLength": 32, "type": "string"}, + {"type": "null"}, + ], + "title": "Identifier", + }, }, } @@ -93,6 +106,13 @@ def test_inherited_schema(): "anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Age", }, + "identifier": { + "anyOf": [ + {"maxLength": 32, "type": "string"}, + {"type": "null"}, + ], + "title": "Identifier", + }, "other": { "anyOf": [{"type": "string"}, {"type": "null"}], "title": "Other",