Skip to content

Commit 39a3032

Browse files
committed
updated type cast logic in filters
1 parent 65aadb8 commit 39a3032

File tree

2 files changed

+99
-21
lines changed

2 files changed

+99
-21
lines changed

fastapi_jsonapi/data_layers/filtering/sqlalchemy.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
2-
from typing import Any, List, Tuple, Type, Union
3-
4-
from pydantic import BaseModel
2+
from typing import (
3+
Any,
4+
Callable,
5+
Dict,
6+
List,
7+
Optional,
8+
Tuple,
9+
Type,
10+
Union,
11+
)
12+
13+
from pydantic import BaseConfig, BaseModel
514
from pydantic.fields import ModelField
15+
from pydantic.validators import _VALIDATORS, find_validators
616
from sqlalchemy import and_, not_, or_
717
from sqlalchemy.orm import InstrumentedAttribute, aliased
818
from sqlalchemy.sql.elements import BinaryExpression
@@ -22,6 +32,9 @@
2232
List[Join],
2333
]
2434

35+
# The mapping with validators using by to cast raw value to instance of target type
36+
REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS)
37+
2538

2639
def create_filters(model: Type[TypeModel], filter_info: Union[list, dict], schema: Type[TypeSchema]):
2740
"""
@@ -78,20 +91,83 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
7891
types = [i.type_ for i in fields]
7992
clear_value = None
8093
errors: List[str] = []
81-
for i_type in types:
82-
try:
83-
if isinstance(value, list): # noqa: SIM108
84-
clear_value = [i_type(item) for item in value]
85-
else:
86-
# pass
87-
clear_value = i_type(value)
88-
except (TypeError, ValueError) as ex:
89-
errors.append(str(ex))
94+
95+
pydantic_types, userspace_types = self._separate_types(types)
96+
97+
if pydantic_types:
98+
if isinstance(value, list):
99+
clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value)
100+
else:
101+
clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value)
102+
103+
if clear_value is None and userspace_types:
104+
for i_type in types:
105+
try:
106+
if isinstance(value, list): # noqa: SIM108
107+
clear_value = [i_type(item) for item in value]
108+
else:
109+
clear_value = i_type(value)
110+
except (TypeError, ValueError) as ex:
111+
errors.append(str(ex))
112+
90113
# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
91114
if clear_value is None and not any(not i_f.required for i_f in fields):
92115
raise InvalidType(detail=", ".join(errors))
93116
return getattr(model_column, self.operator)(clear_value)
94117

118+
def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]:
119+
"""
120+
Separates the types into two kinds. The first are those for which
121+
there are already validators defined by pydantic - str, int, datetime
122+
and some other built-in types. The second are all other types for which
123+
the `arbitrary_types_allowed` config is applied when defining the pydantic model
124+
"""
125+
pydantic_types = filter(lambda type_: type_ in REGISTERED_PYDANTIC_TYPES, types)
126+
userspace_types_types = filter(lambda type_: type_ not in REGISTERED_PYDANTIC_TYPES, types)
127+
return list(pydantic_types), list(userspace_types_types)
128+
129+
def _cast_value_with_pydantic(
130+
self,
131+
types: List[Type],
132+
value: Any,
133+
) -> Tuple[Optional[Any], List[str]]:
134+
result_value, errors = None, []
135+
136+
for type_to_cast in types:
137+
for validator in find_validators(type_to_cast, BaseConfig):
138+
try:
139+
result_value = validator(value)
140+
return result_value, errors
141+
except Exception as ex:
142+
errors.append(str(ex))
143+
144+
return None, errors
145+
146+
def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]:
147+
type_cast_failed = False
148+
failed_values = []
149+
150+
result_values: List[Any] = []
151+
errors: List[str] = []
152+
153+
for value in values:
154+
casted_value, cast_errors = self._cast_value_with_pydantic(types, value)
155+
errors.extend(cast_errors)
156+
157+
if casted_value is None:
158+
type_cast_failed = True
159+
failed_values.append(value)
160+
161+
continue
162+
163+
result_values.append(casted_value)
164+
165+
if type_cast_failed:
166+
msg = f"Can't parse items {failed_values} of value {values}"
167+
raise InvalidFilters(msg)
168+
169+
return result_values, errors
170+
95171
def resolve(self) -> FilterAndJoins: # noqa: PLR0911
96172
"""Create filter for a particular node of the filter tree"""
97173
if "or" in self.filter_:

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,7 +1262,7 @@ class ContainsTimestampAttrsSchema(BaseModel):
12621262
stms = select(ContainsTimestamp).where(ContainsTimestamp.id == int(entity_id))
12631263
entity_model: Optional[ContainsTimestamp] = (await async_session.execute(stms)).scalar_one_or_none()
12641264
assert entity_model
1265-
assert entity_model.timestamp == create_timestamp
1265+
assert entity_model.timestamp.isoformat() == create_timestamp.replace(tzinfo=None).isoformat()
12661266

12671267
params = {
12681268
"filter": json.dumps(
@@ -1285,21 +1285,23 @@ class ContainsTimestampAttrsSchema(BaseModel):
12851285
"data": [
12861286
{
12871287
"type": "contains_timestamp_model",
1288-
"attributes": {"timestamp": create_timestamp.isoformat()},
1288+
"attributes": {"timestamp": create_timestamp.replace(tzinfo=None).isoformat()},
12891289
"id": entity_id,
12901290
},
12911291
],
12921292
}
12931293

12941294
# check filter really work
12951295
params = {
1296-
"filter": [
1297-
{
1298-
"name": "timestamp",
1299-
"op": "eq",
1300-
"val": datetime.now(tz=timezone.utc).isoformat(),
1301-
},
1302-
],
1296+
"filter": json.dumps(
1297+
[
1298+
{
1299+
"name": "timestamp",
1300+
"op": "eq",
1301+
"val": datetime.now(tz=timezone.utc).isoformat(),
1302+
},
1303+
],
1304+
),
13031305
}
13041306
res = await client.get(url, params=params)
13051307
assert res.status_code == status.HTTP_200_OK, res.text

0 commit comments

Comments
 (0)