Skip to content

Commit 9368829

Browse files
authored
Merge branch 'main' into feature/fix-output-type
2 parents 199e6d3 + 444a86e commit 9368829

File tree

9 files changed

+208
-78
lines changed

9 files changed

+208
-78
lines changed

examples/api_for_sqlalchemy/api/user.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from fastapi import Depends
88
from sqlalchemy import select, desc
99
from sqlalchemy.ext.asyncio import AsyncSession
10+
from sqlalchemy.sql import Select
1011
from tortoise.exceptions import DoesNotExist
11-
from tortoise.queryset import QuerySet
1212

1313
from examples.api_for_sqlalchemy.extensions.sqlalchemy import Connector
1414
from examples.api_for_sqlalchemy.helpers.factories.meta_base import FactoryUseMode
@@ -19,7 +19,6 @@
1919
from examples.api_for_sqlalchemy.models.pydantic.user import UserInSchema
2020
from examples.api_for_sqlalchemy.models.sqlalchemy import User
2121
from fastapi_rest_jsonapi import SqlalchemyEngine
22-
2322
from fastapi_rest_jsonapi.exceptions import (
2423
BadRequest,
2524
HTTPException,
@@ -76,7 +75,7 @@ async def patch(cls, obj_id, data: UserPatchSchema, query_params: QueryStringMan
7675

7776
class UserList:
7877
@classmethod
79-
async def get(cls, query_params: QueryStringManager, session: AsyncSession = Depends(Connector.get_session)) -> Union[QuerySet, JSONAPIResultListSchema]:
78+
async def get(cls, query_params: QueryStringManager, session: AsyncSession = Depends(Connector.get_session)) -> Union[Select, JSONAPIResultListSchema]:
8079
user_query = select(User).order_by(desc(User.id))
8180
dl = SqlalchemyEngine(query=user_query, schema=UserSchema, model=User, session=session)
8281
count, users_db = await dl.get_collection(qs=query_params)

fastapi_rest_jsonapi/api.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
List,
77
Optional,
88
Type,
9-
Union,
9+
Union, TypeVar,
1010
)
1111

1212
import pydantic
@@ -21,28 +21,31 @@
2121
get_detail_jsonapi,
2222
get_list_jsonapi,
2323
patch_detail_jsonapi,
24-
post_list_jsonapi,
24+
post_list_jsonapi, delete_list_jsonapi,
2525
)
2626
from fastapi_rest_jsonapi.schema import BasePatchJSONAPISchema, BasePostJSONAPISchema, JSONAPIObjectSchema, \
2727
JSONAPIResultDetailSchema
2828

2929
JSON_API_RESPONSE_TYPE = Optional[Dict[Union[int, str], Dict[str, Any]]]
3030

31+
TypeAPIRouter = TypeVar("TypeAPIRouter", bound=APIRouter)
32+
TypeSchema = TypeVar("TypeSchema", bound=BaseModel)
33+
3134

3235
class RoutersJSONAPI:
3336
"""API Router interface for JSON API endpoints in web-services."""
3437

3538
def __init__( # noqa: WPS211
3639
self,
37-
routers: APIRouter,
40+
routers: TypeAPIRouter,
3841
path: Union[str, List[str]],
3942
tags: List[str],
4043
class_detail: Any,
4144
class_list: Any,
42-
schema: Type[BaseModel],
45+
schema: Type[TypeSchema],
4346
type_resource: str,
44-
schema_in_patch: Type[BaseModel],
45-
schema_in_post: Type[BaseModel],
47+
schema_in_patch: Type[TypeSchema],
48+
schema_in_post: Type[TypeSchema],
4649
model: Type[TypeModel],
4750
engine: DBORMType = DBORMType.sqlalchemy,
4851
) -> None:
@@ -166,6 +169,21 @@ def _add_routers(self, path: str):
166169
)(self.class_list.post)
167170
)
168171

172+
if hasattr(self.class_list, "delete"):
173+
self._routers.delete(
174+
path,
175+
tags=self._tags,
176+
summary=f"Delete list objects of type `{self._type}`"
177+
)(
178+
delete_list_jsonapi(
179+
schema=self._schema,
180+
model=self._model,
181+
engine=self._engine,
182+
)(
183+
self.class_list.delete
184+
)
185+
)
186+
169187
if hasattr(self.class_detail, "get"):
170188
self._routers.get(
171189
path + "/{obj_id}",
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import TypeVar
22

3-
from pydantic import BaseModel
43

54
TypeQuery = TypeVar("TypeQuery")
65
TypeModel = TypeVar("TypeModel")
7-
TypeSchema = TypeVar("TypeSchema", bound=BaseModel)
6+
TypeSchema = TypeVar("TypeSchema")

fastapi_rest_jsonapi/data_layers/filtering/sqlalchemy.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlalchemy.sql.elements import BinaryExpression
99

1010
from fastapi_rest_jsonapi.data_layers.shared import create_filters_or_sorts
11-
from fastapi_rest_jsonapi.exceptions import InvalidFilters
11+
from fastapi_rest_jsonapi.exceptions import InvalidFilters, InvalidType
1212

1313
from fastapi_rest_jsonapi.data_layers.data_typing import TypeSchema, TypeModel
1414
from fastapi_rest_jsonapi.schema import get_relationships, get_model_field
@@ -78,8 +78,24 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
7878
)
7979
# Here we have to deserialize and validate fields, that are used in filtering,
8080
# so the Enum fields are loaded correctly
81-
value = schema_field.type_(value)
82-
return getattr(model_column, self.operator)(value)
81+
82+
if schema_field.sub_fields:
83+
# Для случаев когда в схеме тип Union
84+
fields = [i for i in schema_field.sub_fields]
85+
else:
86+
fields = [schema_field]
87+
types = [i.type_ for i in fields]
88+
clear_value = None
89+
errors: List[str] = []
90+
for i_type in types:
91+
try:
92+
clear_value = i_type(value)
93+
except (TypeError, ValueError) as ex:
94+
errors.append(str(ex))
95+
# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
96+
if clear_value is None and not any([not i_f.required for i_f in fields]):
97+
raise InvalidType(detail=", ".join(errors))
98+
return getattr(model_column, self.operator)(clear_value)
8399

84100
def resolve(self) -> FilterAndJoins:
85101
"""Create filter for a particular node of the filter tree"""
@@ -105,13 +121,22 @@ def resolve(self) -> FilterAndJoins:
105121
return self._relationship_filtering(value)
106122

107123
schema_field: ModelField = self.schema.__fields__[self.name]
108-
if issubclass(schema_field.type_, BaseModel):
109-
value = {
110-
'name': self.name,
111-
'op': self.filter_['op'],
112-
'val': value,
113-
}
114-
return self._relationship_filtering(value)
124+
if schema_field.sub_fields:
125+
# Для случаев когда в схеме тип Union
126+
types = [i.type_ for i in schema_field.sub_fields]
127+
else:
128+
types = [schema_field.type_]
129+
for i_type in types:
130+
try:
131+
if issubclass(i_type, BaseModel):
132+
value = {
133+
'name': self.name,
134+
'op': self.filter_['op'],
135+
'val': value,
136+
}
137+
return self._relationship_filtering(value)
138+
except (TypeError, ValueError):
139+
pass
115140

116141
return self.create_filter(
117142
schema_field=schema_field,

fastapi_rest_jsonapi/data_layers/sqlalchemy_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def paginate_query(self, query: Select, paginate_info: PaginationQueryStringMana
339339
:params paginate_info: pagination information.
340340
:return: the paginated query
341341
"""
342-
if paginate_info.size == 0:
342+
if paginate_info.size == 0 or paginate_info.size is None:
343343
return query
344344

345345
query = query.limit(paginate_info.size)

fastapi_rest_jsonapi/methods.py

Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
from fastapi_rest_jsonapi.data_layers.tortoise_orm_engine import TortoiseORMEngine
2424
from fastapi_rest_jsonapi.exceptions.json_api import UnsupportedFeatureORM
2525
from fastapi_rest_jsonapi.querystring import QueryStringManager
26-
from fastapi_rest_jsonapi.signature import (
27-
is_necessary_request,
28-
update_signature,
29-
)
26+
from fastapi_rest_jsonapi.signature import update_signature
3027

3128

3229
def get_detail_jsonapi(
@@ -41,12 +38,16 @@ def get_detail_jsonapi(
4138
def inner(func: Callable) -> Callable:
4239
async def wrapper(request: Request, obj_id: int, **kwargs):
4340
query_params = QueryStringManager(request=request, schema=schema)
44-
data_dict: dict = dict(query_params=query_params, obj_id=obj_id)
45-
if is_necessary_request(func):
46-
data_dict["request"] = request
47-
48-
params_function = OrderedDict(signature(func).parameters)
49-
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in params_function})
41+
data_dict = {"obj_id": obj_id}
42+
func_signature = signature(func).parameters
43+
for i_name, i_type in OrderedDict(func_signature).items():
44+
if i_type.annotation is Request:
45+
data_dict[i_name] = request
46+
elif i_type.annotation is QueryStringManager:
47+
data_dict[i_name] = query_params
48+
49+
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in func_signature})
50+
data_dict = {i_k: i_v for i_k, i_v in data_dict.items() if i_k in func_signature}
5051
data_schema: Any = await func(**data_dict)
5152
return schema_resp(
5253
data={
@@ -85,12 +86,18 @@ def patch_detail_jsonapi(
8586
def inner(func: Callable) -> Callable:
8687
async def wrapper(request: Request, obj_id: int, data: schema_in, **kwargs): # type: ignore
8788
query_params = QueryStringManager(request=request, schema=schema)
88-
data_dict: dict = dict(query_params=query_params, obj_id=obj_id, data=getattr(data, "attributes", data))
89-
if is_necessary_request(func):
90-
data_dict["request"] = request
91-
92-
params_function = OrderedDict(signature(func).parameters)
93-
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in params_function})
89+
data_dict = {"obj_id": obj_id}
90+
func_signature = signature(func).parameters
91+
for i_name, i_type in OrderedDict(func_signature).items():
92+
if i_type.annotation is schema_in.__fields__["attributes"].type_:
93+
data_dict[i_name] = getattr(data, 'attributes', data)
94+
elif i_type.annotation is Request:
95+
data_dict[i_name] = request
96+
elif i_type.annotation is QueryStringManager:
97+
data_dict[i_name] = query_params
98+
99+
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in func_signature})
100+
data_dict = {i_k: i_v for i_k, i_v in data_dict.items() if i_k in func_signature}
94101
data_schema: Any = await func(**data_dict)
95102
return schema_resp(
96103
data={
@@ -120,18 +127,68 @@ def delete_detail_jsonapi(
120127
def inner(func: Callable) -> Callable:
121128
async def wrapper(request: Request, obj_id: int, **kwargs): # type: ignore
122129
query_params = QueryStringManager(request=request, schema=schema)
123-
data_dict: dict = dict(query_params=query_params, obj_id=obj_id)
124-
if is_necessary_request(func):
125-
data_dict["request"] = request
130+
data_dict = {"obj_id": obj_id}
131+
func_signature = signature(func).parameters
132+
for i_name, i_type in OrderedDict(func_signature).items():
133+
if i_type.annotation is Request:
134+
data_dict[i_name] = request
135+
elif i_type.annotation is QueryStringManager:
136+
data_dict[i_name] = query_params
137+
138+
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in func_signature})
139+
data_dict = {i_k: i_v for i_k, i_v in data_dict.items() if i_k in func_signature}
140+
await func(**data_dict)
141+
return Response(status_code=status.HTTP_204_NO_CONTENT)
142+
143+
# mypy ругается что нет метода __signature__, как это обойти красиво- не знаю
144+
wrapper.__signature__ = update_signature( # type: ignore
145+
sig=signature(wrapper),
146+
other=OrderedDict(signature(func).parameters),
147+
)
148+
return wrapper
149+
150+
return inner
151+
152+
153+
def delete_list_jsonapi(
154+
schema: Type[BaseModel],
155+
model: Type[TypeModel],
156+
engine: DBORMType,
157+
) -> Callable:
158+
"""DELETE method router (Decorator for JSON API)."""
159+
160+
def inner(func: Callable) -> Callable:
161+
async def wrapper(
162+
request: Request,
163+
filters_list: Optional[str] = Query(
164+
None,
165+
alias="filter",
166+
description="[Filtering docs](https://flask-combo-jsonapi.readthedocs.io/en/latest/filtering.html)"
167+
"\nExamples:\n* filter for timestamp interval: "
168+
'`[{"name": "timestamp", "op": "ge", "val": "2020-07-16T11:35:33.383"},'
169+
'{"name": "timestamp", "op": "le", "val": "2020-07-21T11:35:33.383"}]`',
170+
),
171+
**kwargs,
172+
):
173+
query_params = QueryStringManager(request=request, schema=schema)
174+
data_dict = {}
175+
func_signature = signature(func).parameters
176+
for i_name, i_type in OrderedDict(func_signature).items():
177+
if i_type.annotation is Request:
178+
data_dict[i_name] = request
179+
elif i_type.annotation is QueryStringManager:
180+
data_dict[i_name] = query_params
126181

127182
params_function = OrderedDict(signature(func).parameters)
128183
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in params_function})
184+
data_dict = {i_k: i_v for i_k, i_v in data_dict.items() if i_k in params_function}
129185
await func(**data_dict)
130186
return Response(status_code=status.HTTP_204_NO_CONTENT)
131187

132188
# mypy ругается что нет метода __signature__, как это обойти красиво- не знаю
133189
wrapper.__signature__ = update_signature( # type: ignore
134190
sig=signature(wrapper),
191+
schema=schema,
135192
other=OrderedDict(signature(func).parameters),
136193
)
137194
return wrapper
@@ -193,21 +250,21 @@ async def wrapper(
193250
**kwargs,
194251
):
195252
query_params = QueryStringManager(request=request, schema=schema)
196-
data = {
197-
i_name: query_params
198-
for i_name, i_param in OrderedDict(signature(func).parameters).items()
199-
if i_param.annotation is QueryStringManager
200-
}
201-
if is_necessary_request(func):
202-
data["request"] = request
203-
204-
params_function = OrderedDict(signature(func).parameters)
205-
data.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in params_function})
206-
query = await func(**data)
253+
data_dict = {}
254+
func_signature = signature(func).parameters
255+
for i_name, i_type in OrderedDict(func_signature).items():
256+
if i_type.annotation is Request:
257+
data_dict[i_name] = request
258+
elif i_type.annotation is QueryStringManager:
259+
data_dict[i_name] = query_params
260+
261+
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in func_signature})
262+
data_dict = {i_k: i_v for i_k, i_v in data_dict.items() if i_k in func_signature}
263+
query = await func(**data_dict)
207264

208265
if engine is DBORMType.sqlalchemy:
209266
# Для SQLAlchemy нужно указывать session, для Tortoise достаточно модели
210-
session_list = [i_v for i_k, i_v in params_function.items() if isinstance(i_v, AsyncSession)]
267+
session_list = [i_v for i_k, i_v in func_signature.items() if isinstance(i_v, AsyncSession)]
211268
session: Optional[AsyncSession] = session_list and session_list[0] or None
212269
else:
213270
session = None
@@ -257,12 +314,19 @@ def post_list_jsonapi(
257314
def inner(func: Callable) -> Callable:
258315
async def wrapper(request: Request, data: schema_in, **kwargs): # type: ignore
259316
query_params = QueryStringManager(request=request, schema=schema)
260-
data_dict: dict = dict(query_params=query_params, data=getattr(data, 'attributes', data))
261-
if is_necessary_request(func):
262-
data_dict["request"] = request
263-
264-
params_function = OrderedDict(signature(func).parameters)
317+
data_dict = {}
318+
func_signature = signature(func).parameters
319+
for i_name, i_type in OrderedDict(func_signature).items():
320+
if i_type.annotation is schema_in.__fields__["attributes"].type_:
321+
data_dict[i_name] = getattr(data, 'attributes', data)
322+
elif i_type.annotation is Request:
323+
data_dict[i_name] = request
324+
elif i_type.annotation is QueryStringManager:
325+
data_dict[i_name] = query_params
326+
327+
params_function = OrderedDict(func_signature)
265328
data_dict.update({i_k: i_v for i_k, i_v in kwargs.items() if i_k in params_function})
329+
data_dict = {i_k: i_v for i_k, i_v in data_dict.items() if i_k in params_function}
266330
data_pydantic: Any = await func(**data_dict)
267331
return schema_resp(
268332
data={

fastapi_rest_jsonapi/schema.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Helpers to deal with marshmallow schemas. Base JSON:API schemas."""
2+
import uuid
23
from typing import (
34
Dict,
45
Type,
56
List,
67
Optional,
7-
Sequence, TYPE_CHECKING,
8+
Sequence, TYPE_CHECKING, Union,
89
)
910

1011
from fastapi import FastAPI
@@ -179,7 +180,13 @@ def get_relationships(schema: Type["TypeSchema"], model_field: bool = False) ->
179180
:param schema: a pydantic schema
180181
:param model_field: list of relationship fields of a schema
181182
"""
182-
relationships = [i_name for i_name, i_type in schema.__fields__.items() if issubclass(i_type.type_, BaseModel)]
183+
relationships: List[str] = []
184+
for i_name, i_type in schema.__fields__.items():
185+
try:
186+
if issubclass(i_type.type_, BaseModel):
187+
relationships.append(i_name)
188+
except TypeError:
189+
pass
183190

184191
if model_field is True:
185192
relationships = [get_model_field(schema, key) for key in relationships]

0 commit comments

Comments
 (0)