Skip to content

Commit 361152a

Browse files
committed
updated field validators extraction logic
1 parent 878d864 commit 361152a

File tree

3 files changed

+71
-64
lines changed

3 files changed

+71
-64
lines changed

fastapi_jsonapi/schema_builder.py

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pydantic import BaseModel as PydanticBaseModel
2121
from pydantic.class_validators import (
2222
ROOT_VALIDATOR_CONFIG_KEY,
23-
VALIDATOR_CONFIG_KEY,
2423
extract_validators,
2524
inherit_validators,
2625
)
@@ -297,7 +296,10 @@ def _get_info_from_schema_for_building(
297296
# works both for to-one and to-many
298297
included_schemas.append((name, field.type_, relationship.resource_type))
299298
elif name == "id":
300-
id_validators = self._extract_field_validators(schema, target_field_name="id")
299+
id_validators = self._extract_field_validators(
300+
schema,
301+
include_for_field_names={"id"},
302+
)
301303
resource_id_field = (*(resource_id_field[:-1]), id_validators)
302304

303305
if not field.field_info.extra.get("client_can_set_id"):
@@ -398,7 +400,13 @@ def deduplicate_field_validators(self, validators: dict) -> dict:
398400

399401
return result_validators
400402

401-
def prepare_validators(self, model: Type[BaseModel]):
403+
def _extract_field_validators(
404+
self,
405+
model: Type[BaseModel],
406+
*,
407+
include_for_field_names: Set[str] = None,
408+
exclude_for_field_names: Set[str] = None,
409+
):
402410
validators = inherit_validators(
403411
extract_validators(model.__dict__),
404412
deepcopy(model.__validators__),
@@ -411,8 +419,21 @@ def prepare_validators(self, model: Type[BaseModel]):
411419
"check_fields",
412420
)
413421

422+
exclude_for_field_names = exclude_for_field_names or set()
423+
424+
if include_for_field_names and exclude_for_field_names:
425+
exclude_for_field_names = include_for_field_names.difference(
426+
exclude_for_field_names,
427+
)
428+
414429
result_validators = {}
415430
for field_name, field_validators in validators.items():
431+
if field_name in exclude_for_field_names:
432+
continue
433+
434+
if include_for_field_names and field_name not in include_for_field_names:
435+
continue
436+
416437
field_validator: Validator
417438
for field_validator in field_validators:
418439
validator_name = f"{field_name}_{field_validator.func.__name__}_validator"
@@ -452,7 +473,7 @@ def _unpack_validators(self, model: Type[BaseModel], validator_config_key: str)
452473
:param validator_config_key: Choice field, available options are pydantic consts
453474
VALIDATOR_CONFIG_KEY, ROOT_VALIDATOR_CONFIG_KEY
454475
"""
455-
root_validator_class_methods = {
476+
validator_class_methods = {
456477
# validators only
457478
attr_name: value
458479
for attr_name, value in model.__dict__.items()
@@ -461,7 +482,7 @@ def _unpack_validators(self, model: Type[BaseModel], validator_config_key: str)
461482

462483
return {
463484
validator_name: getattr(validator_method, validator_config_key)
464-
for validator_name, validator_method in root_validator_class_methods.items()
485+
for validator_name, validator_method in validator_class_methods.items()
465486
}
466487

467488
def _extract_root_validators(self, model: Type[BaseModel]) -> Dict[str, Callable]:
@@ -477,50 +498,6 @@ def _extract_root_validators(self, model: Type[BaseModel]) -> Dict[str, Callable
477498

478499
return validators
479500

480-
def _extract_field_validators(
481-
self,
482-
model: Type[BaseModel],
483-
target_field_name: str = None,
484-
exclude_for_field_names: Set[str] = None,
485-
) -> Dict[str, Callable]:
486-
"""
487-
:param model: Type[BaseModel]
488-
:param target_field_name: Name of field for which validators will be returned.
489-
If not set the function will return validators for all fields.
490-
"""
491-
validators = {}
492-
validator_origin_param_keys = ("pre", "each_item", "always", "check_fields")
493-
494-
unpacked_validators = self._unpack_validators(model, VALIDATOR_CONFIG_KEY)
495-
for validator_name, (field_names, validator_instance) in unpacked_validators.items():
496-
if target_field_name and target_field_name not in field_names:
497-
continue
498-
elif target_field_name:
499-
field_names = [target_field_name] # noqa: PLW2901
500-
501-
if exclude_for_field_names:
502-
field_names = [ # noqa: PLW2901
503-
# filter names
504-
field_name
505-
for field_name in field_names
506-
if field_name not in exclude_for_field_names
507-
]
508-
509-
if not field_names:
510-
continue
511-
512-
validators[validator_name] = validator(
513-
*field_names,
514-
allow_reuse=True,
515-
**{
516-
# copy origin params
517-
param_name: getattr(validator_instance, param_name)
518-
for param_name in validator_origin_param_keys
519-
},
520-
)(validator_instance.func)
521-
522-
return validators
523-
524501
def _extract_validators(
525502
self,
526503
model: Type[BaseModel],

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from copy import deepcopy
33
from itertools import chain, zip_longest
44
from json import dumps
5-
from typing import Dict, List
5+
from typing import Dict, List, Type
66
from uuid import UUID, uuid4
77

88
from fastapi import FastAPI, status
@@ -2050,6 +2050,11 @@ def _refresh_caches(self) -> None:
20502050

20512051
RoutersJSONAPI.all_jsonapi_routers = all_jsonapi_routers
20522052

2053+
def _clear_cache(self):
2054+
SchemaBuilder.object_schemas_cache = {}
2055+
SchemaBuilder.relationship_schema_cache = {}
2056+
SchemaBuilder.base_jsonapi_object_schemas_cache = {}
2057+
20532058
def build_app(self, schema) -> FastAPI:
20542059
return build_app_custom(
20552060
model=User,
@@ -2059,6 +2064,12 @@ def build_app(self, schema) -> FastAPI:
20592064
resource_type=self.resource_type,
20602065
)
20612066

2067+
def inherit(self, schema: Type[BaseModel]) -> Type[BaseModel]:
2068+
class InheritedSchema(schema):
2069+
pass
2070+
2071+
return InheritedSchema
2072+
20622073
async def execute_request_and_check_response(
20632074
self,
20642075
app: FastAPI,
@@ -2082,6 +2093,28 @@ async def execute_request_and_check_response(
20822093
},
20832094
}
20842095

2096+
async def execute_request_twice_and_check_response(
2097+
self,
2098+
schema: Type[BaseModel],
2099+
body: Dict,
2100+
expected_detail: str,
2101+
):
2102+
"""
2103+
Makes two requests for check schema inheritance
2104+
"""
2105+
apps = [
2106+
self.build_app(schema),
2107+
self.build_app(self.inherit(schema)),
2108+
]
2109+
2110+
for app in apps:
2111+
await self.execute_request_and_check_response(
2112+
app=app,
2113+
body=body,
2114+
expected_detail=expected_detail,
2115+
)
2116+
self._clear_cache()
2117+
20852118
async def test_field_validator_call(self):
20862119
"""
20872120
Basic check to ensure that field validator called
@@ -2103,8 +2136,8 @@ class Config:
21032136
attrs = {"name": fake.name()}
21042137
create_user_body = {"data": {"attributes": attrs}}
21052138

2106-
await self.execute_request_and_check_response(
2107-
app=self.build_app(UserSchemaWithValidator),
2139+
await self.execute_request_twice_and_check_response(
2140+
schema=UserSchemaWithValidator,
21082141
body=create_user_body,
21092142
expected_detail="Check validator",
21102143
)
@@ -2124,8 +2157,8 @@ class Config:
21242157
attrs = {"names": ["good_name", "bad_name"]}
21252158
create_user_body = {"data": {"attributes": attrs}}
21262159

2127-
await self.execute_request_and_check_response(
2128-
app=self.build_app(UserSchemaWithValidator),
2160+
await self.execute_request_twice_and_check_response(
2161+
schema=UserSchemaWithValidator,
21292162
body=create_user_body,
21302163
expected_detail="Bad name not allowed",
21312164
)
@@ -2148,8 +2181,8 @@ class Config:
21482181
attrs = {"name": fake.name()}
21492182
create_user_body = {"data": {"attributes": attrs}}
21502183

2151-
await self.execute_request_and_check_response(
2152-
app=self.build_app(UserSchemaWithValidator),
2184+
await self.execute_request_twice_and_check_response(
2185+
schema=UserSchemaWithValidator,
21532186
body=create_user_body,
21542187
expected_detail="Pre validator called",
21552188
)
@@ -2167,8 +2200,8 @@ class Config:
21672200

21682201
create_user_body = {"data": {"attributes": {}}}
21692202

2170-
await self.execute_request_and_check_response(
2171-
app=self.build_app(UserSchemaWithValidator),
2203+
await self.execute_request_twice_and_check_response(
2204+
schema=UserSchemaWithValidator,
21722205
body=create_user_body,
21732206
expected_detail="Called always validator",
21742207
)
@@ -2274,8 +2307,8 @@ class Config:
22742307
},
22752308
}
22762309

2277-
await self.execute_request_and_check_response(
2278-
app=self.build_app(UserSchemaWithValidator),
2310+
await self.execute_request_twice_and_check_response(
2311+
schema=UserSchemaWithValidator,
22792312
body=create_user_body,
22802313
expected_detail="Check validator",
22812314
)

tests/test_api/test_validators.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ async def task_with_none_ids(
1818
async_session.add(task)
1919
await async_session.commit()
2020

21-
yield task
22-
23-
await async_session.delete(task)
24-
await async_session.commit()
21+
return task
2522

2623

2724
@pytest.fixture()

0 commit comments

Comments
 (0)