Skip to content

Commit 0d928e7

Browse files
committed
updated extracting of root validators, added tests
1 parent 361152a commit 0d928e7

File tree

2 files changed

+153
-60
lines changed

2 files changed

+153
-60
lines changed

fastapi_jsonapi/schema_builder.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616
)
1717

1818
import pydantic
19-
from pydantic import BaseConfig, root_validator, validator
19+
from pydantic import (
20+
BaseConfig,
21+
root_validator,
22+
validator,
23+
)
2024
from pydantic import BaseModel as PydanticBaseModel
2125
from pydantic.class_validators import (
22-
ROOT_VALIDATOR_CONFIG_KEY,
26+
extract_root_validators,
2327
extract_validators,
2428
inherit_validators,
2529
)
2630
from pydantic.fields import FieldInfo, ModelField, Validator
31+
from pydantic.utils import unique_list
2732

2833
from fastapi_jsonapi.data_typing import TypeSchema
2934
from fastapi_jsonapi.schema import (
@@ -386,7 +391,34 @@ def create_relationship_data_schema(
386391
self.relationship_schema_cache[cache_key] = relationship_data_schema
387392
return relationship_data_schema
388393

389-
def deduplicate_field_validators(self, validators: dict) -> dict:
394+
def _extract_root_validators(self, model: Type[BaseModel]) -> Dict[str, Callable]:
395+
pre_rv_new, post_rv_new = extract_root_validators(model.__dict__)
396+
pre_root_validators = unique_list(
397+
model.__pre_root_validators__ + pre_rv_new,
398+
name_factory=lambda v: v.__name__,
399+
)
400+
post_root_validators = unique_list(
401+
model.__post_root_validators__ + post_rv_new,
402+
name_factory=lambda skip_on_failure_and_v: skip_on_failure_and_v[1].__name__,
403+
)
404+
405+
result_validators = {}
406+
407+
for validator_func in pre_root_validators:
408+
result_validators[validator_func.__name__] = root_validator(
409+
pre=True,
410+
allow_reuse=True,
411+
)(validator_func)
412+
413+
for skip_on_failure, validator_func in post_root_validators:
414+
result_validators[validator_func.__name__] = root_validator(
415+
allow_reuse=True,
416+
skip_on_failure=skip_on_failure,
417+
)(validator_func)
418+
419+
return result_validators
420+
421+
def _deduplicate_field_validators(self, validators: dict) -> dict:
390422
result_validators = {}
391423

392424
for field_name, field_validators in validators.items():
@@ -411,7 +443,7 @@ def _extract_field_validators(
411443
extract_validators(model.__dict__),
412444
deepcopy(model.__validators__),
413445
)
414-
validators = self.deduplicate_field_validators(validators)
446+
validators = self._deduplicate_field_validators(validators)
415447
validator_origin_param_keys = (
416448
"pre",
417449
"each_item",
@@ -450,54 +482,6 @@ def _extract_field_validators(
450482

451483
return result_validators
452484

453-
def _is_target_validator(self, attr_name: str, value: Any, validator_config_key: str) -> bool:
454-
"""
455-
True if passed object is validator of type identified by "validator_config_key" arg
456-
457-
:param attr_name:
458-
:param value:
459-
:param validator_config_key: Choice field, available options are pydantic consts
460-
VALIDATOR_CONFIG_KEY, ROOT_VALIDATOR_CONFIG_KEY
461-
"""
462-
return (
463-
# also with private items
464-
not attr_name.startswith("__")
465-
and getattr(value, validator_config_key, None)
466-
)
467-
468-
def _unpack_validators(self, model: Type[BaseModel], validator_config_key: str) -> Dict[str, Validator]:
469-
"""
470-
Selects all validators from model attrs and unpack them from class methods
471-
472-
:param model: Type[BaseModel]
473-
:param validator_config_key: Choice field, available options are pydantic consts
474-
VALIDATOR_CONFIG_KEY, ROOT_VALIDATOR_CONFIG_KEY
475-
"""
476-
validator_class_methods = {
477-
# validators only
478-
attr_name: value
479-
for attr_name, value in model.__dict__.items()
480-
if self._is_target_validator(attr_name, value, validator_config_key)
481-
}
482-
483-
return {
484-
validator_name: getattr(validator_method, validator_config_key)
485-
for validator_name, validator_method in validator_class_methods.items()
486-
}
487-
488-
def _extract_root_validators(self, model: Type[BaseModel]) -> Dict[str, Callable]:
489-
validators = {}
490-
491-
unpacked_validators = self._unpack_validators(model, ROOT_VALIDATOR_CONFIG_KEY)
492-
for validator_name, validator_instance in unpacked_validators.items():
493-
validators[validator_name] = root_validator(
494-
pre=validator_instance.pre,
495-
skip_on_failure=validator_instance.skip_on_failure,
496-
allow_reuse=True,
497-
)(validator_instance.func)
498-
499-
return validators
500-
501485
def _extract_validators(
502486
self,
503487
model: Type[BaseModel],

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,11 +2313,18 @@ class Config:
23132313
expected_detail="Check validator",
23142314
)
23152315

2316-
async def test_field_validator_can_change_value(self):
2316+
@mark.parametrize(
2317+
"inherit",
2318+
[
2319+
param(True, id="inherited_true"),
2320+
param(False, id="inherited_false"),
2321+
],
2322+
)
2323+
async def test_field_validator_can_change_value(self, inherit: bool):
23172324
class UserSchemaWithValidator(BaseModel):
23182325
name: str
23192326

2320-
@validator("name")
2327+
@validator("name", allow_reuse=True)
23212328
def fix_title(cls, v):
23222329
return v.title()
23232330

@@ -2327,7 +2334,11 @@ class Config:
23272334
attrs = {"name": "john doe"}
23282335
create_user_body = {"data": {"attributes": attrs}}
23292336

2330-
app = self.build_app(UserSchemaWithValidator)
2337+
if inherit:
2338+
app = self.build_app(self.inherit(UserSchemaWithValidator))
2339+
else:
2340+
app = self.build_app(UserSchemaWithValidator)
2341+
23312342
async with AsyncClient(app=app, base_url="http://test") as client:
23322343
url = app.url_path_for(f"get_{self.resource_type}_list")
23332344
res = await client.post(url, json=create_user_body)
@@ -2392,17 +2403,24 @@ class Config:
23922403
attrs = {"name": name}
23932404
create_user_body = {"data": {"attributes": attrs}}
23942405

2395-
await self.execute_request_and_check_response(
2396-
app=self.build_app(UserSchemaWithValidator),
2406+
await self.execute_request_twice_and_check_response(
2407+
schema=UserSchemaWithValidator,
23972408
body=create_user_body,
23982409
expected_detail=expected_detail,
23992410
)
24002411

2401-
async def test_root_validator_can_change_value(self):
2412+
@mark.parametrize(
2413+
"inherit",
2414+
[
2415+
param(True, id="inherited_true"),
2416+
param(False, id="inherited_false"),
2417+
],
2418+
)
2419+
async def test_root_validator_can_change_value(self, inherit: bool):
24022420
class UserSchemaWithValidator(BaseModel):
24032421
name: str
24042422

2405-
@root_validator
2423+
@root_validator(allow_reuse=True)
24062424
def fix_title(cls, v):
24072425
v["name"] = v["name"].title()
24082426
return v
@@ -2413,7 +2431,11 @@ class Config:
24132431
attrs = {"name": "john doe"}
24142432
create_user_body = {"data": {"attributes": attrs}}
24152433

2416-
app = self.build_app(UserSchemaWithValidator)
2434+
if inherit:
2435+
app = self.build_app(self.inherit(UserSchemaWithValidator))
2436+
else:
2437+
app = self.build_app(UserSchemaWithValidator)
2438+
24172439
async with AsyncClient(app=app, base_url="http://test") as client:
24182440
url = app.url_path_for(f"get_{self.resource_type}_list")
24192441
res = await client.post(url, json=create_user_body)
@@ -2431,5 +2453,92 @@ class Config:
24312453
"meta": None,
24322454
}
24332455

2456+
@mark.parametrize(
2457+
("name", "expected_detail"),
2458+
[
2459+
param("check_pre_1", "check_pre_1", id="check_1_pre_validator"),
2460+
param("check_pre_2", "check_pre_2", id="check_2_pre_validator"),
2461+
param("check_post_1", "check_post_1", id="check_1_post_validator"),
2462+
param("check_post_2", "check_post_2", id="check_2_post_validator"),
2463+
],
2464+
)
2465+
async def test_root_validator_inheritance(self, name: str, expected_detail: str):
2466+
class UserSchemaWithValidatorBase(BaseModel):
2467+
name: str
2468+
2469+
@root_validator(pre=True, allow_reuse=True)
2470+
def validator_pre_1(cls, values):
2471+
if values["name"] == "check_pre_1":
2472+
raise BadRequest(detail="Base check_pre_1")
2473+
2474+
return values
2475+
2476+
@root_validator(pre=True, allow_reuse=True)
2477+
def validator_pre_2(cls, values):
2478+
if values["name"] == "check_pre_2":
2479+
raise BadRequest(detail="Base check_pre_2")
2480+
2481+
return values
2482+
2483+
@root_validator(allow_reuse=True)
2484+
def validator_post_1(cls, values):
2485+
if values["name"] == "check_post_1":
2486+
raise BadRequest(detail="Base check_post_1")
2487+
2488+
return values
2489+
2490+
@root_validator(allow_reuse=True)
2491+
def validator_post_2(cls, values):
2492+
if values["name"] == "check_post_2":
2493+
raise BadRequest(detail="Base check_post_2")
2494+
2495+
return values
2496+
2497+
class Config:
2498+
orm_mode = True
2499+
2500+
class UserSchemaWithValidator(UserSchemaWithValidatorBase):
2501+
name: str
2502+
2503+
@root_validator(pre=True, allow_reuse=True)
2504+
def validator_pre_1(cls, values):
2505+
if values["name"] == "check_pre_1":
2506+
raise BadRequest(detail="check_pre_1")
2507+
2508+
return values
2509+
2510+
@root_validator(pre=True, allow_reuse=True)
2511+
def validator_pre_2(cls, values):
2512+
if values["name"] == "check_pre_2":
2513+
raise BadRequest(detail="check_pre_2")
2514+
2515+
return values
2516+
2517+
@root_validator(allow_reuse=True)
2518+
def validator_post_1(cls, values):
2519+
if values["name"] == "check_post_1":
2520+
raise BadRequest(detail="check_post_1")
2521+
2522+
return values
2523+
2524+
@root_validator(allow_reuse=True)
2525+
def validator_post_2(cls, values):
2526+
if values["name"] == "check_post_2":
2527+
raise BadRequest(detail="check_post_2")
2528+
2529+
return values
2530+
2531+
class Config:
2532+
orm_mode = True
2533+
2534+
attrs = {"name": name}
2535+
create_user_body = {"data": {"attributes": attrs}}
2536+
2537+
await self.execute_request_and_check_response(
2538+
app=self.build_app(UserSchemaWithValidator),
2539+
body=create_user_body,
2540+
expected_detail=expected_detail,
2541+
)
2542+
24342543

24352544
# todo: test errors

0 commit comments

Comments
 (0)