Skip to content

Commit a06ad87

Browse files
authored
Enable job insertion on both transactions and top level pool (#9)
A problem with the API currently is that is that inserts on a top-level SQLAlchemy engine aren't supported. In the test suite, the driver is always initialized with a transaction: with engine.begin() as tx: yield riversqlalchemy.Driver(tx) But if trying to initialize the driver with the engine itself, we'd see failures because it doesn't support the necessary SQL execution methods required to insert a job. Additionally, when checking for a unique job, the driver always opens a transaction with `self.conn.begin_nested()`, which will error unless in a transaction already. In the Ruby client there's only a single `River#insert` method because convention in that language is to have a lot of global stuff going on, and the current transaction is added to local thread storage by Sequel or ActiveRecord so that we can determine whether or not we're already in a transaction from within the client. SQLAlchemy is a little different. After opening a transaction, you're expected to track the session object for perform operations in it: with session.begin(): session.add(some_object()) session.add(some_other_object()) So here I'm proposing a client API that looks a little more like the Go API, with a pair of insert functions for `#insert` and `#insert_tx`, the former doing an insert on the connection originally passed to driver, and the latter doing an insert on a transaction sent as argument. So this style of top-level insertion is now possible: engine = sqlalchemy.create_engine(database_url) client = riverqueue.Client(riversqlalchemy.Driver(engine)) insert_res = client.insert( MyJobArgs(), insert_opts=riverqueue.InsertOpts( unique_opts=riverqueue.UniqueOpts(by_period=900) ), ) The client opens a new transaction on the driver's engine, and inserts the job there. Insertion on a particular transaction continues to be possible, and now without having to reinitialize the client/driver: with engine.begin() as session: insert_res = client.insert_tx( session, MyJobArgs(), insert_opts=riverqueue.InsertOpts( unique_opts=riverqueue.UniqueOpts(by_period=900) ), ) print(insert_res)
1 parent 52b6850 commit a06ad87

File tree

6 files changed

+192
-69
lines changed

6 files changed

+192
-69
lines changed

src/riverqueue/client.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import datetime, timezone, timedelta
33
from typing import Any, Optional, Protocol, Tuple, List, Callable
44

5-
from .driver import GetParams, JobInsertParams, DriverProtocol
5+
from .driver import GetParams, JobInsertParams, DriverProtocol, ExecutorProtocol
66
from .model import InsertResult
77
from .fnv import fnv1_hash
88

@@ -53,29 +53,56 @@ def __init__(self, driver: DriverProtocol, advisory_lock_prefix=None):
5353
def insert(
5454
self, args: Args, insert_opts: Optional[InsertOpts] = None
5555
) -> InsertResult:
56+
with self.driver.executor() as exec:
57+
if not insert_opts:
58+
insert_opts = InsertOpts()
59+
insert_params, unique_opts = self.__make_insert_params(args, insert_opts)
60+
61+
def insert():
62+
return InsertResult(exec.job_insert(insert_params))
63+
64+
return self.__check_unique_job(exec, insert_params, unique_opts, insert)
65+
66+
def insert_tx(
67+
self, tx, args: Args, insert_opts: Optional[InsertOpts] = None
68+
) -> InsertResult:
69+
exec = self.driver.unwrap_executor(tx)
5670
if not insert_opts:
5771
insert_opts = InsertOpts()
5872
insert_params, unique_opts = self.__make_insert_params(args, insert_opts)
5973

6074
def insert():
61-
print(self.driver)
62-
return InsertResult(self.driver.job_insert(insert_params))
75+
return InsertResult(exec.job_insert(insert_params))
6376

64-
return self.__check_unique_job(insert_params, unique_opts, insert)
77+
return self.__check_unique_job(exec, insert_params, unique_opts, insert)
6578

6679
def insert_many(self, args: List[Args]) -> List[InsertResult]:
67-
all_params = [
80+
with self.driver.executor() as exec:
81+
return [
82+
InsertResult(x)
83+
for x in exec.job_insert_many(self.__make_insert_params_many(args))
84+
]
85+
86+
def insert_many_tx(self, tx, args: List[Args]) -> List[InsertResult]:
87+
exec = self.driver.unwrap_executor(tx)
88+
return [
89+
InsertResult(x)
90+
for x in exec.job_insert_many(self.__make_insert_params_many(args))
91+
]
92+
93+
def __make_insert_params_many(self, args: List[Args]) -> List[JobInsertParams]:
94+
return [
6895
self.__make_insert_params(
6996
arg.args, arg.insert_opts or InsertOpts(), is_insert_many=True
7097
)[0]
7198
if isinstance(arg, InsertManyParams)
7299
else self.__make_insert_params(arg, InsertOpts(), is_insert_many=True)[0]
73100
for arg in args
74101
]
75-
return [InsertResult(x) for x in self.driver.job_insert_many(all_params)]
76102

77103
def __check_unique_job(
78104
self,
105+
exec: ExecutorProtocol,
79106
insert_params: JobInsertParams,
80107
unique_opts: Optional[UniqueOpts],
81108
insert_func: Callable[[], InsertResult],
@@ -125,25 +152,27 @@ def __check_unique_job(
125152
if not any_unique_opts:
126153
return insert_func()
127154

128-
with self.driver.transaction():
155+
with exec.transaction():
129156
if self.advisory_lock_prefix is None:
130157
lock_key = fnv1_hash(lock_str.encode("utf-8"), 64)
131158
else:
132159
prefix = self.advisory_lock_prefix
133160
lock_key = (prefix << 32) | fnv1_hash(lock_str.encode("utf-8"), 32)
134161

135162
lock_key = self.__uint64_to_int64(lock_key)
136-
self.driver.advisory_lock(lock_key)
163+
exec.advisory_lock(lock_key)
137164

138-
existing_job = self.driver.job_get_by_kind_and_unique_properties(get_params)
165+
existing_job = exec.job_get_by_kind_and_unique_properties(get_params)
139166
if existing_job:
140167
return InsertResult(existing_job, unique_skipped_as_duplicated=True)
141168

142169
return insert_func()
143170

144171
@staticmethod
145172
def __make_insert_params(
146-
args: Args, insert_opts: InsertOpts, is_insert_many: bool = False
173+
args: Args,
174+
insert_opts: InsertOpts,
175+
is_insert_many: bool = False,
147176
) -> Tuple[JobInsertParams, Optional[UniqueOpts]]:
148177
if not hasattr(args, "kind"):
149178
raise Exception("args should respond to `kind`")

src/riverqueue/driver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Reexport for more ergonomic use in calling code.
22
from .driver_protocol import (
3+
ExecutorProtocol as ExecutorProtocol,
34
GetParams as GetParams,
45
JobInsertParams as JobInsertParams,
56
DriverProtocol as DriverProtocol,

src/riverqueue/driver/driver_protocol.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from contextlib import _GeneratorContextManager, contextmanager
12
from dataclasses import dataclass, field
23
from datetime import datetime
3-
from typing import Any, List, ContextManager, Optional, Protocol
4+
from typing import Any, Iterator, List, Optional, Protocol
45

56
from ..model import Job
67

@@ -34,11 +35,11 @@ class JobInsertParams:
3435
finalized_at: Optional[datetime] = None
3536

3637

37-
class DriverProtocol(Protocol):
38+
class ExecutorProtocol(Protocol):
3839
def advisory_lock(self, lock: int) -> None:
3940
pass
4041

41-
def job_insert(self, insert_params: JobInsertParams) -> Optional[Job]:
42+
def job_insert(self, insert_params: JobInsertParams) -> Job:
4243
pass
4344

4445
def job_insert_many(self, all_params) -> List[Job]:
@@ -49,5 +50,14 @@ def job_get_by_kind_and_unique_properties(
4950
) -> Optional[Job]:
5051
pass
5152

52-
def transaction(self) -> ContextManager:
53+
def transaction(self) -> _GeneratorContextManager:
54+
pass
55+
56+
57+
class DriverProtocol(Protocol):
58+
@contextmanager
59+
def executor(self) -> Iterator[ExecutorProtocol]:
60+
pass
61+
62+
def unwrap_executor(self, tx) -> ExecutorProtocol:
5363
pass
Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,64 @@
11
from contextlib import contextmanager
2-
from typing import Optional, List, Generator
2+
from sqlalchemy import Engine
3+
from sqlalchemy.engine import Connection
4+
from typing import Iterator, Optional, List, Generator, cast
35

4-
from ...driver import DriverProtocol, GetParams, JobInsertParams
6+
from ...driver import DriverProtocol, ExecutorProtocol, GetParams, JobInsertParams
57
from ...model import Job
68
from . import river_job, pg_misc
79

810

9-
class Driver(DriverProtocol):
10-
def __init__(self, session):
11-
self.session = session
12-
self.pg_misc_querier = pg_misc.Querier(session)
13-
self.job_querier = river_job.Querier(session)
11+
class Executor(ExecutorProtocol):
12+
def __init__(self, conn: Connection):
13+
self.conn = conn
14+
self.pg_misc_querier = pg_misc.Querier(conn)
15+
self.job_querier = river_job.Querier(conn)
1416

1517
def advisory_lock(self, key: int) -> None:
1618
self.pg_misc_querier.pg_advisory_xact_lock(key=key)
1719

18-
def job_insert(self, insert_params: JobInsertParams) -> Optional[Job]:
19-
return self.job_querier.job_insert_fast(insert_params)
20+
def job_insert(self, insert_params: JobInsertParams) -> Job:
21+
return cast(
22+
Job,
23+
self.job_querier.job_insert_fast(
24+
cast(river_job.JobInsertFastParams, insert_params)
25+
),
26+
)
2027

2128
def job_insert_many(self, all_params) -> List[Job]:
2229
raise NotImplementedError("sqlc doesn't implement copy in python yet")
2330

2431
def job_get_by_kind_and_unique_properties(
2532
self, get_params: GetParams
2633
) -> Optional[Job]:
27-
return self.job_querier.job_get_by_kind_and_unique_properties(get_params)
34+
return cast(
35+
Optional[Job],
36+
self.job_querier.job_get_by_kind_and_unique_properties(
37+
cast(river_job.JobGetByKindAndUniquePropertiesParams, get_params)
38+
),
39+
)
2840

2941
@contextmanager
3042
def transaction(self) -> Generator:
31-
session = self.session
32-
with session.begin_nested():
33-
yield
43+
if self.conn.in_transaction():
44+
with self.conn.begin_nested():
45+
yield
46+
else:
47+
with self.conn.begin():
48+
yield
49+
50+
51+
class Driver(DriverProtocol):
52+
def __init__(self, conn: Connection | Engine):
53+
self.conn = conn
54+
55+
@contextmanager
56+
def executor(self) -> Iterator[ExecutorProtocol]:
57+
if isinstance(self.conn, Engine):
58+
with self.conn.begin() as tx:
59+
yield Executor(tx)
60+
else:
61+
yield Executor(self.conn)
62+
63+
def unwrap_executor(self, tx) -> ExecutorProtocol:
64+
return Executor(tx)

0 commit comments

Comments
 (0)