|
1 | 1 | """JSON API schemas builder class.""" |
| 2 | +from copy import deepcopy |
2 | 3 | from dataclasses import dataclass |
3 | 4 | from typing import ( |
4 | 5 | Any, |
|
17 | 18 | import pydantic |
18 | 19 | from pydantic import BaseConfig, root_validator, validator |
19 | 20 | from pydantic import BaseModel as PydanticBaseModel |
20 | | -from pydantic.class_validators import ROOT_VALIDATOR_CONFIG_KEY, VALIDATOR_CONFIG_KEY |
| 21 | +from pydantic.class_validators import ( |
| 22 | + ROOT_VALIDATOR_CONFIG_KEY, |
| 23 | + VALIDATOR_CONFIG_KEY, |
| 24 | + extract_validators, |
| 25 | + inherit_validators, |
| 26 | +) |
21 | 27 | from pydantic.fields import FieldInfo, ModelField, Validator |
22 | 28 |
|
23 | 29 | from fastapi_jsonapi.data_typing import TypeSchema |
@@ -378,6 +384,51 @@ def create_relationship_data_schema( |
378 | 384 | self.relationship_schema_cache[cache_key] = relationship_data_schema |
379 | 385 | return relationship_data_schema |
380 | 386 |
|
| 387 | + def deduplicate_field_validators(self, validators: dict) -> dict: |
| 388 | + result_validators = {} |
| 389 | + |
| 390 | + for field_name, field_validators in validators.items(): |
| 391 | + result_validators[field_name] = list( |
| 392 | + { |
| 393 | + # override in definition order |
| 394 | + field_validator.func.__name__: field_validator |
| 395 | + for field_validator in field_validators |
| 396 | + }.values(), |
| 397 | + ) |
| 398 | + |
| 399 | + return result_validators |
| 400 | + |
| 401 | + def prepare_validators(self, model: Type[BaseModel]): |
| 402 | + validators = inherit_validators( |
| 403 | + extract_validators(model.__dict__), |
| 404 | + deepcopy(model.__validators__), |
| 405 | + ) |
| 406 | + validators = self.deduplicate_field_validators(validators) |
| 407 | + validator_origin_param_keys = ( |
| 408 | + "pre", |
| 409 | + "each_item", |
| 410 | + "always", |
| 411 | + "check_fields", |
| 412 | + ) |
| 413 | + |
| 414 | + result_validators = {} |
| 415 | + for field_name, field_validators in validators.items(): |
| 416 | + field_validator: Validator |
| 417 | + for field_validator in field_validators: |
| 418 | + validator_name = f"{field_name}_{field_validator.func.__name__}_validator" |
| 419 | + validator_params = { |
| 420 | + # copy validator params |
| 421 | + param_key: getattr(field_validator, param_key) |
| 422 | + for param_key in validator_origin_param_keys |
| 423 | + } |
| 424 | + result_validators[validator_name] = validator( |
| 425 | + field_name, |
| 426 | + **validator_params, |
| 427 | + allow_reuse=True, |
| 428 | + )(field_validator.func) |
| 429 | + |
| 430 | + return result_validators |
| 431 | + |
381 | 432 | def _is_target_validator(self, attr_name: str, value: Any, validator_config_key: str) -> bool: |
382 | 433 | """ |
383 | 434 | True if passed object is validator of type identified by "validator_config_key" arg |
|
0 commit comments