33from typing import TYPE_CHECKING , Any , Iterable , List , Optional , Tuple , Type
44
55from sqlalchemy import delete , func , select
6- from sqlalchemy .exc import DBAPIError , NoResultFound
7- from sqlalchemy .ext .asyncio import AsyncSession
6+ from sqlalchemy .exc import DBAPIError , IntegrityError , NoResultFound
7+ from sqlalchemy .ext .asyncio import AsyncSession , AsyncSessionTransaction
88from sqlalchemy .inspection import inspect
99from sqlalchemy .orm import joinedload , selectinload
1010from sqlalchemy .orm .attributes import InstrumentedAttribute
1111from sqlalchemy .orm .collections import InstrumentedList
1212
13+ from fastapi_jsonapi import BadRequest
1314from fastapi_jsonapi .data_layers .base import BaseDataLayer
1415from fastapi_jsonapi .data_layers .filtering .sqlalchemy import create_filters
1516from fastapi_jsonapi .data_layers .sorting .sqlalchemy import create_sorts
@@ -86,6 +87,30 @@ def __init__(
8687 self .eagerload_includes_ = eagerload_includes
8788 self ._query = query
8889 self .auto_convert_id_to_column_type = auto_convert_id_to_column_type
90+ self .transaction : Optional [AsyncSessionTransaction ] = None
91+
92+ async def atomic_start (self , previous_dl : Optional ["SqlalchemyDataLayer" ] = None ):
93+ self .is_atomic = True
94+ if previous_dl :
95+ self .session = previous_dl .session
96+ if previous_dl .transaction :
97+ self .transaction = previous_dl .transaction
98+ return
99+
100+ self .transaction = self .session .begin ()
101+ await self .transaction .start ()
102+
103+ async def atomic_end (self , success : bool = True ):
104+ if success :
105+ await self .transaction .commit ()
106+ else :
107+ await self .transaction .rollback ()
108+
109+ async def save (self ):
110+ if self .is_atomic :
111+ await self .session .flush ()
112+ else :
113+ await self .session .commit ()
89114
90115 def prepare_id_value (self , col : InstrumentedAttribute , value : Any ) -> Any :
91116 """
@@ -199,7 +224,11 @@ async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs:
199224
200225 self .session .add (obj )
201226 try :
202- await self .session .commit ()
227+ await self .save ()
228+ except IntegrityError :
229+ log .exception ("Could not create object with data create %s" , data_create )
230+ msg = "Object creation error"
231+ raise BadRequest (msg , pointer = "/data" )
203232 except DBAPIError :
204233 log .exception ("Could not create object with data create %s" , data_create )
205234 msg = "Object creation error"
@@ -331,7 +360,7 @@ async def update_object(
331360 setattr (obj , field_name , new_value )
332361 has_updated = True
333362 try :
334- await self .session . commit ()
363+ await self .save ()
335364 except DBAPIError as e :
336365 await self .session .rollback ()
337366
@@ -357,7 +386,7 @@ async def delete_object(self, obj: TypeModel, view_kwargs: dict):
357386 await self .before_delete_object (obj , view_kwargs )
358387 try :
359388 await self .session .delete (obj )
360- await self .session . commit ()
389+ await self .save ()
361390 except DBAPIError as e :
362391 await self .session .rollback ()
363392
@@ -377,7 +406,7 @@ async def delete_objects(self, objects: List[TypeModel], view_kwargs: dict):
377406
378407 try :
379408 await self .session .execute (query )
380- await self .session . commit ()
409+ await self .save ()
381410 except DBAPIError as e :
382411 await self .session .rollback ()
383412 raise InternalServerError (
0 commit comments