Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions ninja/patch_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Optional,
Expand All @@ -9,13 +11,20 @@
)

from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import core_schema

from ninja import Body
from ninja.orm import ModelSchema
from ninja.schema import Schema
from ninja.utils import is_optional_type

try:
copy_field_info: Callable[[FieldInfo], FieldInfo] = FieldInfo._copy
except AttributeError:
# Fallback for Pydantic<2.11.0
copy_field_info = copy.copy


class ModelToDict(dict):
_wrapped_model: Any = None
Expand Down Expand Up @@ -45,15 +54,20 @@ def get_schema_annotations(schema_cls: Type[Any]) -> Dict[str, Any]:
return annotations


def create_patch_schema(schema_cls: Type[Any]) -> Type[ModelToDict]:
def create_patch_schema(schema_cls: Type[BaseModel]) -> Type[ModelToDict]:
schema_annotations = get_schema_annotations(schema_cls)
values, annotations = {}, {}
# assert False, f"{schema_cls} - {schema_cls.model_fields}"
for f in schema_cls.model_fields.keys():
t = schema_annotations[f]
if not is_optional_type(t):
values[f] = getattr(schema_cls, f, None)
annotations[f] = Optional[t]
values: Dict[str, Any] = {}
annotations = {}

for name, field in schema_cls.model_fields.items():
annotation = schema_annotations[name]
if is_optional_type(annotation):
continue
patch_field = copy_field_info(field)
patch_field.default = None
patch_field.default_factory = None
values[name] = patch_field
annotations[name] = Optional[annotation]
values["__annotations__"] = annotations
OptionalSchema = type(f"{schema_cls.__name__}Patch", (schema_cls,), values)

Expand All @@ -65,7 +79,7 @@ class OptionalDictSchema(ModelToDict):


class PatchDictUtil:
def __getitem__(self, schema_cls: Any) -> Any:
def __getitem__(self, schema_cls: Type[BaseModel]) -> Any:
new_cls = create_patch_schema(schema_cls)
return Body[new_cls] # type: ignore

Expand Down
22 changes: 21 additions & 1 deletion tests/test_patch_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from ninja import NinjaAPI, Schema
from ninja import Field, NinjaAPI, Schema
from ninja.patch_dict import PatchDict
from ninja.testing import TestClient

Expand All @@ -15,6 +15,7 @@ class SomeSchema(Schema):
name: str
age: int
category: Optional[str] = None
identifier: str = Field(max_length=32)


class OtherSchema(SomeSchema):
Expand Down Expand Up @@ -47,6 +48,11 @@ def test_patch_calls(input: dict, output: dict):
assert response.json() == {"payload": output, "type": "<class 'dict'>"}


def test_patch_calls_bad_request():
response = client.patch("/patch", json={"identifier": "0" * 100})
assert response.status_code == 422


def test_schema():
"Checking that json schema properties are all optional"
schema = api.get_openapi_schema()
Expand All @@ -66,6 +72,13 @@ def test_schema():
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Category",
},
"identifier": {
"anyOf": [
{"maxLength": 32, "type": "string"},
{"type": "null"},
],
"title": "Identifier",
},
},
}

Expand Down Expand Up @@ -93,6 +106,13 @@ def test_inherited_schema():
"anyOf": [{"type": "integer"}, {"type": "null"}],
"title": "Age",
},
"identifier": {
"anyOf": [
{"maxLength": 32, "type": "string"},
{"type": "null"},
],
"title": "Identifier",
},
"other": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Other",
Expand Down