|
1 | 1 | """Helper to create sqlalchemy filters according to filter querystring parameter""" |
2 | | -from typing import Any, List, Tuple, Type, Union |
3 | | - |
4 | | -from pydantic import BaseModel |
| 2 | +from typing import ( |
| 3 | + Any, |
| 4 | + Callable, |
| 5 | + Dict, |
| 6 | + List, |
| 7 | + Optional, |
| 8 | + Tuple, |
| 9 | + Type, |
| 10 | + Union, |
| 11 | +) |
| 12 | + |
| 13 | +from pydantic import BaseConfig, BaseModel |
5 | 14 | from pydantic.fields import ModelField |
| 15 | +from pydantic.validators import _VALIDATORS, find_validators |
6 | 16 | from sqlalchemy import and_, not_, or_ |
7 | 17 | from sqlalchemy.orm import InstrumentedAttribute, aliased |
8 | 18 | from sqlalchemy.sql.elements import BinaryExpression |
|
22 | 32 | List[Join], |
23 | 33 | ] |
24 | 34 |
|
| 35 | +# The mapping with validators using by to cast raw value to instance of target type |
| 36 | +REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS) |
| 37 | + |
25 | 38 |
|
26 | 39 | def create_filters(model: Type[TypeModel], filter_info: Union[list, dict], schema: Type[TypeSchema]): |
27 | 40 | """ |
@@ -78,20 +91,83 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value) |
78 | 91 | types = [i.type_ for i in fields] |
79 | 92 | clear_value = None |
80 | 93 | errors: List[str] = [] |
81 | | - for i_type in types: |
82 | | - try: |
83 | | - if isinstance(value, list): # noqa: SIM108 |
84 | | - clear_value = [i_type(item) for item in value] |
85 | | - else: |
86 | | - # pass |
87 | | - clear_value = i_type(value) |
88 | | - except (TypeError, ValueError) as ex: |
89 | | - errors.append(str(ex)) |
| 94 | + |
| 95 | + pydantic_types, userspace_types = self._separate_types(types) |
| 96 | + |
| 97 | + if pydantic_types: |
| 98 | + if isinstance(value, list): |
| 99 | + clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value) |
| 100 | + else: |
| 101 | + clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value) |
| 102 | + |
| 103 | + if clear_value is None and userspace_types: |
| 104 | + for i_type in types: |
| 105 | + try: |
| 106 | + if isinstance(value, list): # noqa: SIM108 |
| 107 | + clear_value = [i_type(item) for item in value] |
| 108 | + else: |
| 109 | + clear_value = i_type(value) |
| 110 | + except (TypeError, ValueError) as ex: |
| 111 | + errors.append(str(ex)) |
| 112 | + |
90 | 113 | # Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку) |
91 | 114 | if clear_value is None and not any(not i_f.required for i_f in fields): |
92 | 115 | raise InvalidType(detail=", ".join(errors)) |
93 | 116 | return getattr(model_column, self.operator)(clear_value) |
94 | 117 |
|
| 118 | + def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]: |
| 119 | + """ |
| 120 | + Separates the types into two kinds. The first are those for which |
| 121 | + there are already validators defined by pydantic - str, int, datetime |
| 122 | + and some other built-in types. The second are all other types for which |
| 123 | + the `arbitrary_types_allowed` config is applied when defining the pydantic model |
| 124 | + """ |
| 125 | + pydantic_types = filter(lambda type_: type_ in REGISTERED_PYDANTIC_TYPES, types) |
| 126 | + userspace_types_types = filter(lambda type_: type_ not in REGISTERED_PYDANTIC_TYPES, types) |
| 127 | + return list(pydantic_types), list(userspace_types_types) |
| 128 | + |
| 129 | + def _cast_value_with_pydantic( |
| 130 | + self, |
| 131 | + types: List[Type], |
| 132 | + value: Any, |
| 133 | + ) -> Tuple[Optional[Any], List[str]]: |
| 134 | + result_value, errors = None, [] |
| 135 | + |
| 136 | + for type_to_cast in types: |
| 137 | + for validator in find_validators(type_to_cast, BaseConfig): |
| 138 | + try: |
| 139 | + result_value = validator(value) |
| 140 | + return result_value, errors |
| 141 | + except Exception as ex: |
| 142 | + errors.append(str(ex)) |
| 143 | + |
| 144 | + return None, errors |
| 145 | + |
| 146 | + def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]: |
| 147 | + type_cast_failed = False |
| 148 | + failed_values = [] |
| 149 | + |
| 150 | + result_values: List[Any] = [] |
| 151 | + errors: List[str] = [] |
| 152 | + |
| 153 | + for value in values: |
| 154 | + casted_value, cast_errors = self._cast_value_with_pydantic(types, value) |
| 155 | + errors.extend(cast_errors) |
| 156 | + |
| 157 | + if casted_value is None: |
| 158 | + type_cast_failed = True |
| 159 | + failed_values.append(value) |
| 160 | + |
| 161 | + continue |
| 162 | + |
| 163 | + result_values.append(casted_value) |
| 164 | + |
| 165 | + if type_cast_failed: |
| 166 | + msg = f"Can't parse items {failed_values} of value {values}" |
| 167 | + raise InvalidFilters(msg) |
| 168 | + |
| 169 | + return result_values, errors |
| 170 | + |
95 | 171 | def resolve(self) -> FilterAndJoins: # noqa: PLR0911 |
96 | 172 | """Create filter for a particular node of the filter tree""" |
97 | 173 | if "or" in self.filter_: |
|
0 commit comments