Skip to content

Commit 39655d5

Browse files
committed
atomic operations: single session, rollback on error
1 parent 2de8568 commit 39655d5

File tree

6 files changed

+117
-8
lines changed

6 files changed

+117
-8
lines changed

fastapi_jsonapi/atomic/atomic.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
OperationRelationshipSchema,
2121
)
2222
from fastapi_jsonapi.utils.dependency_helper import DependencyHelper
23+
from fastapi_jsonapi.views.list_view import ListViewBase
2324
from fastapi_jsonapi.views.utils import HTTPMethodConfig
2425
from fastapi_jsonapi.views.view_base import ViewBase
2526

2627
if TYPE_CHECKING:
2728
from fastapi_jsonapi.data_layers.base import BaseDataLayer
28-
from fastapi_jsonapi.views.list_view import ListViewBase
2929

3030

3131
@dataclass
@@ -105,15 +105,18 @@ async def view_atomic(
105105

106106
results = []
107107

108+
previous_dl: Optional["BaseDataLayer"] = None
108109
for operation in prepared_operations:
109110
dl = operation.data_layer
111+
await dl.atomic_start(previous_dl=previous_dl)
112+
previous_dl = dl
110113
if operation.action == "add":
111114
data = operation.jsonapi.schema_in_post(data=operation.data)
112115
created_object = await dl.create_object(
113116
data_create=data.data,
114117
view_kwargs={},
115118
)
116-
# assert isinstance(operation.view, ListViewBase)
119+
assert isinstance(operation.view, ListViewBase)
117120
view: "ListViewBase" = operation.view
118121
response = await view.response_for_created_object(
119122
dl=operation.data_layer,
@@ -129,6 +132,9 @@ async def view_atomic(
129132
msg = f"unknown action {operation.action!r}"
130133
raise ValueError(msg)
131134

135+
if previous_dl:
136+
await previous_dl.atomic_end(success=True)
137+
132138
return {"atomic:results": results}
133139

134140
def _register_view(self):

fastapi_jsonapi/data_layers/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def __init__(
6363
self.id_name_field = id_name_field
6464
self.disable_collection_count: bool = disable_collection_count
6565
self.default_collection_count: int = default_collection_count
66+
self.is_atomic = False
67+
68+
async def atomic_start(self, previous_dl: Optional["BaseDataLayer"] = None):
69+
self.is_atomic = True
70+
71+
async def atomic_end(self, success: bool = True):
72+
raise NotImplementedError
6673

6774
async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs: dict) -> TypeModel:
6875
"""

fastapi_jsonapi/data_layers/sqla_orm.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
44

55
from sqlalchemy import delete, func, select
6-
from sqlalchemy.exc import DBAPIError, NoResultFound
7-
from sqlalchemy.ext.asyncio import AsyncSession
6+
from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound
7+
from sqlalchemy.ext.asyncio import AsyncSession, AsyncSessionTransaction
88
from sqlalchemy.inspection import inspect
99
from sqlalchemy.orm import joinedload, selectinload
1010
from sqlalchemy.orm.attributes import InstrumentedAttribute
1111
from sqlalchemy.orm.collections import InstrumentedList
1212

13+
from fastapi_jsonapi import BadRequest
1314
from fastapi_jsonapi.data_layers.base import BaseDataLayer
1415
from fastapi_jsonapi.data_layers.filtering.sqlalchemy import create_filters
1516
from fastapi_jsonapi.data_layers.sorting.sqlalchemy import create_sorts
@@ -86,6 +87,30 @@ def __init__(
8687
self.eagerload_includes_ = eagerload_includes
8788
self._query = query
8889
self.auto_convert_id_to_column_type = auto_convert_id_to_column_type
90+
self.transaction: Optional[AsyncSessionTransaction] = None
91+
92+
async def atomic_start(self, previous_dl: Optional["SqlalchemyDataLayer"] = None):
93+
self.is_atomic = True
94+
if previous_dl:
95+
self.session = previous_dl.session
96+
if previous_dl.transaction:
97+
self.transaction = previous_dl.transaction
98+
return
99+
100+
self.transaction = self.session.begin()
101+
await self.transaction.start()
102+
103+
async def atomic_end(self, success: bool = True):
104+
if success:
105+
await self.transaction.commit()
106+
else:
107+
await self.transaction.rollback()
108+
109+
async def save(self):
110+
if self.is_atomic:
111+
await self.session.flush()
112+
else:
113+
await self.session.commit()
89114

90115
def prepare_id_value(self, col: InstrumentedAttribute, value: Any) -> Any:
91116
"""
@@ -199,7 +224,11 @@ async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs:
199224

200225
self.session.add(obj)
201226
try:
202-
await self.session.commit()
227+
await self.save()
228+
except IntegrityError:
229+
log.exception("Could not create object with data create %s", data_create)
230+
msg = "Object creation error"
231+
raise BadRequest(msg, pointer="/data")
203232
except DBAPIError:
204233
log.exception("Could not create object with data create %s", data_create)
205234
msg = "Object creation error"
@@ -331,7 +360,7 @@ async def update_object(
331360
setattr(obj, field_name, new_value)
332361
has_updated = True
333362
try:
334-
await self.session.commit()
363+
await self.save()
335364
except DBAPIError as e:
336365
await self.session.rollback()
337366

@@ -357,7 +386,7 @@ async def delete_object(self, obj: TypeModel, view_kwargs: dict):
357386
await self.before_delete_object(obj, view_kwargs)
358387
try:
359388
await self.session.delete(obj)
360-
await self.session.commit()
389+
await self.save()
361390
except DBAPIError as e:
362391
await self.session.rollback()
363392

@@ -377,7 +406,7 @@ async def delete_objects(self, objects: List[TypeModel], view_kwargs: dict):
377406

378407
try:
379408
await self.session.execute(query)
380-
await self.session.commit()
409+
await self.save()
381410
except DBAPIError as e:
382411
await self.session.rollback()
383412
raise InternalServerError(

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454

5555
def configure_logging():
5656
logging.getLogger("faker.factory").setLevel(logging.INFO)
57+
logging.getLogger("aiosqlite").setLevel(logging.INFO)
58+
# logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
59+
logging.basicConfig(level=logging.DEBUG)
5760

5861

5962
configure_logging()

tests/fixtures/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def add_routers(app_plain: FastAPI):
132132
schema_in_post=ChildInSchema,
133133
model=Child,
134134
)
135+
135136
RoutersJSONAPI(
136137
router=router,
137138
path="/computers",

tests/test_atomic/test_create_objects.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import and_, or_, select
66
from sqlalchemy.engine import Result
77
from sqlalchemy.ext.asyncio import AsyncSession
8+
from sqlalchemy.sql.functions import count
89
from starlette import status
910

1011
from tests.misc.utils import fake
@@ -135,3 +136,65 @@ async def test_create_two_objects(
135136
"type": "user",
136137
},
137138
}
139+
140+
async def test_atomic_rollback_on_create_error(
141+
self,
142+
client: AsyncClient,
143+
async_session: AsyncSession,
144+
):
145+
"""
146+
User name is unique
147+
148+
- create first user
149+
- create second user with the same name
150+
- catch exc
151+
- rollback all changes
152+
153+
:param client:
154+
:param async_session:
155+
:return:
156+
"""
157+
user_data_1 = UserAttributesBaseSchema(
158+
name=fake.name(),
159+
age=fake.pyint(min_value=13, max_value=99),
160+
email=fake.email(),
161+
)
162+
user_data_2 = UserAttributesBaseSchema(
163+
name=user_data_1.name,
164+
age=fake.pyint(min_value=13, max_value=99),
165+
email=fake.email(),
166+
)
167+
users_data = [user_data_1, user_data_2]
168+
data_atomic_request = {
169+
"atomic:operations": [
170+
{
171+
"op": "add",
172+
"data": {
173+
"type": "user",
174+
"attributes": user_data.dict(),
175+
},
176+
}
177+
for user_data in users_data
178+
],
179+
}
180+
response = await client.post("/operations", json=data_atomic_request)
181+
assert response.status_code == status.HTTP_400_BAD_REQUEST, response.text
182+
response_data = response.json()
183+
assert "errors" in response_data, response_data
184+
errors = response_data["errors"]
185+
assert errors, response_data
186+
error = errors[0]
187+
assert error == {
188+
"detail": "Object creation error",
189+
"source": {"pointer": "/data"},
190+
"status_code": status.HTTP_400_BAD_REQUEST,
191+
"title": "Bad Request",
192+
}
193+
stmt = select(count(User.id)).where(
194+
or_(
195+
User.name == user_data_1.name,
196+
User.name == user_data_2.name,
197+
),
198+
)
199+
result: Result = await async_session.execute(stmt)
200+
assert result.scalar_one() == 0

0 commit comments

Comments
 (0)