diff --git a/src/datajoint/jobs.py b/src/datajoint/jobs.py index e5499eb8e..5a0eb2a86 100644 --- a/src/datajoint/jobs.py +++ b/src/datajoint/jobs.py @@ -13,7 +13,7 @@ import platform import subprocess -from .condition import AndList, Not +from .condition import AndList, Not, make_condition from .errors import DataJointError, DuplicateError from .heading import Heading from .table import Table @@ -431,8 +431,10 @@ def reserve(self, key: dict) -> bool: """ Attempt to reserve a pending job for processing. - Updates status to ``'reserved'`` if currently ``'pending'`` and - ``scheduled_time <= now``. + Atomically updates status to ``'reserved'`` if currently ``'pending'`` + and ``scheduled_time <= now``, using a single UPDATE with a WHERE clause + that includes the status check. This prevents race conditions where + multiple workers could reserve the same job simultaneously. Parameters ---------- @@ -444,33 +446,26 @@ def reserve(self, key: dict) -> bool: bool True if reservation successful, False if job not available. """ - # Check if job is pending and scheduled (use CURRENT_TIMESTAMP(3) for datetime(3) precision) - job = (self & key & "status='pending'" & "scheduled_time <= CURRENT_TIMESTAMP(3)").to_dicts() - - if not job: - return False - - # Get server time for reserved_time - server_now = self.connection.query("SELECT CURRENT_TIMESTAMP").fetchone()[0] - - # Build update row with primary key and new values pk = self._get_pk(key) - update_row = { - **pk, - "status": "reserved", - "reserved_time": server_now, - "host": platform.node(), - "pid": os.getpid(), - "connection_id": self.connection.connection_id, - "user": self.connection.get_user(), - "version": _get_job_version(), - } - - try: - self.update1(update_row) - return True - except Exception: - return False + where = make_condition(self, pk, set()) + qi = self.adapter.quote_identifier + assignments = ", ".join(f"{qi(k)}=%s" for k in ("status", "host", "pid", "connection_id", "user", "version")) + query = ( + f"UPDATE {self.full_table_name} " + f"SET {assignments}, {qi('reserved_time')}=CURRENT_TIMESTAMP(3) " + f"WHERE {where} AND {qi('status')}='pending' " + f"AND {qi('scheduled_time')} <= CURRENT_TIMESTAMP(3)" + ) + args = [ + "reserved", + platform.node(), + os.getpid(), + self.connection.connection_id, + self.connection.get_user(), + _get_job_version(), + ] + cursor = self.connection.query(query, args=args) + return cursor.rowcount == 1 def complete(self, key: dict, duration: float | None = None) -> None: """