Skip to content

Commit f5e2db3

Browse files
jim-obrien-origwwwjfy
authored andcommitted
factory method gino.ext.starlette (#538)
* factory method starlette. * Test for factory method gino.ext.starlette. * remove unused import * Chgs to PR rec by @wwwjfy, rm chgs Pipfile.lock, .gitignore. * Revert "Chgs to PR rec by @wwwjfy, rm chgs Pipfile.lock, .gitignore." This reverts commit c75f26a. * changes from suggestions on PR
1 parent 04e84ec commit f5e2db3

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

gino/ext/starlette.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# noinspection PyPackageRequirements
2-
from starlette.applications import Starlette
2+
from sqlalchemy.engine.url import URL
33
# noinspection PyPackageRequirements
4-
from starlette.types import Message, Receive, Scope, Send
4+
from starlette import status
55
# noinspection PyPackageRequirements
66
from starlette.exceptions import HTTPException
77
# noinspection PyPackageRequirements
8-
from starlette import status
9-
from sqlalchemy.engine.url import URL
8+
from starlette.types import Message, Receive, Scope, Send
109

1110
from ..api import Gino as _Gino, GinoExecutor as _Executor
1211
from ..engine import GinoConnection as _Connection, GinoEngine as _Engine
@@ -94,6 +93,7 @@ async def receiver() -> Message:
9493
elif message["type"] == "lifespan.shutdown":
9594
await self.db.pop_bind().close()
9695
return message
96+
9797
await self.app(scope, receiver, send)
9898
return
9999

@@ -146,7 +146,7 @@ class Gino(_Gino):
146146
model_base_classes = _Gino.model_base_classes + (StarletteModelMixin,)
147147
query_executor = GinoExecutor
148148

149-
def __init__(self, app: Starlette, *args, **kwargs):
149+
def __init__(self, app=None, *args, **kwargs):
150150
self.config = dict()
151151
if 'dsn' in kwargs:
152152
self.config['dsn'] = kwargs.pop('dsn')
@@ -168,7 +168,10 @@ def __init__(self, app: Starlette, *args, **kwargs):
168168
self.config['kwargs'] = kwargs.pop('kwargs', dict())
169169

170170
super().__init__(*args, **kwargs)
171+
if app is not None:
172+
self.init_app(app)
171173

174+
def init_app(self, app):
172175
app.add_middleware(_Middleware, db=self)
173176

174177
async def first_or_404(self, *args, **kwargs):

tests/test_starlette.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ async def _app(**kwargs):
2828
max_inactive_connection_lifetime=_MAX_INACTIVE_CONNECTION_LIFETIME,
2929
),
3030
})
31-
db = Gino(app, **kwargs)
31+
factory = kwargs.pop('factory', False)
32+
33+
if factory:
34+
db = Gino(**kwargs)
35+
db.init_app(app)
36+
else:
37+
db = Gino(app, **kwargs)
3238

3339
class User(db.Model):
3440
__tablename__ = 'gino_users'
@@ -95,6 +101,18 @@ async def app():
95101
)
96102

97103

104+
@pytest.fixture
105+
@async_generator
106+
async def app_factory():
107+
await _app(
108+
factory=True,
109+
host=DB_ARGS['host'],
110+
port=DB_ARGS['port'],
111+
user=DB_ARGS['user'],
112+
password=DB_ARGS['password'],
113+
database=DB_ARGS['database'],
114+
)
115+
98116
@pytest.fixture
99117
@async_generator
100118
async def app_ssl(ssl_ctx):
@@ -157,3 +175,6 @@ def test_ssl(app_ssl):
157175

158176
def test_dsn(app_dsn):
159177
_test(app_dsn)
178+
179+
def test_app_factory(app_factory):
180+
_test(app_factory)

0 commit comments

Comments
 (0)