Skip to content

Commit 4a7e1e8

Browse files
Merge pull request #1399 from datajoint/fix/job-reserve-race-1398
fix: Atomic job reservation to prevent race condition
2 parents 6327a82 + 2bd8b26 commit 4a7e1e8

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

src/datajoint/jobs.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import platform
1414
import subprocess
1515

16-
from .condition import AndList, Not
16+
from .condition import AndList, Not, make_condition
1717
from .errors import DataJointError, DuplicateError
1818
from .heading import Heading
1919
from .table import Table
@@ -431,8 +431,10 @@ def reserve(self, key: dict) -> bool:
431431
"""
432432
Attempt to reserve a pending job for processing.
433433
434-
Updates status to ``'reserved'`` if currently ``'pending'`` and
435-
``scheduled_time <= now``.
434+
Atomically updates status to ``'reserved'`` if currently ``'pending'``
435+
and ``scheduled_time <= now``, using a single UPDATE with a WHERE clause
436+
that includes the status check. This prevents race conditions where
437+
multiple workers could reserve the same job simultaneously.
436438
437439
Parameters
438440
----------
@@ -444,33 +446,26 @@ def reserve(self, key: dict) -> bool:
444446
bool
445447
True if reservation successful, False if job not available.
446448
"""
447-
# Check if job is pending and scheduled (use CURRENT_TIMESTAMP(3) for datetime(3) precision)
448-
job = (self & key & "status='pending'" & "scheduled_time <= CURRENT_TIMESTAMP(3)").to_dicts()
449-
450-
if not job:
451-
return False
452-
453-
# Get server time for reserved_time
454-
server_now = self.connection.query("SELECT CURRENT_TIMESTAMP").fetchone()[0]
455-
456-
# Build update row with primary key and new values
457449
pk = self._get_pk(key)
458-
update_row = {
459-
**pk,
460-
"status": "reserved",
461-
"reserved_time": server_now,
462-
"host": platform.node(),
463-
"pid": os.getpid(),
464-
"connection_id": self.connection.connection_id,
465-
"user": self.connection.get_user(),
466-
"version": _get_job_version(),
467-
}
468-
469-
try:
470-
self.update1(update_row)
471-
return True
472-
except Exception:
473-
return False
450+
where = make_condition(self, pk, set())
451+
qi = self.adapter.quote_identifier
452+
assignments = ", ".join(f"{qi(k)}=%s" for k in ("status", "host", "pid", "connection_id", "user", "version"))
453+
query = (
454+
f"UPDATE {self.full_table_name} "
455+
f"SET {assignments}, {qi('reserved_time')}=CURRENT_TIMESTAMP(3) "
456+
f"WHERE {where} AND {qi('status')}='pending' "
457+
f"AND {qi('scheduled_time')} <= CURRENT_TIMESTAMP(3)"
458+
)
459+
args = [
460+
"reserved",
461+
platform.node(),
462+
os.getpid(),
463+
self.connection.connection_id,
464+
self.connection.get_user(),
465+
_get_job_version(),
466+
]
467+
cursor = self.connection.query(query, args=args)
468+
return cursor.rowcount == 1
474469

475470
def complete(self, key: dict, duration: float | None = None) -> None:
476471
"""

0 commit comments

Comments
 (0)