From ac024cf6ffb4c39232c162b280e7d0b9a334840f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 13 Feb 2026 11:41:51 +0500 Subject: [PATCH 01/15] Rename tasks/ to scheduled_tasks/ --- src/dstack/_internal/server/app.py | 6 +- .../_internal/server/background/__init__.py | 142 ---------------- .../background/scheduled_tasks/__init__.py | 156 ++++++++++++++++++ .../{tasks => scheduled_tasks}/common.py | 0 .../compute_groups.py} | 0 .../events.py} | 0 .../fleets.py} | 0 .../gateways.py} | 0 .../idle_volumes.py} | 0 .../instances.py} | 2 +- .../metrics.py} | 0 .../placement_groups.py} | 0 .../probes.py} | 0 .../prometheus_metrics.py} | 0 .../running_jobs.py} | 2 +- .../runs.py} | 0 .../submitted_jobs.py} | 4 +- .../terminating_jobs.py} | 0 .../volumes.py} | 0 .../services/jobs/configurators/base.py | 2 +- .../background/scheduled_tasks}/__init__.py | 0 .../test_process_compute_groups.py | 2 +- .../test_process_events.py | 2 +- .../test_process_fleets.py | 2 +- .../test_process_gateways.py | 2 +- .../test_process_idle_volumes.py | 2 +- .../test_process_instances.py | 42 ++--- .../test_process_metrics.py | 2 +- .../test_process_placement_groups.py | 2 +- .../test_process_probes.py | 6 +- .../test_process_prometheus_metrics.py | 2 +- .../test_process_running_jobs.py | 2 +- .../test_process_runs.py | 12 +- .../test_process_submitted_jobs.py | 2 +- .../test_process_submitted_volumes.py | 2 +- .../test_process_terminating_jobs.py | 2 +- .../server/background/tasks/__init__.py | 0 .../_internal/server/routers/test_runs.py | 2 +- 38 files changed, 208 insertions(+), 192 deletions(-) create mode 100644 src/dstack/_internal/server/background/scheduled_tasks/__init__.py rename src/dstack/_internal/server/background/{tasks => scheduled_tasks}/common.py (100%) rename src/dstack/_internal/server/background/{tasks/process_compute_groups.py => scheduled_tasks/compute_groups.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_events.py => scheduled_tasks/events.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_fleets.py => scheduled_tasks/fleets.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_gateways.py => scheduled_tasks/gateways.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_idle_volumes.py => scheduled_tasks/idle_volumes.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_instances.py => scheduled_tasks/instances.py} (99%) rename src/dstack/_internal/server/background/{tasks/process_metrics.py => scheduled_tasks/metrics.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_placement_groups.py => scheduled_tasks/placement_groups.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_probes.py => scheduled_tasks/probes.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_prometheus_metrics.py => scheduled_tasks/prometheus_metrics.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_running_jobs.py => scheduled_tasks/running_jobs.py} (99%) rename src/dstack/_internal/server/background/{tasks/process_runs.py => scheduled_tasks/runs.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_submitted_jobs.py => scheduled_tasks/submitted_jobs.py} (99%) rename src/dstack/_internal/server/background/{tasks/process_terminating_jobs.py => scheduled_tasks/terminating_jobs.py} (100%) rename src/dstack/_internal/server/background/{tasks/process_volumes.py => scheduled_tasks/volumes.py} (100%) rename src/{dstack/_internal/server/background/tasks => tests/_internal/server/background/scheduled_tasks}/__init__.py (100%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_compute_groups.py (97%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_events.py (94%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_fleets.py (98%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_gateways.py (98%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_idle_volumes.py (98%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_instances.py (96%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_metrics.py (98%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_placement_groups.py (94%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_probes.py (96%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_prometheus_metrics.py (98%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_running_jobs.py (99%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_runs.py (98%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_submitted_jobs.py (99%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_submitted_volumes.py (96%) rename src/tests/_internal/server/background/{tasks => scheduled_tasks}/test_process_terminating_jobs.py (99%) delete mode 100644 src/tests/_internal/server/background/tasks/__init__.py diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index dbea6f777b..fbeddfd182 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -23,8 +23,8 @@ from dstack._internal.proxy.lib.deps import get_injector_from_app from dstack._internal.proxy.lib.routers import model_proxy from dstack._internal.server import settings -from dstack._internal.server.background import start_background_tasks -from dstack._internal.server.background.tasks.process_probes import PROBES_SCHEDULER +from dstack._internal.server.background.scheduled_tasks import start_scheduled_tasks +from dstack._internal.server.background.scheduled_tasks.probes import PROBES_SCHEDULER from dstack._internal.server.db import get_db, get_session_ctx, migrate from dstack._internal.server.routers import ( auth, @@ -164,7 +164,7 @@ async def lifespan(app: FastAPI): init_default_storage() scheduler = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: - scheduler = start_background_tasks() + scheduler = start_scheduled_tasks() else: logger.info("Background processing is disabled") PROBES_SCHEDULER.start() diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 8577cce6f1..e69de29bb2 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -1,142 +0,0 @@ -from apscheduler.schedulers.asyncio import AsyncIOScheduler -from apscheduler.triggers.interval import IntervalTrigger - -from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_compute_groups import process_compute_groups -from dstack._internal.server.background.tasks.process_events import delete_events -from dstack._internal.server.background.tasks.process_fleets import process_fleets -from dstack._internal.server.background.tasks.process_gateways import ( - process_gateways, - process_gateways_connections, -) -from dstack._internal.server.background.tasks.process_idle_volumes import process_idle_volumes -from dstack._internal.server.background.tasks.process_instances import ( - delete_instance_health_checks, - process_instances, -) -from dstack._internal.server.background.tasks.process_metrics import ( - collect_metrics, - delete_metrics, -) -from dstack._internal.server.background.tasks.process_placement_groups import ( - process_placement_groups, -) -from dstack._internal.server.background.tasks.process_probes import process_probes -from dstack._internal.server.background.tasks.process_prometheus_metrics import ( - collect_prometheus_metrics, - delete_prometheus_metrics, -) -from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs -from dstack._internal.server.background.tasks.process_runs import process_runs -from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs -from dstack._internal.server.background.tasks.process_terminating_jobs import ( - process_terminating_jobs, -) -from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes - -_scheduler = AsyncIOScheduler() - - -def get_scheduler() -> AsyncIOScheduler: - return _scheduler - - -def start_background_tasks() -> AsyncIOScheduler: - # Background processing is implemented via in-memory locks on SQLite - # and SELECT FOR UPDATE on Postgres. Locks may be held for a long time. - # This is currently the main bottleneck for scaling dstack processing - # as processing more resources requires more DB connections. - # TODO: Make background processing efficient by committing locks to DB - # and processing outside of DB transactions. - # - # Now we just try to process as many resources as possible without exhausting DB connections. - # - # Quick tasks can process multiple resources per transaction. - # Potentially long tasks process one resource per transaction - # to avoid holding locks for all the resources if one is slow to process. - # Still, the next batch won't be processed unless all resources are processed, - # so larger batches do not increase processing rate linearly. - # - # The interval, batch_size, and max_instances determine background tasks processing rates. - # By default, one server replica can handle: - # - # * 150 active jobs with 2 minutes processing latency - # * 150 active runs with 2 minutes processing latency - # * 150 active instances with 2 minutes processing latency - # - # These latency numbers do not account for provisioning time, - # so it may be slower if a backend is slow to provision. - # - # Users can set SERVER_BACKGROUND_PROCESSING_FACTOR to process more resources per replica. - # They also need to increase max db connections on the client side and db side. - # - # In-memory locking via locksets does not guarantee - # that the first waiting for the lock will acquire it. - # The jitter is needed to give all tasks a chance to acquire locks. - - _scheduler.add_job(process_probes, IntervalTrigger(seconds=3, jitter=1)) - _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) - _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) - if settings.ENABLE_PROMETHEUS_METRICS: - _scheduler.add_job( - collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 - ) - _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) - _scheduler.add_job( - process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 - ) - _scheduler.add_job( - process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 - ) - _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) - _scheduler.add_job( - process_fleets, - IntervalTrigger(seconds=10, jitter=2), - max_instances=1, - ) - _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) - for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): - # Add multiple copies of tasks if requested. - # max_instances=1 for additional copies to avoid running too many tasks. - # Move other tasks here when they need per-replica scaling. - _scheduler.add_job( - process_submitted_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=4 if replica == 0 else 1, - ) - _scheduler.add_job( - process_running_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_terminating_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_runs, - IntervalTrigger(seconds=2, jitter=1), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_instances, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.add_job( - process_compute_groups, - IntervalTrigger(seconds=15, jitter=2), - kwargs={"batch_size": 1}, - max_instances=2 if replica == 0 else 1, - ) - _scheduler.start() - return _scheduler diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py new file mode 100644 index 0000000000..3bc17ecc40 --- /dev/null +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -0,0 +1,156 @@ +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.interval import IntervalTrigger + +from dstack._internal.server import settings +from dstack._internal.server.background.scheduled_tasks.compute_groups import ( + process_compute_groups, +) +from dstack._internal.server.background.scheduled_tasks.events import delete_events +from dstack._internal.server.background.scheduled_tasks.fleets import process_fleets +from dstack._internal.server.background.scheduled_tasks.gateways import ( + process_gateways, + process_gateways_connections, +) +from dstack._internal.server.background.scheduled_tasks.idle_volumes import ( + process_idle_volumes, +) +from dstack._internal.server.background.scheduled_tasks.instances import ( + delete_instance_health_checks, + process_instances, +) +from dstack._internal.server.background.scheduled_tasks.metrics import ( + collect_metrics, + delete_metrics, +) +from dstack._internal.server.background.scheduled_tasks.placement_groups import ( + process_placement_groups, +) +from dstack._internal.server.background.scheduled_tasks.probes import process_probes +from dstack._internal.server.background.scheduled_tasks.prometheus_metrics import ( + collect_prometheus_metrics, + delete_prometheus_metrics, +) +from dstack._internal.server.background.scheduled_tasks.running_jobs import ( + process_running_jobs, +) +from dstack._internal.server.background.scheduled_tasks.runs import process_runs +from dstack._internal.server.background.scheduled_tasks.submitted_jobs import ( + process_submitted_jobs, +) +from dstack._internal.server.background.scheduled_tasks.terminating_jobs import ( + process_terminating_jobs, +) +from dstack._internal.server.background.scheduled_tasks.volumes import ( + process_submitted_volumes, +) + +_scheduler = AsyncIOScheduler() + + +def get_scheduler() -> AsyncIOScheduler: + return _scheduler + + +def start_scheduled_tasks() -> AsyncIOScheduler: + """ + Start periodic tasks triggered by `apscheduler` at specific times/intervals. + Suitable for tasks that run infrequently and don't need to lock rows for a long time. + """ + # Background processing is implemented via in-memory locks on SQLite + # and SELECT FOR UPDATE on Postgres. Locks may be held for a long time. + # This is currently the main bottleneck for scaling dstack processing + # as processing more resources requires more DB connections. + # TODO: Make background processing efficient by committing locks to DB + # and processing outside of DB transactions. + # + # Now we just try to process as many resources as possible without exhausting DB connections. + # + # Quick tasks can process multiple resources per transaction. + # Potentially long tasks process one resource per transaction + # to avoid holding locks for all the resources if one is slow to process. + # Still, the next batch won't be processed unless all resources are processed, + # so larger batches do not increase processing rate linearly. + # + # The interval, batch_size, and max_instances determine background tasks processing rates. + # By default, one server replica can handle: + # + # * 150 active jobs with 2 minutes processing latency + # * 150 active runs with 2 minutes processing latency + # * 150 active instances with 2 minutes processing latency + # + # These latency numbers do not account for provisioning time, + # so it may be slower if a backend is slow to provision. + # + # Users can set SERVER_BACKGROUND_PROCESSING_FACTOR to process more resources per replica. + # They also need to increase max db connections on the client side and db side. + # + # In-memory locking via locksets does not guarantee + # that the first waiting for the lock will acquire it. + # The jitter is needed to give all tasks a chance to acquire locks. + + _scheduler.add_job(process_probes, IntervalTrigger(seconds=3, jitter=1)) + _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) + _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) + _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) + if settings.ENABLE_PROMETHEUS_METRICS: + _scheduler.add_job( + collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 + ) + _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) + _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) + _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) + _scheduler.add_job( + process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 + ) + _scheduler.add_job( + process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 + ) + _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) + _scheduler.add_job( + process_fleets, + IntervalTrigger(seconds=10, jitter=2), + max_instances=1, + ) + _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) + for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): + # Add multiple copies of tasks if requested. + # max_instances=1 for additional copies to avoid running too many tasks. + # Move other tasks here when they need per-replica scaling. + _scheduler.add_job( + process_submitted_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=4 if replica == 0 else 1, + ) + _scheduler.add_job( + process_running_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.add_job( + process_terminating_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.add_job( + process_runs, + IntervalTrigger(seconds=2, jitter=1), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.add_job( + process_instances, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.add_job( + process_compute_groups, + IntervalTrigger(seconds=15, jitter=2), + kwargs={"batch_size": 1}, + max_instances=2 if replica == 0 else 1, + ) + _scheduler.start() + return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/common.py b/src/dstack/_internal/server/background/scheduled_tasks/common.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/common.py rename to src/dstack/_internal/server/background/scheduled_tasks/common.py diff --git a/src/dstack/_internal/server/background/tasks/process_compute_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_compute_groups.py rename to src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py diff --git a/src/dstack/_internal/server/background/tasks/process_events.py b/src/dstack/_internal/server/background/scheduled_tasks/events.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_events.py rename to src/dstack/_internal/server/background/scheduled_tasks/events.py diff --git a/src/dstack/_internal/server/background/tasks/process_fleets.py b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_fleets.py rename to src/dstack/_internal/server/background/scheduled_tasks/fleets.py diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_gateways.py rename to src/dstack/_internal/server/background/scheduled_tasks/gateways.py diff --git a/src/dstack/_internal/server/background/tasks/process_idle_volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_idle_volumes.py rename to src/dstack/_internal/server/background/scheduled_tasks/idle_volumes.py diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py similarity index 99% rename from src/dstack/_internal/server/background/tasks/process_instances.py rename to src/dstack/_internal/server/background/scheduled_tasks/instances.py index da47cf16ed..196f347c4f 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -59,7 +59,7 @@ JobProvisioningData, ) from dstack._internal.server import settings as server_settings -from dstack._internal.server.background.tasks.common import get_provisioning_timeout +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( FleetModel, diff --git a/src/dstack/_internal/server/background/tasks/process_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/metrics.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_metrics.py rename to src/dstack/_internal/server/background/scheduled_tasks/metrics.py diff --git a/src/dstack/_internal/server/background/tasks/process_placement_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_placement_groups.py rename to src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py diff --git a/src/dstack/_internal/server/background/tasks/process_probes.py b/src/dstack/_internal/server/background/scheduled_tasks/probes.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_probes.py rename to src/dstack/_internal/server/background/scheduled_tasks/probes.py diff --git a/src/dstack/_internal/server/background/tasks/process_prometheus_metrics.py b/src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_prometheus_metrics.py rename to src/dstack/_internal/server/background/scheduled_tasks/prometheus_metrics.py diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py similarity index 99% rename from src/dstack/_internal/server/background/tasks/process_running_jobs.py rename to src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 7275106ceb..f413edf44b 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -37,7 +37,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint -from dstack._internal.server.background.tasks.common import get_provisioning_timeout +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( FleetModel, diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_runs.py rename to src/dstack/_internal/server/background/scheduled_tasks/runs.py diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py similarity index 99% rename from src/dstack/_internal/server/background/tasks/process_submitted_jobs.py rename to src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index a021096613..79746e9338 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -57,7 +57,9 @@ from dstack._internal.core.models.volumes import Volume from dstack._internal.core.services.profiles import get_termination from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_compute_groups import ComputeGroupStatus +from dstack._internal.server.background.scheduled_tasks.compute_groups import ( + ComputeGroupStatus, +) from dstack._internal.server.db import ( get_db, get_session_ctx, diff --git a/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_terminating_jobs.py rename to src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py diff --git a/src/dstack/_internal/server/background/tasks/process_volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/process_volumes.py rename to src/dstack/_internal/server/background/scheduled_tasks/volumes.py diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 3b6038ccd9..d28bb3577b 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -78,7 +78,7 @@ def get_default_python_verison() -> str: def get_default_image(nvcc: bool = False) -> str: """ Note: May be overridden by dstack (e.g., EFA-enabled version for AWS EFA-capable instances). - See `dstack._internal.server.background.tasks.process_running_jobs._patch_base_image_for_aws_efa` for details. + See `dstack._internal.server.background.scheduled_tasks.running_jobs._patch_base_image_for_aws_efa` for details. Args: nvcc: If True, returns 'devel' variant, otherwise 'base'. diff --git a/src/dstack/_internal/server/background/tasks/__init__.py b/src/tests/_internal/server/background/scheduled_tasks/__init__.py similarity index 100% rename from src/dstack/_internal/server/background/tasks/__init__.py rename to src/tests/_internal/server/background/scheduled_tasks/__init__.py diff --git a/src/tests/_internal/server/background/tasks/test_process_compute_groups.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py similarity index 97% rename from src/tests/_internal/server/background/tasks/test_process_compute_groups.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py index 11ce734606..5ca81b88a6 100644 --- a/src/tests/_internal/server/background/tasks/test_process_compute_groups.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py @@ -6,7 +6,7 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.server.background.tasks.process_compute_groups import ( +from dstack._internal.server.background.scheduled_tasks.compute_groups import ( ComputeGroupStatus, process_compute_groups, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_events.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_events.py similarity index 94% rename from src/tests/_internal/server/background/tasks/test_process_events.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_events.py index 21043e0bae..91eb066f58 100644 --- a/src/tests/_internal/server/background/tasks/test_process_events.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_events.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_events import delete_events +from dstack._internal.server.background.scheduled_tasks.events import delete_events from dstack._internal.server.services import events from dstack._internal.server.testing.common import create_user, list_events diff --git a/src/tests/_internal/server/background/tasks/test_process_fleets.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_fleets.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_fleets.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_fleets.py index ae7155c3ca..2ef1b27ab9 100644 --- a/src/tests/_internal/server/background/tasks/test_process_fleets.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_fleets.py @@ -6,7 +6,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole -from dstack._internal.server.background.tasks.process_fleets import process_fleets +from dstack._internal.server.background.scheduled_tasks.fleets import process_fleets from dstack._internal.server.models import InstanceModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_gateways.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_gateways.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_gateways.py index b280b8948d..5f19d2cfcd 100644 --- a/src/tests/_internal/server/background/tasks/test_process_gateways.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_gateways.py @@ -5,7 +5,7 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus -from dstack._internal.server.background.tasks.process_gateways import process_gateways +from dstack._internal.server.background.scheduled_tasks.gateways import process_gateways from dstack._internal.server.testing.common import ( AsyncContextManager, ComputeMockSpec, diff --git a/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_idle_volumes.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_idle_volumes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_idle_volumes.py index 9d73afbb78..6a7acf0c43 100644 --- a/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_idle_volumes.py @@ -6,7 +6,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeStatus -from dstack._internal.server.background.tasks.process_idle_volumes import ( +from dstack._internal.server.background.scheduled_tasks.idle_volumes import ( _get_idle_time, _should_delete_volume, process_idle_volumes, diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_instances.py similarity index 96% rename from src/tests/_internal/server/background/tasks/test_process_instances.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_instances.py index 8d94ee059b..1b9789953e 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_instances.py @@ -39,7 +39,7 @@ JobProvisioningData, JobStatus, ) -from dstack._internal.server.background.tasks.process_instances import ( +from dstack._internal.server.background.scheduled_tasks.instances import ( delete_instance_health_checks, process_instances, ) @@ -101,7 +101,7 @@ async def test_check_shim_transitions_provisioning_on_ready( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -130,7 +130,7 @@ async def test_check_shim_transitions_provisioning_on_terminating( health_reason = "Shim problem" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_reason) await process_instances() @@ -177,7 +177,7 @@ async def test_check_shim_transitions_provisioning_on_busy( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -202,7 +202,7 @@ async def test_check_shim_start_termination_deadline(self, test_db, session: Asy ) health_status = "SSH connection fail" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() @@ -232,7 +232,7 @@ async def test_check_shim_does_not_start_termination_deadline_with_ssh_instance( ) health_status = "SSH connection fail" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() @@ -257,7 +257,7 @@ async def test_check_shim_stop_termination_deadline(self, test_db, session: Asyn await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -283,7 +283,7 @@ async def test_check_shim_terminate_instance_by_deadline(self, test_db, session: health_status = "Not ok" with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() @@ -347,7 +347,7 @@ async def test_check_shim_process_ureachable_state( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() @@ -378,7 +378,7 @@ async def test_check_shim_switch_to_unreachable_state( ) with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck(reachable=False) await process_instances() @@ -412,7 +412,7 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes ) with patch( - "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + "dstack._internal.server.background.scheduled_tasks.instances._check_instance_inner" ) as healthcheck: healthcheck.return_value = InstanceCheck( reachable=True, health_response=health_response @@ -440,7 +440,7 @@ class TestRemoveDanglingTasks: @pytest.fixture def disable_maybe_install_components(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._maybe_install_components", + "dstack._internal.server.background.scheduled_tasks.instances._maybe_install_components", Mock(return_value=None), ) @@ -607,7 +607,7 @@ def mock_terminate_in_backend(error: Optional[Exception] = None): if error is not None: terminate_instance.side_effect = error with patch( - "dstack._internal.server.background.tasks.process_instances.backends_services.get_project_backend_by_type" + "dstack._internal.server.background.scheduled_tasks.instances.backends_services.get_project_backend_by_type" ) as get_backend: get_backend.return_value = backend yield terminate_instance @@ -1153,7 +1153,7 @@ def host_info(self) -> dict: def deploy_instance_mock(self, monkeypatch: pytest.MonkeyPatch, host_info: dict): mock = Mock(return_value=(InstanceCheck(reachable=True), host_info, GoArchType.AMD64)) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._deploy_instance", mock + "dstack._internal.server.background.scheduled_tasks.instances._deploy_instance", mock ) return mock @@ -1262,7 +1262,7 @@ def component_list(self) -> ComponentList: def debug_task_log(self, caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: caplog.set_level( level=logging.DEBUG, - logger="dstack._internal.server.background.tasks.process_instances", + logger="dstack._internal.server.background.scheduled_tasks.instances", ) return caplog @@ -1308,7 +1308,7 @@ def component_list(self) -> ComponentList: def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=self.EXPECTED_VERSION) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_version", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_runner_version", mock, ) return mock @@ -1317,7 +1317,7 @@ def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Moc def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value="https://example.com/runner") monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_download_url", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_runner_download_url", mock, ) return mock @@ -1424,7 +1424,7 @@ def component_list(self) -> ComponentList: def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=self.EXPECTED_VERSION) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_shim_version", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_shim_version", mock, ) return mock @@ -1433,7 +1433,7 @@ def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: def get_dstack_shim_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value="https://example.com/shim") monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances.get_dstack_shim_download_url", + "dstack._internal.server.background.scheduled_tasks.instances.get_dstack_shim_download_url", mock, ) return mock @@ -1547,7 +1547,7 @@ def component_list(self) -> ComponentList: def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=False) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._maybe_install_runner", + "dstack._internal.server.background.scheduled_tasks.instances._maybe_install_runner", mock, ) return mock @@ -1556,7 +1556,7 @@ def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: def maybe_install_shim_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(return_value=False) monkeypatch.setattr( - "dstack._internal.server.background.tasks.process_instances._maybe_install_shim", + "dstack._internal.server.background.scheduled_tasks.instances._maybe_install_shim", mock, ) return mock diff --git a/src/tests/_internal/server/background/tasks/test_process_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_metrics.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_metrics.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_metrics.py index 0be650a223..df52dd88e2 100644 --- a/src/tests/_internal/server/background/tasks/test_process_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_metrics.py @@ -10,7 +10,7 @@ from dstack._internal.core.models.runs import JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server import settings -from dstack._internal.server.background.tasks.process_metrics import ( +from dstack._internal.server.background.scheduled_tasks.metrics import ( collect_metrics, delete_metrics, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_placement_groups.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_placement_groups.py similarity index 94% rename from src/tests/_internal/server/background/tasks/test_process_placement_groups.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_placement_groups.py index a45051a48e..14b9d2189d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_placement_groups.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_placement_groups.py @@ -3,7 +3,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.server.background.tasks.process_placement_groups import ( +from dstack._internal.server.background.scheduled_tasks.placement_groups import ( process_placement_groups, ) from dstack._internal.server.testing.common import ( diff --git a/src/tests/_internal/server/background/tasks/test_process_probes.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_probes.py similarity index 96% rename from src/tests/_internal/server/background/tasks/test_process_probes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_probes.py index 928709dd7f..bfd569ab1b 100644 --- a/src/tests/_internal/server/background/tasks/test_process_probes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_probes.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.configurations import ProbeConfig, ServiceConfiguration from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus -from dstack._internal.server.background.tasks.process_probes import ( +from dstack._internal.server.background.scheduled_tasks.probes import ( PROCESSING_OVERHEAD_TIMEOUT, SSH_CONNECT_TIMEOUT, process_probes, @@ -140,7 +140,7 @@ async def test_schedules_probe_execution(self, test_db, session: AsyncSession) - processing_time = datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc) with freeze_time(processing_time): with patch( - "dstack._internal.server.background.tasks.process_probes.PROBES_SCHEDULER" + "dstack._internal.server.background.scheduled_tasks.probes.PROBES_SCHEDULER" ) as scheduler_mock: await process_probes() assert scheduler_mock.add_job.call_count == 2 @@ -210,7 +210,7 @@ async def test_deactivates_probe_when_until_ready_and_ready_after_reached( probe_regular = await create_probe(session, job, probe_num=1, success_streak=3) with patch( - "dstack._internal.server.background.tasks.process_probes.PROBES_SCHEDULER" + "dstack._internal.server.background.scheduled_tasks.probes.PROBES_SCHEDULER" ) as scheduler_mock: await process_probes() diff --git a/src/tests/_internal/server/background/tasks/test_process_prometheus_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_prometheus_metrics.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_prometheus_metrics.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_prometheus_metrics.py index 7c59b6dd1f..80961d5c10 100644 --- a/src/tests/_internal/server/background/tasks/test_process_prometheus_metrics.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_prometheus_metrics.py @@ -11,7 +11,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole -from dstack._internal.server.background.tasks.process_prometheus_metrics import ( +from dstack._internal.server.background.scheduled_tasks.prometheus_metrics import ( collect_prometheus_metrics, delete_prometheus_metrics, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_running_jobs.py similarity index 99% rename from src/tests/_internal/server/background/tasks/test_process_running_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_running_jobs.py index 12edeec208..0d748f4e91 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_running_jobs.py @@ -37,7 +37,7 @@ VolumeStatus, ) from dstack._internal.server import settings as server_settings -from dstack._internal.server.background.tasks.process_running_jobs import ( +from dstack._internal.server.background.scheduled_tasks.running_jobs import ( _patch_base_image_for_aws_efa, process_running_jobs, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_runs.py similarity index 98% rename from src/tests/_internal/server/background/tasks/test_process_runs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_runs.py index b9420d8e9a..ffb63de358 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_runs.py @@ -8,7 +8,7 @@ from pydantic import parse_obj_as from sqlalchemy.ext.asyncio import AsyncSession -import dstack._internal.server.background.tasks.process_runs as process_runs +import dstack._internal.server.background.scheduled_tasks.runs as process_runs from dstack._internal.core.models.configurations import ( ProbeConfig, ServiceConfiguration, @@ -100,7 +100,7 @@ async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): expected_duration = (current_time - run.submitted_at).total_seconds() with patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics: await process_runs.process_runs() @@ -131,7 +131,7 @@ async def test_keep_provisioning(self, test_db, session: AsyncSession): await create_job(session=session, run=run, status=JobStatus.PULLING) with patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics: await process_runs.process_runs() @@ -198,7 +198,7 @@ async def test_retry_running_to_pending(self, test_db, session: AsyncSession): with ( patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics, ): datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3) @@ -297,7 +297,7 @@ async def test_submitted_to_provisioning_if_any(self, test_db, session: AsyncSes expected_duration = (current_time - run.submitted_at).total_seconds() with patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics: await process_runs.process_runs() @@ -351,7 +351,7 @@ async def test_all_no_capacity_to_pending(self, test_db, session: AsyncSession): with ( patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, patch( - "dstack._internal.server.background.tasks.process_runs.run_metrics" + "dstack._internal.server.background.scheduled_tasks.runs.run_metrics" ) as mock_run_metrics, ): datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_jobs.py similarity index 99% rename from src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_jobs.py index 8a3a4b1d57..b06eb50ec2 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_jobs.py @@ -27,7 +27,7 @@ VolumeMountPoint, VolumeStatus, ) -from dstack._internal.server.background.tasks.process_submitted_jobs import ( +from dstack._internal.server.background.scheduled_tasks.submitted_jobs import ( _prepare_job_runtime_data, process_submitted_jobs, ) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_volumes.py similarity index 96% rename from src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_volumes.py index dfeef1e42e..8c9a6bf3cf 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_volumes.py @@ -5,7 +5,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus -from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes +from dstack._internal.server.background.scheduled_tasks.volumes import process_submitted_volumes from dstack._internal.server.testing.common import ( ComputeMockSpec, create_project, diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_terminating_jobs.py similarity index 99% rename from src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_process_terminating_jobs.py index 1d1c143d4f..d2b4d2d318 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_terminating_jobs.py @@ -10,7 +10,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus, JobTerminationReason from dstack._internal.core.models.volumes import VolumeStatus -from dstack._internal.server.background.tasks.process_terminating_jobs import ( +from dstack._internal.server.background.scheduled_tasks.terminating_jobs import ( process_terminating_jobs, ) from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel diff --git a/src/tests/_internal/server/background/tasks/__init__.py b/src/tests/_internal/server/background/tasks/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index ad8ad878d1..21b6e8f28f 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -67,7 +67,7 @@ list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str -from tests._internal.server.background.tasks.test_process_running_jobs import settings +from tests._internal.server.background.scheduled_tasks.test_process_running_jobs import settings pytestmark = pytest.mark.usefixtures("image_config_mock") From fd7830afa8f8ead05c4ab7c3edd8791a0cc50c1b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 13 Feb 2026 12:45:18 +0500 Subject: [PATCH 02/15] Support pipeline tasks --- src/dstack/_internal/server/app.py | 5 + .../background/pipeline_tasks/__init__.py | 45 +++ .../server/background/pipeline_tasks/base.py | 259 ++++++++++++++++++ .../_internal/server/services/pipelines.py | 12 + 4 files changed, 321 insertions(+) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/__init__.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/base.py create mode 100644 src/dstack/_internal/server/services/pipelines.py diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index fbeddfd182..0fbff6d383 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -23,6 +23,7 @@ from dstack._internal.proxy.lib.deps import get_injector_from_app from dstack._internal.proxy.lib.routers import model_proxy from dstack._internal.server import settings +from dstack._internal.server.background.pipeline_tasks import start_pipeline_tasks from dstack._internal.server.background.scheduled_tasks import start_scheduled_tasks from dstack._internal.server.background.scheduled_tasks.probes import PROBES_SCHEDULER from dstack._internal.server.db import get_db, get_session_ctx, migrate @@ -163,8 +164,10 @@ async def lifespan(app: FastAPI): if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() scheduler = None + pipeline_manager = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: scheduler = start_scheduled_tasks() + pipeline_manager = start_pipeline_tasks() else: logger.info("Background processing is disabled") PROBES_SCHEDULER.start() @@ -191,6 +194,8 @@ async def lifespan(app: FastAPI): yield if scheduler is not None: scheduler.shutdown() + if pipeline_manager is not None: + pipeline_manager.shutdown() PROBES_SCHEDULER.shutdown(wait=False) await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py new file mode 100644 index 0000000000..758a89568e --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -0,0 +1,45 @@ +from dstack._internal.server.background.pipeline_tasks.base import Pipeline +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class PipelineManager: + def __init__(self) -> None: + self._pipelines: list[Pipeline] = [] # Pipelines will go here + self._hinter = PipelineHinter(self._pipelines) + + def start(self): + for pipeline in self._pipelines: + pipeline.start() + + def shutdown(self): + for pipeline in self._pipelines: + pipeline.shutdown() + + @property + def hinter(self): + return self._hinter + + +class PipelineHinter: + def __init__(self, pipelines: list[Pipeline]) -> None: + self._pipelines = pipelines + self._hint_fetch_map = {p.hint_fetch_model_name: p for p in self._pipelines} + + def hint_fetch(self, model_name: str): + pipeline = self._hint_fetch_map.get(model_name) + if pipeline is None: + logger.warning("Model %s not registered for fetch hints", model_name) + return + pipeline.hint_fetch() + + +def start_pipeline_tasks() -> PipelineManager: + """ + Start tasks processed by fetch-workers pipelines based on db + in-memory queues. + Suitable for tasks that run frequently and need to lock rows for a long time. + """ + pipeline_manager = PipelineManager() + pipeline_manager.start() + return pipeline_manager diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py new file mode 100644 index 0000000000..3704a593cf --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -0,0 +1,259 @@ +import asyncio +import math +import random +import uuid +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypeVar + +from sqlalchemy import and_, or_, update +from sqlalchemy.orm import Mapped + +from dstack._internal.server.db import get_session_ctx +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class PipelineItem(Protocol): + id: uuid.UUID + lock_expires_at: datetime + lock_token: uuid.UUID + + +class PipelineModel(Protocol): + id: Mapped[uuid.UUID] + lock_expires_at: Mapped[Optional[datetime]] + lock_token: Mapped[Optional[uuid.UUID]] + + __mapper__: ClassVar[Any] + __table__: ClassVar[Any] + + +class Pipeline(ABC): + def __init__( + self, + workers_num: int, + queue_lower_limit_factor: float, + queue_upper_limit_factor: float, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeat_trigger: timedelta, + ) -> None: + self._workers_num = workers_num + self._queue_lower_limit_factor = queue_lower_limit_factor + self._queue_upper_limit_factor = queue_upper_limit_factor + self._queue_desired_minsize = math.ceil(workers_num * queue_lower_limit_factor) + self._queue_maxsize = math.ceil(workers_num * queue_upper_limit_factor) + self._min_processing_interval = min_processing_interval + self._lock_timeout = lock_timeout + self._heartbeat_trigger = heartbeat_trigger + self._queue = asyncio.Queue[PipelineItem](maxsize=self._queue_maxsize) + + def start(self): + asyncio.create_task(self._heartbeater.start()) + for worker in self._workers: + asyncio.create_task(worker.start()) + asyncio.create_task(self._fetcher.start()) + + def shutdown(self): + self._fetcher.shutdown() + self._heartbeater.shutdown() + + def hint_fetch(self): + self._fetcher.hint() + + @property + @abstractmethod + def hint_fetch_model_name(self) -> str: + pass + + @property + @abstractmethod + def _heartbeater(self) -> "Heartbeater": + pass + + @property + @abstractmethod + def _fetcher(self) -> "Fetcher": + pass + + @property + @abstractmethod + def _workers(self) -> Sequence["Worker"]: + pass + + +ModelT = TypeVar("ModelT", bound=PipelineModel) + + +class Heartbeater(Generic[ModelT]): + def __init__( + self, + model_type: type[ModelT], + lock_timeout: timedelta, + heartbeat_trigger: timedelta, + heartbeat_delay: float = 1.0, + ) -> None: + self._model_type = model_type + self._lock_timeout = lock_timeout + self._hearbeat_margin = heartbeat_trigger + self._items: dict[uuid.UUID, PipelineItem] = {} + self._untrack_lock = asyncio.Lock() + self._heartbeat_delay = heartbeat_delay + self._running = False + + async def start(self): + self._running = True + while self._running: + try: + await self.heartbeat() + except Exception: + logger.exception("Unexpected exception when running heartbeat") + await asyncio.sleep(self._heartbeat_delay) + + def shutdown(self): + self._running = False + + async def track(self, item: PipelineItem): + self._items[item.id] = item + + async def untrack(self, item: PipelineItem): + async with self._untrack_lock: + tracked = self._items.get(item.id) + # Prevent expired fetch iteration to unlock item processed by new iteration. + if tracked is not None and tracked.lock_token == item.lock_token: + del self._items[item.id] + + async def heartbeat(self): + updated_items: list[PipelineItem] = [] + now = get_current_datetime() + items = list(self._items.values()) + for item in items: + if item.lock_expires_at < now: + logger.warning( + "Failed to heartbeat item %s in time." + " The item is expected to be processed on another fetch iteration.", + item.id, + ) + await self.untrack(item) + elif item.lock_expires_at < now + self._hearbeat_margin: + updated_items.append(item) + if len(updated_items) == 0: + return + logger.debug("Updating lock_expires_at for items: %s", [str(r.id) for r in updated_items]) + async with get_session_ctx() as session: + per_item_filters = [ + and_( + self._model_type.id == item.id, self._model_type.lock_token == item.lock_token + ) + for item in updated_items + ] + res = await session.execute( + update(self._model_type) + .where(or_(*per_item_filters)) + .values(lock_expires_at=now + self._lock_timeout) + ) + if res.rowcount == 0: # pyright: ignore[reportAttributeAccessIssue] + logger.warning( + "Failed to update lock_expires_at: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration." + ) + return + for item in updated_items: + item.lock_expires_at = now + self._lock_timeout + + +class Fetcher(ABC): + _DEFAULT_FETCH_DELAYS = [0.5, 1, 2, 5] + + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater, + queue_check_delay: float = 1.0, + fetch_delays: Optional[list[float]] = None, + ) -> None: + self._queue = queue + self._queue_desired_minsize = queue_desired_minsize + self._min_processing_interval = min_processing_interval + self._lock_timeout = lock_timeout + self._heartbeater = heartbeater + self._queue_check_delay = queue_check_delay + if fetch_delays is None: + fetch_delays = self._DEFAULT_FETCH_DELAYS + self._fetch_delays = fetch_delays + self._running = False + self._fetch_event = asyncio.Event() + + async def start(self): + self._running = True + empty_fetch_count = 0 + while self._running: + if self._queue.qsize() >= self._queue_desired_minsize: + await asyncio.sleep(self._queue_check_delay) + continue + fetch_limit = self._queue.maxsize - self._queue.qsize() + try: + items = await self.fetch(limit=fetch_limit) + except Exception: + logger.exception("Unexpected exception when fetching new items") + items = [] + if len(items) == 0: + try: + await asyncio.wait_for( + self._fetch_event.wait(), + timeout=self._next_fetch_delay(empty_fetch_count), + ) + except TimeoutError: + pass + empty_fetch_count += 1 + self._fetch_event.clear() + continue + else: + empty_fetch_count = 0 + for item in items: + self._queue.put_nowait(item) # should never raise + await self._heartbeater.track(item) + + def shutdown(self): + self._running = False + + def hint(self): + self._fetch_event.set() + + @abstractmethod + async def fetch(self, limit: int) -> list[PipelineItem]: + pass + + def _next_fetch_delay(self, empty_fetch_count: int) -> float: + next_delay = self._fetch_delays[min(empty_fetch_count, len(self._fetch_delays) - 1)] + jitter = random.random() * 0.4 - 0.2 + return next_delay * (1 + jitter) + + +class Worker(ABC): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater, + ) -> None: + self._queue = queue + self._heartbeater = heartbeater + + async def start(self): + while True: + item = await self._queue.get() + try: + await self.process(item) + except Exception: + logger.exception("Unexpected exception when processing item") + await self._heartbeater.untrack(item) + + @abstractmethod + async def process(self, item: PipelineItem): + pass diff --git a/src/dstack/_internal/server/services/pipelines.py b/src/dstack/_internal/server/services/pipelines.py new file mode 100644 index 0000000000..19f4df902d --- /dev/null +++ b/src/dstack/_internal/server/services/pipelines.py @@ -0,0 +1,12 @@ +from typing import Protocol + +from fastapi import Request + + +class PipelineHinterProtocol(Protocol): + def hint_fetch(self, model_name: str) -> None: + pass + + +def get_pipeline_hinter(request: Request) -> PipelineHinterProtocol: + return request.app.state.pipeline_manager.hinter From 6056dba522c52a2417037be024b2de2cc374e789 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 13 Feb 2026 14:33:13 +0500 Subject: [PATCH 03/15] Stop workers on pipeline shutdown --- src/dstack/_internal/server/app.py | 6 +++--- .../server/background/pipeline_tasks/base.py | 19 ++++++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 0fbff6d383..7a754785bd 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -192,11 +192,11 @@ async def lifespan(app: FastAPI): for func in _ON_STARTUP_HOOKS: await func(app) yield - if scheduler is not None: - scheduler.shutdown() + PROBES_SCHEDULER.shutdown(wait=False) if pipeline_manager is not None: pipeline_manager.shutdown() - PROBES_SCHEDULER.shutdown(wait=False) + if scheduler is not None: + scheduler.shutdown() await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() await service_conn_pool.remove_all() diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 3704a593cf..5123494106 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -50,15 +50,18 @@ def __init__( self._lock_timeout = lock_timeout self._heartbeat_trigger = heartbeat_trigger self._queue = asyncio.Queue[PipelineItem](maxsize=self._queue_maxsize) + self._tasks: list[asyncio.Task] = [] def start(self): - asyncio.create_task(self._heartbeater.start()) + self._tasks.append(asyncio.create_task(self._heartbeater.start())) for worker in self._workers: - asyncio.create_task(worker.start()) - asyncio.create_task(self._fetcher.start()) + self._tasks.append(asyncio.create_task(worker.start())) + self._tasks.append(asyncio.create_task(self._fetcher.start())) def shutdown(self): self._fetcher.shutdown() + for worker in self._workers: + worker.shutdown() self._heartbeater.shutdown() def hint_fetch(self): @@ -244,15 +247,21 @@ def __init__( ) -> None: self._queue = queue self._heartbeater = heartbeater + self._running = False async def start(self): - while True: + self._running = True + while self._running: item = await self._queue.get() try: await self.process(item) except Exception: logger.exception("Unexpected exception when processing item") - await self._heartbeater.untrack(item) + finally: + await self._heartbeater.untrack(item) + + def shutdown(self): + self._running = False @abstractmethod async def process(self, item: PipelineItem): From b7346521f649d5441b5af21723cd6ac93ed50458 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 13 Feb 2026 15:43:05 +0500 Subject: [PATCH 04/15] Add pipeline draining --- src/dstack/_internal/server/app.py | 2 + .../background/pipeline_tasks/__init__.py | 14 +++++ .../server/background/pipeline_tasks/base.py | 53 ++++++++++++++++--- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 7a754785bd..dba3873c4b 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -197,6 +197,8 @@ async def lifespan(app: FastAPI): pipeline_manager.shutdown() if scheduler is not None: scheduler.shutdown() + if pipeline_manager is not None: + await pipeline_manager.drain() await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() await service_conn_pool.remove_all() diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 758a89568e..c3a5c8ab29 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -1,3 +1,5 @@ +import asyncio + from dstack._internal.server.background.pipeline_tasks.base import Pipeline from dstack._internal.utils.logging import get_logger @@ -17,6 +19,18 @@ def shutdown(self): for pipeline in self._pipelines: pipeline.shutdown() + async def drain(self): + results = await asyncio.gather( + *[p.drain() for p in self._pipelines], return_exceptions=True + ) + for pipeline, result in zip(self._pipelines, results): + if isinstance(result, BaseException): + logger.error( + "Unexpected exception when draining pipeline %r", + pipeline, + exc_info=(type(result), result, result.__traceback__), + ) + @property def hinter(self): return self._hinter diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 5123494106..e926b82a30 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -31,6 +31,10 @@ class PipelineModel(Protocol): __table__: ClassVar[Any] +class PipelineError(Exception): + pass + + class Pipeline(ABC): def __init__( self, @@ -51,18 +55,55 @@ def __init__( self._heartbeat_trigger = heartbeat_trigger self._queue = asyncio.Queue[PipelineItem](maxsize=self._queue_maxsize) self._tasks: list[asyncio.Task] = [] + self._running = False + self._shutdown = False def start(self): + """ + Starts all pipeline tasks. + """ + if self._running: + return + if self._shutdown: + raise PipelineError("Cannot start pipeline after shutdown.") + self._running = True self._tasks.append(asyncio.create_task(self._heartbeater.start())) for worker in self._workers: self._tasks.append(asyncio.create_task(worker.start())) self._tasks.append(asyncio.create_task(self._fetcher.start())) def shutdown(self): - self._fetcher.shutdown() + """ + Stops the pipeline from processing new items and signals running tasks to cancel. + """ + if self._shutdown: + return + self._shutdown = True + self._running = False + self._fetcher.stop() for worker in self._workers: - worker.shutdown() - self._heartbeater.shutdown() + worker.stop() + self._heartbeater.stop() + for task in self._tasks: + if not task.done(): + task.cancel() + + async def drain(self): + """ + Waits for all pipeline tasks to finish cleanup after shutdown. + """ + if not self._shutdown: + raise PipelineError("Cannot drain running pipeline. Call `shutdown()` first.") + results = await asyncio.gather(*self._tasks, return_exceptions=True) + for task, result in zip(self._tasks, results): + if isinstance(result, BaseException) and not isinstance( + result, asyncio.CancelledError + ): + logger.error( + "Unexpected exception when draining pipeline task %r", + task, + exc_info=(type(result), result, result.__traceback__), + ) def hint_fetch(self): self._fetcher.hint() @@ -116,7 +157,7 @@ async def start(self): logger.exception("Unexpected exception when running heartbeat") await asyncio.sleep(self._heartbeat_delay) - def shutdown(self): + def stop(self): self._running = False async def track(self, item: PipelineItem): @@ -223,7 +264,7 @@ async def start(self): self._queue.put_nowait(item) # should never raise await self._heartbeater.track(item) - def shutdown(self): + def stop(self): self._running = False def hint(self): @@ -260,7 +301,7 @@ async def start(self): finally: await self._heartbeater.untrack(item) - def shutdown(self): + def stop(self): self._running = False @abstractmethod From 1ff631b074a287b3fe533707e2f8d3d9ced7e36e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 16 Feb 2026 15:02:30 +0500 Subject: [PATCH 05/15] Use returning instead of rowcount for heartbeat --- src/dstack/_internal/server/app.py | 2 + .../server/background/pipeline_tasks/base.py | 51 ++++++++++++------- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index dba3873c4b..27593d3f9e 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -196,6 +196,8 @@ async def lifespan(app: FastAPI): if pipeline_manager is not None: pipeline_manager.shutdown() if scheduler is not None: + # Note: Scheduler does not cancel currently running jobs, so scheduled tasks cannot do cleanup. + # TODO: Track and cancel scheduled tasks. scheduler.shutdown() if pipeline_manager is not None: await pipeline_manager.drain() diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index e926b82a30..8bd8ec49a9 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -27,6 +27,7 @@ class PipelineModel(Protocol): lock_expires_at: Mapped[Optional[datetime]] lock_token: Mapped[Optional[uuid.UUID]] + __tablename__: str __mapper__: ClassVar[Any] __table__: ClassVar[Any] @@ -171,42 +172,56 @@ async def untrack(self, item: PipelineItem): del self._items[item.id] async def heartbeat(self): - updated_items: list[PipelineItem] = [] + items_to_update: list[PipelineItem] = [] now = get_current_datetime() items = list(self._items.values()) + failed_to_heartbeat_count = 0 for item in items: if item.lock_expires_at < now: - logger.warning( - "Failed to heartbeat item %s in time." - " The item is expected to be processed on another fetch iteration.", - item.id, - ) + failed_to_heartbeat_count += 1 await self.untrack(item) elif item.lock_expires_at < now + self._hearbeat_margin: - updated_items.append(item) - if len(updated_items) == 0: + items_to_update.append(item) + if failed_to_heartbeat_count > 0: + logger.warning( + "Failed to heartbeat %d %s items in time." + " The items are expected to be processed on another fetch iteration.", + failed_to_heartbeat_count, + self._model_type.__tablename__, + ) + if len(items_to_update) == 0: return - logger.debug("Updating lock_expires_at for items: %s", [str(r.id) for r in updated_items]) + logger.debug( + "Updating lock_expires_at for items: %s", [str(r.id) for r in items_to_update] + ) async with get_session_ctx() as session: per_item_filters = [ and_( self._model_type.id == item.id, self._model_type.lock_token == item.lock_token ) - for item in updated_items + for item in items_to_update ] res = await session.execute( update(self._model_type) .where(or_(*per_item_filters)) .values(lock_expires_at=now + self._lock_timeout) + .returning(self._model_type.id) + ) + updated_ids = set(res.scalars().all()) + failed_to_update_count = 0 + for item in items_to_update: + if item.id in updated_ids: + item.lock_expires_at = now + self._lock_timeout + else: + failed_to_update_count += 1 + await self.untrack(item) + if failed_to_update_count > 0: + logger.warning( + "Failed to update %s lock_expires_at of %d items: lock_token changed." + " The items are expected to be processed and updated on another fetch iteration.", + self._model_type.__tablename__, + failed_to_update_count, ) - if res.rowcount == 0: # pyright: ignore[reportAttributeAccessIssue] - logger.warning( - "Failed to update lock_expires_at: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration." - ) - return - for item in updated_items: - item.lock_expires_at = now + self._lock_timeout class Fetcher(ABC): From e6526facd5a2c624163714788615a5cdf84ad520 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 17 Feb 2026 12:36:57 +0500 Subject: [PATCH 06/15] Add ComputeGroupPipeline and PlacementGroupPipeline --- .../_internal/core/models/compute_groups.py | 7 + .../background/pipeline_tasks/__init__.py | 6 +- .../server/background/pipeline_tasks/base.py | 19 ++ .../pipeline_tasks/compute_groups.py | 322 ++++++++++++++++++ .../pipeline_tasks/placement_groups.py | 250 ++++++++++++++ .../versions/351efa3432d9_pipelines.py | 65 ++++ src/dstack/_internal/server/models.py | 10 +- 7 files changed, 676 insertions(+), 3 deletions(-) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py create mode 100644 src/dstack/_internal/server/migrations/versions/351efa3432d9_pipelines.py diff --git a/src/dstack/_internal/core/models/compute_groups.py b/src/dstack/_internal/core/models/compute_groups.py index 66e1292eff..3fa967494d 100644 --- a/src/dstack/_internal/core/models/compute_groups.py +++ b/src/dstack/_internal/core/models/compute_groups.py @@ -12,6 +12,13 @@ class ComputeGroupStatus(str, enum.Enum): RUNNING = "running" TERMINATED = "terminated" + @classmethod + def finished_statuses(cls) -> List["ComputeGroupStatus"]: + return [cls.TERMINATED] + + def is_finished(self): + return self in self.finished_statuses() + class ComputeGroupProvisioningData(CoreModel): compute_group_id: str diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index c3a5c8ab29..467e671c33 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -8,7 +8,11 @@ class PipelineManager: def __init__(self) -> None: - self._pipelines: list[Pipeline] = [] # Pipelines will go here + self._pipelines: list[Pipeline] = [ + # Pipelines go here: + # PlacementGroupPipeline(), + # ComputeGroupPipeline(), + ] self._hinter = PipelineHinter(self._pipelines) def start(self): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 8bd8ec49a9..ac2e510fa7 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -21,6 +21,8 @@ class PipelineItem(Protocol): lock_expires_at: datetime lock_token: uuid.UUID + __tablename__: str + class PipelineModel(Protocol): id: Mapped[uuid.UUID] @@ -309,12 +311,14 @@ async def start(self): self._running = True while self._running: item = await self._queue.get() + logger.debug("Processing %s item %s", item.__tablename__, item.id) try: await self.process(item) except Exception: logger.exception("Unexpected exception when processing item") finally: await self._heartbeater.untrack(item) + logger.debug("Processed %s item %s", item.__tablename__, item.id) def stop(self): self._running = False @@ -322,3 +326,18 @@ def stop(self): @abstractmethod async def process(self, item: PipelineItem): pass + + +UpdateMap = dict[str, Any] + + +def get_unlock_update_map() -> UpdateMap: + return { + "lock_expires_at": None, + "lock_token": None, + "lock_owner": None, + } + + +def get_processed_update_map() -> UpdateMap: + return {"last_processed_at": get_current_datetime()} diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py new file mode 100644 index 0000000000..f14b432fbf --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -0,0 +1,322 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Sequence, cast + +from sqlalchemy import or_, select, update +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.compute import ComputeWithGroupProvisioningSupport +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.compute_groups import ComputeGroupStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + UpdateMap, + Worker, + get_processed_update_map, + get_unlock_update_map, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group +from dstack._internal.server.services.instances import switch_instance_status +from dstack._internal.server.services.locking import get_locker +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + +TERMINATION_RETRY_TIMEOUT = timedelta(seconds=60) +TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) + + +class ComputeGroupPipeline(Pipeline): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[ComputeGroupModel]( + model_type=ComputeGroupModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = ComputeGroupFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + ComputeGroupWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return ComputeGroupModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher: + return self.__fetcher + + @property + def _workers(self) -> Sequence["ComputeGroupWorker"]: + return self.__workers + + +class ComputeGroupFetcher(Fetcher): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[ComputeGroupModel], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + async def fetch(self, limit: int) -> list[PipelineItem]: + compute_group_lock, _ = get_locker(get_db().dialect_name).get_lockset( + ComputeGroupModel.__tablename__ + ) + async with compute_group_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(ComputeGroupModel) + .where( + ComputeGroupModel.status.not_in(ComputeGroupStatus.finished_statuses()), + ComputeGroupModel.last_processed_at <= now - self._min_processing_interval, + or_( + ComputeGroupModel.lock_expires_at.is_(None), + ComputeGroupModel.lock_expires_at < now, + ), + or_( + ComputeGroupModel.lock_owner.is_(None), + ComputeGroupModel.lock_owner == ComputeGroupPipeline.__name__, + ), + ) + .order_by(ComputeGroupModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=ComputeGroupModel) + .options( + load_only( + ComputeGroupModel.id, + ComputeGroupModel.lock_token, + ComputeGroupModel.lock_expires_at, + ) + ) + ) + compute_group_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + for compute_group_model in compute_group_models: + compute_group_model.lock_expires_at = lock_expires_at + compute_group_model.lock_token = lock_token + compute_group_model.lock_owner = ComputeGroupPipeline.__name__ + await session.commit() + return [cast(PipelineItem, r) for r in compute_group_models] + + +class ComputeGroupWorker(Worker): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater[ComputeGroupModel], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + async def process(self, item: PipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(ComputeGroupModel) + .where( + ComputeGroupModel.id == item.id, + ComputeGroupModel.lock_token == item.lock_token, + ) + # Terminating instances belonging to a compute group are locked implicitly by locking the compute group. + .options( + joinedload(ComputeGroupModel.instances), + joinedload(ComputeGroupModel.project).joinedload(ProjectModel.backends), + ) + ) + compute_group_model = res.unique().scalar_one_or_none() + if compute_group_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + terminate_result = _TerminateResult() + # TODO: Fetch only compute groups with all instances terminating. + if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances): + terminate_result = await _terminate_compute_group(compute_group_model) + if terminate_result.compute_group_update_map: + logger.info("Terminated compute group %s", compute_group_model.id) + else: + terminate_result.compute_group_update_map = get_processed_update_map() + + terminate_result.compute_group_update_map |= get_unlock_update_map() + + async with get_session_ctx() as session: + res = await session.execute( + update(ComputeGroupModel) + .where( + ComputeGroupModel.id == compute_group_model.id, + ComputeGroupModel.lock_token == compute_group_model.lock_token, + ) + .values(**terminate_result.compute_group_update_map) + .returning(ComputeGroupModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + if not terminate_result.instances_update_map: + return + instances_ids = [i.id for i in compute_group_model.instances] + res = await session.execute( + update(InstanceModel) + .where(InstanceModel.id.in_(instances_ids)) + .values(**terminate_result.instances_update_map) + ) + for instance_model in compute_group_model.instances: + switch_instance_status(session, instance_model, InstanceStatus.TERMINATED) + + +@dataclass +class _TerminateResult: + compute_group_update_map: UpdateMap = field(default_factory=dict) + instances_update_map: UpdateMap = field(default_factory=dict) + + +async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _TerminateResult: + result = _TerminateResult() + if ( + compute_group_model.last_termination_retry_at is not None + and _next_termination_retry_at(compute_group_model) > get_current_datetime() + ): + return result + compute_group = compute_group_model_to_compute_group(compute_group_model) + cgpd = compute_group.provisioning_data + backend = await backends_services.get_project_backend_by_type( + project=compute_group_model.project, + backend_type=cgpd.backend, + ) + if backend is None: + logger.error( + "Failed to terminate compute group %s. Backend %s not available." + " Please terminate it manually to avoid unexpected charges.", + compute_group.name, + cgpd.backend, + ) + return _get_terminated_result() + logger.debug("Terminating compute group %s", compute_group.name) + compute = backend.compute() + assert isinstance(compute, ComputeWithGroupProvisioningSupport) + try: + await run_async( + compute.terminate_compute_group, + compute_group, + ) + except Exception as e: + if compute_group_model.first_termination_retry_at is None: + result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime() + result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime() + if _next_termination_retry_at(compute_group_model) < _get_termination_deadline( + compute_group_model + ): + logger.warning( + "Failed to terminate compute group %s. Will retry. Error: %r", + compute_group.name, + e, + exc_info=not isinstance(e, BackendError), + ) + return result + logger.error( + "Failed all attempts to terminate compute group %s." + " Please terminate it manually to avoid unexpected charges." + " Error: %r", + compute_group.name, + e, + exc_info=not isinstance(e, BackendError), + ) + terminated_result = _get_terminated_result() + return _TerminateResult( + compute_group_update_map=result.compute_group_update_map + | terminated_result.compute_group_update_map, + instances_update_map=result.instances_update_map + | terminated_result.compute_group_update_map, + ) + + +def _next_termination_retry_at(compute_group_model: ComputeGroupModel) -> datetime: + assert compute_group_model.last_termination_retry_at is not None + return compute_group_model.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT + + +def _get_termination_deadline(compute_group_model: ComputeGroupModel) -> datetime: + assert compute_group_model.first_termination_retry_at is not None + return compute_group_model.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION + + +def _get_terminated_result() -> _TerminateResult: + now = get_current_datetime() + return _TerminateResult( + compute_group_update_map={ + "last_processed_at": now, + "deleted": True, + "deleted_at": now, + "status": ComputeGroupStatus.TERMINATED, + }, + instances_update_map={ + "last_processed_at": now, + "deleted": True, + "deleted_at": now, + "finished_at": now, + "status": InstanceStatus.TERMINATED, + }, + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py new file mode 100644 index 0000000000..ad4666d0a4 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -0,0 +1,250 @@ +import asyncio +import uuid +from datetime import timedelta +from typing import Sequence, cast + +from sqlalchemy import or_, select, update +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport +from dstack._internal.core.errors import PlacementGroupInUseError +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + UpdateMap, + Worker, + get_processed_update_map, + get_unlock_update_map, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + PlacementGroupModel, + ProjectModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.placement import placement_group_model_to_placement_group +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class PlacementGroupPipeline(Pipeline): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[PlacementGroupModel]( + model_type=PlacementGroupModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = PlacementGroupFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + PlacementGroupWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return PlacementGroupModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher: + return self.__fetcher + + @property + def _workers(self) -> Sequence["PlacementGroupWorker"]: + return self.__workers + + +class PlacementGroupFetcher(Fetcher): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[PlacementGroupModel], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + async def fetch(self, limit: int) -> list[PipelineItem]: + placement_group_lock, _ = get_locker(get_db().dialect_name).get_lockset( + PlacementGroupModel.__tablename__ + ) + async with placement_group_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.fleet_deleted == True, + PlacementGroupModel.deleted == False, + or_( + PlacementGroupModel.lock_expires_at.is_(None), + PlacementGroupModel.lock_expires_at < now, + ), + or_( + PlacementGroupModel.lock_owner.is_(None), + PlacementGroupModel.lock_owner == PlacementGroupPipeline.__name__, + ), + ) + .order_by(PlacementGroupModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True) + .options( + load_only( + PlacementGroupModel.id, + PlacementGroupModel.lock_token, + PlacementGroupModel.lock_expires_at, + ) + ) + ) + placement_group_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + for placement_group_model in placement_group_models: + placement_group_model.lock_expires_at = lock_expires_at + placement_group_model.lock_token = lock_token + placement_group_model.lock_owner = PlacementGroupPipeline.__name__ + await session.commit() + return [cast(PipelineItem, r) for r in placement_group_models] + + +class PlacementGroupWorker(Worker): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater[PlacementGroupModel], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + async def process(self, item: PipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.id == item.id, + PlacementGroupModel.lock_token == item.lock_token, + ) + .options(joinedload(PlacementGroupModel.project).joinedload(ProjectModel.backends)) + ) + placement_group_model = res.unique().scalar_one_or_none() + if placement_group_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + update_map = await _delete_placement_group(placement_group_model) + if update_map: + logger.info("Deleted placement group %s", placement_group_model.name) + else: + update_map = get_processed_update_map() + + update_map |= get_unlock_update_map() + + async with get_session_ctx() as session: + res = await session.execute( + update(PlacementGroupModel) + .where( + PlacementGroupModel.id == placement_group_model.id, + PlacementGroupModel.lock_token == placement_group_model.lock_token, + ) + .values(**update_map) + .returning(PlacementGroupModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + + +async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> UpdateMap: + placement_group = placement_group_model_to_placement_group(placement_group_model) + if placement_group.provisioning_data is None: + logger.error( + "Failed to delete placement group %s. provisioning_data is None.", placement_group.name + ) + return _get_deleted_update_map() + backend = await backends_services.get_project_backend_by_type( + project=placement_group_model.project, + backend_type=placement_group.provisioning_data.backend, + ) + if backend is None: + logger.error( + "Failed to delete placement group %s. Backend not available. Please delete it manually.", + placement_group.name, + ) + return _get_deleted_update_map() + compute = backend.compute() + assert isinstance(compute, ComputeWithPlacementGroupSupport) + try: + await run_async(compute.delete_placement_group, placement_group) + except PlacementGroupInUseError: + logger.info( + "Placement group %s is still in use. Skipping deletion for now.", placement_group.name + ) + return {} + except Exception: + logger.exception( + "Got exception when deleting placement group %s. Please delete it manually.", + placement_group.name, + ) + return _get_deleted_update_map() + + return _get_deleted_update_map() + + +def _get_deleted_update_map() -> UpdateMap: + now = get_current_datetime() + return { + "last_processed_at": now, + "deleted": True, + "deleted_at": now, + } diff --git a/src/dstack/_internal/server/migrations/versions/351efa3432d9_pipelines.py b/src/dstack/_internal/server/migrations/versions/351efa3432d9_pipelines.py new file mode 100644 index 0000000000..ade236d04c --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/351efa3432d9_pipelines.py @@ -0,0 +1,65 @@ +"""Pipelines + +Revision ID: 351efa3432d9 +Revises: 706e0acc3a7d +Create Date: 2026-02-17 10:45:48.754096 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "351efa3432d9" +down_revision = "706e0acc3a7d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("compute_groups", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + with op.batch_alter_table("placement_groups", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("placement_groups", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + with op.batch_alter_table("compute_groups", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 7e9db282d1..8c611e9447 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -196,6 +196,12 @@ class BaseModel(DeclarativeBase): metadata = MetaData(naming_convention=constraint_naming_convention) +class PipelineModelMixin: + lock_expires_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + lock_token: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False)) + lock_owner: Mapped[Optional[str]] = mapped_column(String(100)) + + class UserModel(BaseModel): __tablename__ = "users" @@ -768,7 +774,7 @@ class VolumeAttachmentModel(BaseModel): attachment_data: Mapped[Optional[str]] = mapped_column(Text) -class PlacementGroupModel(BaseModel): +class PlacementGroupModel(PipelineModelMixin, BaseModel): __tablename__ = "placement_groups" id: Mapped[uuid.UUID] = mapped_column( @@ -795,7 +801,7 @@ class PlacementGroupModel(BaseModel): provisioning_data: Mapped[Optional[str]] = mapped_column(Text) -class ComputeGroupModel(BaseModel): +class ComputeGroupModel(PipelineModelMixin, BaseModel): __tablename__ = "compute_groups" id: Mapped[uuid.UUID] = mapped_column( From de75b4212bb38ec508ba89a6e47c3ee49ab56763 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 17 Feb 2026 15:07:36 +0500 Subject: [PATCH 07/15] Add TestComputeGroupWorker --- .../pipeline_tasks/compute_groups.py | 21 +++-- .../background/pipeline_tasks/__init__.py | 0 .../pipeline_tasks/test_compute_groups.py | 92 +++++++++++++++++++ .../test_process_compute_groups.py | 2 +- 4 files changed, 105 insertions(+), 10 deletions(-) create mode 100644 src/tests/_internal/server/background/pipeline_tasks/__init__.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index f14b432fbf..f65adc4efa 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -237,7 +237,8 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T result = _TerminateResult() if ( compute_group_model.last_termination_retry_at is not None - and _next_termination_retry_at(compute_group_model) > get_current_datetime() + and _next_termination_retry_at(compute_group_model.last_termination_retry_at) + > get_current_datetime() ): return result compute_group = compute_group_model_to_compute_group(compute_group_model) @@ -266,8 +267,12 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T if compute_group_model.first_termination_retry_at is None: result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime() result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime() - if _next_termination_retry_at(compute_group_model) < _get_termination_deadline( - compute_group_model + if _next_termination_retry_at( + result.compute_group_update_map["last_termination_retry_at"] + ) < _get_termination_deadline( + result.compute_group_update_map.get( + "first_termination_retry_at", compute_group_model.first_termination_retry_at + ) ): logger.warning( "Failed to terminate compute group %s. Will retry. Error: %r", @@ -293,14 +298,12 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T ) -def _next_termination_retry_at(compute_group_model: ComputeGroupModel) -> datetime: - assert compute_group_model.last_termination_retry_at is not None - return compute_group_model.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT +def _next_termination_retry_at(last_termination_retry_at: datetime) -> datetime: + return last_termination_retry_at + TERMINATION_RETRY_TIMEOUT -def _get_termination_deadline(compute_group_model: ComputeGroupModel) -> datetime: - assert compute_group_model.first_termination_retry_at is not None - return compute_group_model.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION +def _get_termination_deadline(first_termination_retry_at: datetime) -> datetime: + return first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION def _get_terminated_result() -> _TerminateResult: diff --git a/src/tests/_internal/server/background/pipeline_tasks/__init__.py b/src/tests/_internal/server/background/pipeline_tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py new file mode 100644 index 0000000000..0fd75b981f --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py @@ -0,0 +1,92 @@ +from datetime import datetime, timezone +from typing import cast +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.compute_groups import ComputeGroupStatus +from dstack._internal.server.background.pipeline_tasks.base import PipelineItem +from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupWorker +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_compute_group, + create_fleet, + create_project, +) + + +@pytest.fixture +def worker() -> ComputeGroupWorker: + return ComputeGroupWorker(queue=Mock(), heartbeater=Mock()) + + +class TestComputeGroupWorker: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_terminates_compute_group( + self, test_db, session: AsyncSession, worker: ComputeGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group( + session=session, + project=project, + fleet=fleet, + ) + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + await worker.process(cast(PipelineItem, compute_group)) + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status == ComputeGroupStatus.TERMINATED + assert compute_group.deleted + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retries_compute_group_termination( + self, test_db, session: AsyncSession, worker: ComputeGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), + ) + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + compute_mock.terminate_compute_group.side_effect = BackendError() + await worker.process(cast(PipelineItem, compute_group)) + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status != ComputeGroupStatus.TERMINATED + assert compute_group.first_termination_retry_at is not None + assert compute_group.last_termination_retry_at is not None + # Simulate termination deadline exceeded + compute_group.first_termination_retry_at = datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc) + compute_group.last_termination_retry_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) + compute_group.last_processed_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + compute_mock.terminate_compute_group.side_effect = BackendError() + await worker.process(cast(PipelineItem, compute_group)) + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status == ComputeGroupStatus.TERMINATED diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py b/src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py index 5ca81b88a6..b2b1920199 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py @@ -6,8 +6,8 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.server.background.scheduled_tasks.compute_groups import ( - ComputeGroupStatus, process_compute_groups, ) from dstack._internal.server.testing.common import ( From b714410559e523f7af602511fd6b19a5b4d113b5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 17 Feb 2026 15:12:39 +0500 Subject: [PATCH 08/15] Add TestPlacementGroupWorker --- .../pipeline_tasks/test_placement_groups.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py new file mode 100644 index 0000000000..3eadf9a733 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -0,0 +1,47 @@ +from typing import cast +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.pipeline_tasks.base import PipelineItem +from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_fleet, + create_placement_group, + create_project, +) + + +@pytest.fixture +def worker() -> PlacementGroupWorker: + return PlacementGroupWorker(queue=Mock(), heartbeater=Mock()) + + +class TestPlacementGroupWorker: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_deletes_placement_group( + self, test_db, session: AsyncSession, worker: PlacementGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + placement_group = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test1-pg", + fleet_deleted=True, + ) + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + await worker.process(cast(PipelineItem, placement_group)) + aws_mock.compute.return_value.delete_placement_group.assert_called_once() + await session.refresh(placement_group) + assert placement_group.deleted From 44f38e9be77a13cbc47a7888ac3b546fd4f0f7d7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 17 Feb 2026 15:19:50 +0500 Subject: [PATCH 09/15] Add DSTACK_FF_PIPELINE_PROCESSING_ENABLED --- .../background/pipeline_tasks/__init__.py | 16 +++++++++++----- .../background/scheduled_tasks/__init__.py | 17 ++++++++++------- src/dstack/_internal/settings.py | 3 +++ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 467e671c33..355e042476 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -1,6 +1,11 @@ import asyncio from dstack._internal.server.background.pipeline_tasks.base import Pipeline +from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline +from dstack._internal.server.background.pipeline_tasks.placement_groups import ( + PlacementGroupPipeline, +) +from dstack._internal.settings import FeatureFlags from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -8,11 +13,12 @@ class PipelineManager: def __init__(self) -> None: - self._pipelines: list[Pipeline] = [ - # Pipelines go here: - # PlacementGroupPipeline(), - # ComputeGroupPipeline(), - ] + self._pipelines: list[Pipeline] = [] + if FeatureFlags.PIPELINE_PROCESSING_ENABLED: + self._pipelines += [ + ComputeGroupPipeline(), + PlacementGroupPipeline(), + ] self._hinter = PipelineHinter(self._pipelines) def start(self): diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 3bc17ecc40..c4baf96c58 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -43,6 +43,7 @@ from dstack._internal.server.background.scheduled_tasks.volumes import ( process_submitted_volumes, ) +from dstack._internal.settings import FeatureFlags _scheduler = AsyncIOScheduler() @@ -105,7 +106,8 @@ def start_scheduled_tasks() -> AsyncIOScheduler: _scheduler.add_job( process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 ) - _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) + if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) _scheduler.add_job( process_fleets, IntervalTrigger(seconds=10, jitter=2), @@ -146,11 +148,12 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=2 if replica == 0 else 1, ) - _scheduler.add_job( - process_compute_groups, - IntervalTrigger(seconds=15, jitter=2), - kwargs={"batch_size": 1}, - max_instances=2 if replica == 0 else 1, - ) + if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_compute_groups, + IntervalTrigger(seconds=15, jitter=2), + kwargs={"batch_size": 1}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py index 6089e37c07..d94bb56547 100644 --- a/src/dstack/_internal/settings.py +++ b/src/dstack/_internal/settings.py @@ -47,3 +47,6 @@ class FeatureFlags: # DSTACK_FF_AUTOCREATED_FLEETS_ENABLED enables legacy autocreated fleets: # If there are no fleet suitable for the run, a new fleet is created automatically instead of an error. AUTOCREATED_FLEETS_ENABLED = os.getenv("DSTACK_FF_AUTOCREATED_FLEETS_ENABLED") is not None + # DSTACK_FF_PIPELINE_PROCESSING_ENABLED enables new pipeline-based processing tasks (background/pipeline_tasks/) + # instead of scheduler-based processing tasks (background/scheduled_tasks/) for tasks that implement pipelines. + PIPELINE_PROCESSING_ENABLED = os.getenv("DSTACK_FF_PIPELINE_PROCESSING_ENABLED") is not None From 175f882bc0156ab9d7523d2f107f5900235e46ec Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 17 Feb 2026 15:30:16 +0500 Subject: [PATCH 10/15] Fixes --- src/dstack/_internal/server/app.py | 1 + .../server/background/pipeline_tasks/compute_groups.py | 3 +-- .../server/background/pipeline_tasks/placement_groups.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 27593d3f9e..209679f0ef 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -168,6 +168,7 @@ async def lifespan(app: FastAPI): if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: scheduler = start_scheduled_tasks() pipeline_manager = start_pipeline_tasks() + app.state.pipeline_manager = pipeline_manager else: logger.info("Background processing is disabled") PROBES_SCHEDULER.start() diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index f65adc4efa..b023779aa0 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -293,8 +293,7 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T return _TerminateResult( compute_group_update_map=result.compute_group_update_map | terminated_result.compute_group_update_map, - instances_update_map=result.instances_update_map - | terminated_result.compute_group_update_map, + instances_update_map=result.instances_update_map | terminated_result.instances_update_map, ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index ad4666d0a4..98e0d17849 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -115,6 +115,8 @@ async def fetch(self, limit: int) -> list[PipelineItem]: .where( PlacementGroupModel.fleet_deleted == True, PlacementGroupModel.deleted == False, + PlacementGroupModel.last_processed_at + <= now - self._min_processing_interval, or_( PlacementGroupModel.lock_expires_at.is_(None), PlacementGroupModel.lock_expires_at < now, From 45a2e46315101ece0eb5c2323b88c7ca27a91446 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 17 Feb 2026 15:53:05 +0500 Subject: [PATCH 11/15] Rename scheduled_tasks tests --- .../{test_process_compute_groups.py => test_compute_groups.py} | 0 .../scheduled_tasks/{test_process_events.py => test_events.py} | 0 .../scheduled_tasks/{test_process_fleets.py => test_fleets.py} | 0 .../{test_process_gateways.py => test_gateways.py} | 0 .../{test_process_idle_volumes.py => test_idle_volumes.py} | 0 .../{test_process_instances.py => test_instances.py} | 0 .../{test_process_metrics.py => test_metrics.py} | 0 ...est_process_placement_groups.py => test_placement_groups.py} | 0 .../scheduled_tasks/{test_process_probes.py => test_probes.py} | 0 ...process_prometheus_metrics.py => test_prometheus_metrics.py} | 0 .../{test_process_running_jobs.py => test_running_jobs.py} | 0 .../scheduled_tasks/{test_process_runs.py => test_runs.py} | 0 .../{test_process_submitted_jobs.py => test_submitted_jobs.py} | 0 ...t_process_submitted_volumes.py => test_submitted_volumes.py} | 0 ...est_process_terminating_jobs.py => test_terminating_jobs.py} | 0 src/tests/_internal/server/routers/test_runs.py | 2 +- 16 files changed, 1 insertion(+), 1 deletion(-) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_compute_groups.py => test_compute_groups.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_events.py => test_events.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_fleets.py => test_fleets.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_gateways.py => test_gateways.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_idle_volumes.py => test_idle_volumes.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_instances.py => test_instances.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_metrics.py => test_metrics.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_placement_groups.py => test_placement_groups.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_probes.py => test_probes.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_prometheus_metrics.py => test_prometheus_metrics.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_running_jobs.py => test_running_jobs.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_runs.py => test_runs.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_submitted_jobs.py => test_submitted_jobs.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_submitted_volumes.py => test_submitted_volumes.py} (100%) rename src/tests/_internal/server/background/scheduled_tasks/{test_process_terminating_jobs.py => test_terminating_jobs.py} (100%) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py b/src/tests/_internal/server/background/scheduled_tasks/test_compute_groups.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_compute_groups.py rename to src/tests/_internal/server/background/scheduled_tasks/test_compute_groups.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_events.py b/src/tests/_internal/server/background/scheduled_tasks/test_events.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_events.py rename to src/tests/_internal/server/background/scheduled_tasks/test_events.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_fleets.py b/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_fleets.py rename to src/tests/_internal/server/background/scheduled_tasks/test_fleets.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_gateways.py b/src/tests/_internal/server/background/scheduled_tasks/test_gateways.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_gateways.py rename to src/tests/_internal/server/background/scheduled_tasks/test_gateways.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_idle_volumes.py b/src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_idle_volumes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_idle_volumes.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_instances.py rename to src/tests/_internal/server/background/scheduled_tasks/test_instances.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_metrics.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_metrics.py rename to src/tests/_internal/server/background/scheduled_tasks/test_metrics.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_placement_groups.py b/src/tests/_internal/server/background/scheduled_tasks/test_placement_groups.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_placement_groups.py rename to src/tests/_internal/server/background/scheduled_tasks/test_placement_groups.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_probes.py b/src/tests/_internal/server/background/scheduled_tasks/test_probes.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_probes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_probes.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_prometheus_metrics.py b/src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_prometheus_metrics.py rename to src/tests/_internal/server/background/scheduled_tasks/test_prometheus_metrics.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_running_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_runs.py b/src/tests/_internal/server/background/scheduled_tasks/test_runs.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_runs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_runs.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_volumes.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_volumes.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_submitted_volumes.py rename to src/tests/_internal/server/background/scheduled_tasks/test_submitted_volumes.py diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py similarity index 100% rename from src/tests/_internal/server/background/scheduled_tasks/test_process_terminating_jobs.py rename to src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 21b6e8f28f..af322828b2 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -12,6 +12,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal import settings from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import ApplyAction from dstack._internal.core.models.configurations import ( @@ -67,7 +68,6 @@ list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str -from tests._internal.server.background.scheduled_tasks.test_process_running_jobs import settings pytestmark = pytest.mark.usefixtures("image_config_mock") From acbb72185f560572107ae3c0a9157e8369f5cc06 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Feb 2026 10:49:01 +0500 Subject: [PATCH 12/15] Add TestHeartbeater --- .../background/pipeline_tasks/test_base.py | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_base.py diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_base.py b/src/tests/_internal/server/background/pipeline_tasks/test_base.py new file mode 100644 index 0000000000..49dc47ba91 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_base.py @@ -0,0 +1,177 @@ +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import cast +from unittest.mock import patch + +import pytest +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.pipeline_tasks.base import Heartbeater, PipelineItem +from dstack._internal.server.models import PlacementGroupModel +from dstack._internal.server.testing.common import ( + create_fleet, + create_placement_group, + create_project, +) + + +@dataclass +class DummyPipelineItem: + id: uuid.UUID + lock_token: uuid.UUID + lock_expires_at: datetime + __tablename__: str = PlacementGroupModel.__tablename__ + + +@pytest.fixture +def now() -> datetime: + return datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + + +@pytest.fixture +def heartbeater() -> Heartbeater[PlacementGroupModel]: + return Heartbeater( + model_type=PlacementGroupModel, + lock_timeout=timedelta(seconds=30), + heartbeat_trigger=timedelta(seconds=5), + ) + + +async def _create_locked_placement_group( + session: AsyncSession, + now: datetime, + lock_expires_in: timedelta, +) -> PlacementGroupModel: + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + placement_group = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test-pg", + ) + placement_group.lock_token = uuid.uuid4() + placement_group.lock_expires_at = now + lock_expires_in + await session.commit() + return placement_group + + +class TestHeartbeater: + @pytest.mark.asyncio + async def test_untrack_preserves_item_when_lock_token_mismatches( + self, heartbeater: Heartbeater[PlacementGroupModel], now: datetime + ): + item = DummyPipelineItem( + id=uuid.uuid4(), + lock_token=uuid.uuid4(), + lock_expires_at=now + timedelta(seconds=10), + ) + await heartbeater.track(item) + + stale_item = DummyPipelineItem( + id=item.id, + lock_token=uuid.uuid4(), + lock_expires_at=item.lock_expires_at, + ) + await heartbeater.untrack(stale_item) + + assert item.id in heartbeater._items + await heartbeater.untrack(item) + assert item.id not in heartbeater._items + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_heartbeat_extends_locks_close_to_expiration( + self, + test_db, + session: AsyncSession, + heartbeater: Heartbeater[PlacementGroupModel], + now: datetime, + ): + placement_group = await _create_locked_placement_group( + session=session, + now=now, + lock_expires_in=timedelta(seconds=2), + ) + await heartbeater.track(cast(PipelineItem, placement_group)) + + with patch( + "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", + return_value=now, + ): + await heartbeater.heartbeat() + + expected_lock_expires_at = now + timedelta(seconds=30) + assert placement_group.lock_expires_at == expected_lock_expires_at + assert placement_group.id in heartbeater._items + + await session.refresh(placement_group) + assert placement_group.lock_expires_at == expected_lock_expires_at + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_heartbeat_untracks_expired_items_without_db_update( + self, + test_db, + session: AsyncSession, + heartbeater: Heartbeater[PlacementGroupModel], + now: datetime, + ): + original_lock_expires_at = now - timedelta(seconds=1) + placement_group = await _create_locked_placement_group( + session=session, + now=now, + lock_expires_in=timedelta(seconds=-1), + ) + await heartbeater.track(cast(PipelineItem, placement_group)) + + with patch( + "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", + return_value=now, + ): + await heartbeater.heartbeat() + + assert placement_group.id not in heartbeater._items + + await session.refresh(placement_group) + assert placement_group.lock_expires_at == original_lock_expires_at + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_heartbeat_untracks_item_when_lock_token_changed_in_db( + self, + test_db, + session: AsyncSession, + heartbeater: Heartbeater[PlacementGroupModel], + now: datetime, + ): + original_lock_expires_at = now + timedelta(seconds=2) + placement_group = await _create_locked_placement_group( + session=session, + now=now, + lock_expires_in=timedelta(seconds=2), + ) + await heartbeater.track(cast(PipelineItem, placement_group)) + + new_lock_token = uuid.uuid4() + await session.execute( + update(PlacementGroupModel) + .where(PlacementGroupModel.id == placement_group.id) + .values(lock_token=new_lock_token) + .execution_options(synchronize_session=False) + ) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", + return_value=now, + ): + await heartbeater.heartbeat() + + assert placement_group.id not in heartbeater._items + + await session.refresh(placement_group) + assert placement_group.lock_token == new_lock_token + assert placement_group.lock_expires_at == original_lock_expires_at From abc7cfd5c51d1282f90b6f6cffda2ecb8fd9eaae Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Feb 2026 11:20:12 +0500 Subject: [PATCH 13/15] Split pipeline migration in two --- ...add_computegroupmodel_pipeline_columns.py} | 26 ++-------- ...dd_placementgroupmodel_pipeline_columns.py | 47 +++++++++++++++++++ 2 files changed, 51 insertions(+), 22 deletions(-) rename src/dstack/_internal/server/migrations/versions/{351efa3432d9_pipelines.py => 57cff3ec86ce_add_computegroupmodel_pipeline_columns.py} (59%) create mode 100644 src/dstack/_internal/server/migrations/versions/9c2a227b0154_add_placementgroupmodel_pipeline_columns.py diff --git a/src/dstack/_internal/server/migrations/versions/351efa3432d9_pipelines.py b/src/dstack/_internal/server/migrations/versions/57cff3ec86ce_add_computegroupmodel_pipeline_columns.py similarity index 59% rename from src/dstack/_internal/server/migrations/versions/351efa3432d9_pipelines.py rename to src/dstack/_internal/server/migrations/versions/57cff3ec86ce_add_computegroupmodel_pipeline_columns.py index ade236d04c..e341b3b4a4 100644 --- a/src/dstack/_internal/server/migrations/versions/351efa3432d9_pipelines.py +++ b/src/dstack/_internal/server/migrations/versions/57cff3ec86ce_add_computegroupmodel_pipeline_columns.py @@ -1,8 +1,8 @@ -"""Pipelines +"""Add ComputeGroupModel pipeline columns -Revision ID: 351efa3432d9 +Revision ID: 57cff3ec86ce Revises: 706e0acc3a7d -Create Date: 2026-02-17 10:45:48.754096 +Create Date: 2026-02-18 11:07:48.686185 """ @@ -13,7 +13,7 @@ import dstack._internal.server.models # revision identifiers, used by Alembic. -revision = "351efa3432d9" +revision = "57cff3ec86ce" down_revision = "706e0acc3a7d" branch_labels = None depends_on = None @@ -34,29 +34,11 @@ def upgrade() -> None: ) batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) - with op.batch_alter_table("placement_groups", schema=None) as batch_op: - batch_op.add_column( - sa.Column( - "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True - ) - ) - batch_op.add_column( - sa.Column( - "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True - ) - ) - batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) - # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("placement_groups", schema=None) as batch_op: - batch_op.drop_column("lock_owner") - batch_op.drop_column("lock_token") - batch_op.drop_column("lock_expires_at") - with op.batch_alter_table("compute_groups", schema=None) as batch_op: batch_op.drop_column("lock_owner") batch_op.drop_column("lock_token") diff --git a/src/dstack/_internal/server/migrations/versions/9c2a227b0154_add_placementgroupmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/9c2a227b0154_add_placementgroupmodel_pipeline_columns.py new file mode 100644 index 0000000000..56297fde36 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/9c2a227b0154_add_placementgroupmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add PlacementGroupModel pipeline columns + +Revision ID: 9c2a227b0154 +Revises: 57cff3ec86ce +Create Date: 2026-02-18 11:08:57.860277 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "9c2a227b0154" +down_revision = "57cff3ec86ce" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("placement_groups", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("placement_groups", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### From faf10ae44b54422d08e33cd5304657ad0204d926 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Feb 2026 11:37:53 +0500 Subject: [PATCH 14/15] Add pipeline indexes for compute and placement groups --- ...0_add_pipeline_indexes_for_compute_and_.py | 57 +++++++++++++++++++ src/dstack/_internal/server/models.py | 18 ++++++ 2 files changed, 75 insertions(+) create mode 100644 src/dstack/_internal/server/migrations/versions/a8ed24fd7f90_add_pipeline_indexes_for_compute_and_.py diff --git a/src/dstack/_internal/server/migrations/versions/a8ed24fd7f90_add_pipeline_indexes_for_compute_and_.py b/src/dstack/_internal/server/migrations/versions/a8ed24fd7f90_add_pipeline_indexes_for_compute_and_.py new file mode 100644 index 0000000000..ad35a23d06 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/a8ed24fd7f90_add_pipeline_indexes_for_compute_and_.py @@ -0,0 +1,57 @@ +"""Add pipeline indexes for compute and placement groups + +Revision ID: a8ed24fd7f90 +Revises: 9c2a227b0154 +Create Date: 2026-02-18 11:22:25.972000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a8ed24fd7f90" +down_revision = "9c2a227b0154" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.create_index( + "ix_compute_groups_pipeline_fetch_q", + "compute_groups", + [sa.literal_column("last_processed_at ASC")], + unique=False, + postgresql_where=sa.text("(status NOT IN ('TERMINATED'))"), + sqlite_where=sa.text("(status NOT IN ('TERMINATED'))"), + postgresql_concurrently=True, + ) + op.create_index( + "ix_placement_groups_pipeline_fetch_q", + "placement_groups", + [sa.literal_column("last_processed_at ASC")], + unique=False, + postgresql_where=sa.text("deleted IS FALSE"), + sqlite_where=sa.text("deleted = 0"), + postgresql_concurrently=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_placement_groups_pipeline_fetch_q", + "placement_groups", + postgresql_concurrently=True, + ) + op.drop_index( + "ix_compute_groups_pipeline_fetch_q", + "compute_groups", + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 8c611e9447..a837137a10 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -800,6 +800,15 @@ class PlacementGroupModel(PipelineModelMixin, BaseModel): configuration: Mapped[str] = mapped_column(Text) provisioning_data: Mapped[Optional[str]] = mapped_column(Text) + __table_args__ = ( + Index( + "ix_placement_groups_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + class ComputeGroupModel(PipelineModelMixin, BaseModel): __tablename__ = "compute_groups" @@ -829,6 +838,15 @@ class ComputeGroupModel(PipelineModelMixin, BaseModel): instances: Mapped[List["InstanceModel"]] = relationship(back_populates="compute_group") + __table_args__ = ( + Index( + "ix_compute_groups_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=status.not_in(ComputeGroupStatus.finished_statuses()), + sqlite_where=status.not_in(ComputeGroupStatus.finished_statuses()), + ), + ) + class JobMetricsPoint(BaseModel): __tablename__ = "job_metrics_points" From 84fe002824e5db29dca69fbcf7d59af78bd764d3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 18 Feb 2026 12:28:39 +0500 Subject: [PATCH 15/15] Make PipelineItem a dataclass --- .../server/background/pipeline_tasks/base.py | 15 +++---- .../pipeline_tasks/compute_groups.py | 15 ++++++- .../pipeline_tasks/placement_groups.py | 15 ++++++- .../background/pipeline_tasks/test_base.py | 40 +++++++++++-------- .../pipeline_tasks/test_compute_groups.py | 29 ++++++++++++-- .../pipeline_tasks/test_placement_groups.py | 20 +++++++++- 6 files changed, 100 insertions(+), 34 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index ac2e510fa7..30be480bf9 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -3,6 +3,7 @@ import random import uuid from abc import ABC, abstractmethod +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypeVar @@ -16,22 +17,22 @@ logger = get_logger(__name__) -class PipelineItem(Protocol): +@dataclass +class PipelineItem: + __tablename__: str id: uuid.UUID lock_expires_at: datetime lock_token: uuid.UUID - - __tablename__: str + prev_lock_expired: bool class PipelineModel(Protocol): - id: Mapped[uuid.UUID] - lock_expires_at: Mapped[Optional[datetime]] - lock_token: Mapped[Optional[uuid.UUID]] - __tablename__: str __mapper__: ClassVar[Any] __table__: ClassVar[Any] + id: Mapped[uuid.UUID] + lock_expires_at: Mapped[Optional[datetime]] + lock_token: Mapped[Optional[uuid.UUID]] class PipelineError(Exception): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index b023779aa0..685c5205a8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Sequence, cast +from typing import Sequence from sqlalchemy import or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -142,12 +142,23 @@ async def fetch(self, limit: int) -> list[PipelineItem]: compute_group_models = list(res.scalars().all()) lock_expires_at = get_current_datetime() + self._lock_timeout lock_token = uuid.uuid4() + items = [] for compute_group_model in compute_group_models: + prev_lock_expired = compute_group_model.lock_expires_at is not None compute_group_model.lock_expires_at = lock_expires_at compute_group_model.lock_token = lock_token compute_group_model.lock_owner = ComputeGroupPipeline.__name__ + items.append( + PipelineItem( + __tablename__=ComputeGroupModel.__tablename__, + id=compute_group_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + ) + ) await session.commit() - return [cast(PipelineItem, r) for r in compute_group_models] + return items class ComputeGroupWorker(Worker): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 98e0d17849..9fac5665a5 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -1,7 +1,7 @@ import asyncio import uuid from datetime import timedelta -from typing import Sequence, cast +from typing import Sequence from sqlalchemy import or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -140,12 +140,23 @@ async def fetch(self, limit: int) -> list[PipelineItem]: placement_group_models = list(res.scalars().all()) lock_expires_at = get_current_datetime() + self._lock_timeout lock_token = uuid.uuid4() + items = [] for placement_group_model in placement_group_models: + prev_lock_expired = placement_group_model.lock_expires_at is not None placement_group_model.lock_expires_at = lock_expires_at placement_group_model.lock_token = lock_token placement_group_model.lock_owner = PlacementGroupPipeline.__name__ + items.append( + PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, + id=placement_group_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + ) + ) await session.commit() - return [cast(PipelineItem, r) for r in placement_group_models] + return items class PlacementGroupWorker(Worker): diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_base.py b/src/tests/_internal/server/background/pipeline_tasks/test_base.py index 49dc47ba91..7e84d9f80d 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_base.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_base.py @@ -1,7 +1,5 @@ import uuid -from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import cast from unittest.mock import patch import pytest @@ -17,14 +15,6 @@ ) -@dataclass -class DummyPipelineItem: - id: uuid.UUID - lock_token: uuid.UUID - lock_expires_at: datetime - __tablename__: str = PlacementGroupModel.__tablename__ - - @pytest.fixture def now() -> datetime: return datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) @@ -58,22 +48,38 @@ async def _create_locked_placement_group( return placement_group +def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> PipelineItem: + assert placement_group.lock_token is not None + assert placement_group.lock_expires_at is not None + return PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, + id=placement_group.id, + lock_token=placement_group.lock_token, + lock_expires_at=placement_group.lock_expires_at, + prev_lock_expired=False, + ) + + class TestHeartbeater: @pytest.mark.asyncio async def test_untrack_preserves_item_when_lock_token_mismatches( self, heartbeater: Heartbeater[PlacementGroupModel], now: datetime ): - item = DummyPipelineItem( + item = PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, id=uuid.uuid4(), lock_token=uuid.uuid4(), lock_expires_at=now + timedelta(seconds=10), + prev_lock_expired=True, ) await heartbeater.track(item) - stale_item = DummyPipelineItem( + stale_item = PipelineItem( + __tablename__=PlacementGroupModel.__tablename__, id=item.id, lock_token=uuid.uuid4(), lock_expires_at=item.lock_expires_at, + prev_lock_expired=False, ) await heartbeater.untrack(stale_item) @@ -95,7 +101,7 @@ async def test_heartbeat_extends_locks_close_to_expiration( now=now, lock_expires_in=timedelta(seconds=2), ) - await heartbeater.track(cast(PipelineItem, placement_group)) + await heartbeater.track(_placement_group_to_pipeline_item(placement_group)) with patch( "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", @@ -104,8 +110,8 @@ async def test_heartbeat_extends_locks_close_to_expiration( await heartbeater.heartbeat() expected_lock_expires_at = now + timedelta(seconds=30) - assert placement_group.lock_expires_at == expected_lock_expires_at - assert placement_group.id in heartbeater._items + tracked_item = heartbeater._items[placement_group.id] + assert tracked_item.lock_expires_at == expected_lock_expires_at await session.refresh(placement_group) assert placement_group.lock_expires_at == expected_lock_expires_at @@ -125,7 +131,7 @@ async def test_heartbeat_untracks_expired_items_without_db_update( now=now, lock_expires_in=timedelta(seconds=-1), ) - await heartbeater.track(cast(PipelineItem, placement_group)) + await heartbeater.track(_placement_group_to_pipeline_item(placement_group)) with patch( "dstack._internal.server.background.pipeline_tasks.base.get_current_datetime", @@ -153,7 +159,7 @@ async def test_heartbeat_untracks_item_when_lock_token_changed_in_db( now=now, lock_expires_in=timedelta(seconds=2), ) - await heartbeater.track(cast(PipelineItem, placement_group)) + await heartbeater.track(_placement_group_to_pipeline_item(placement_group)) new_lock_token = uuid.uuid4() await session.execute( diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py index 0fd75b981f..6d24669f7c 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py @@ -1,5 +1,5 @@ +import uuid from datetime import datetime, timezone -from typing import cast from unittest.mock import Mock, patch import pytest @@ -10,6 +10,7 @@ from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupWorker +from dstack._internal.server.models import ComputeGroupModel from dstack._internal.server.testing.common import ( ComputeMockSpec, create_compute_group, @@ -23,6 +24,18 @@ def worker() -> ComputeGroupWorker: return ComputeGroupWorker(queue=Mock(), heartbeater=Mock()) +def _compute_group_to_pipeline_item(compute_group: ComputeGroupModel) -> PipelineItem: + assert compute_group.lock_token is not None + assert compute_group.lock_expires_at is not None + return PipelineItem( + __tablename__=compute_group.__tablename__, + id=compute_group.id, + lock_token=compute_group.lock_token, + lock_expires_at=compute_group.lock_expires_at, + prev_lock_expired=False, + ) + + class TestComputeGroupWorker: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -36,13 +49,16 @@ async def test_terminates_compute_group( project=project, fleet=fleet, ) + compute_group.lock_token = uuid.uuid4() + compute_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: backend_mock = Mock() compute_mock = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value = compute_mock m.return_value = backend_mock backend_mock.TYPE = BackendType.RUNPOD - await worker.process(cast(PipelineItem, compute_group)) + await worker.process(_compute_group_to_pipeline_item(compute_group)) compute_mock.terminate_compute_group.assert_called_once() await session.refresh(compute_group) assert compute_group.status == ComputeGroupStatus.TERMINATED @@ -61,6 +77,9 @@ async def test_retries_compute_group_termination( fleet=fleet, last_processed_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), ) + compute_group.lock_token = uuid.uuid4() + compute_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: backend_mock = Mock() compute_mock = Mock(spec=ComputeMockSpec) @@ -68,7 +87,7 @@ async def test_retries_compute_group_termination( m.return_value = backend_mock backend_mock.TYPE = BackendType.RUNPOD compute_mock.terminate_compute_group.side_effect = BackendError() - await worker.process(cast(PipelineItem, compute_group)) + await worker.process(_compute_group_to_pipeline_item(compute_group)) compute_mock.terminate_compute_group.assert_called_once() await session.refresh(compute_group) assert compute_group.status != ComputeGroupStatus.TERMINATED @@ -78,6 +97,8 @@ async def test_retries_compute_group_termination( compute_group.first_termination_retry_at = datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc) compute_group.last_termination_retry_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) compute_group.last_processed_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) + compute_group.lock_token = uuid.uuid4() + compute_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: backend_mock = Mock() @@ -86,7 +107,7 @@ async def test_retries_compute_group_termination( m.return_value = backend_mock backend_mock.TYPE = BackendType.RUNPOD compute_mock.terminate_compute_group.side_effect = BackendError() - await worker.process(cast(PipelineItem, compute_group)) + await worker.process(_compute_group_to_pipeline_item(compute_group)) compute_mock.terminate_compute_group.assert_called_once() await session.refresh(compute_group) assert compute_group.status == ComputeGroupStatus.TERMINATED diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py index 3eadf9a733..87cab83e12 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -1,4 +1,5 @@ -from typing import cast +import uuid +from datetime import datetime, timezone from unittest.mock import Mock, patch import pytest @@ -19,6 +20,18 @@ def worker() -> PlacementGroupWorker: return PlacementGroupWorker(queue=Mock(), heartbeater=Mock()) +def _placement_group_to_pipeline_item(placement_group) -> PipelineItem: + assert placement_group.lock_token is not None + assert placement_group.lock_expires_at is not None + return PipelineItem( + __tablename__=placement_group.__tablename__, + id=placement_group.id, + lock_token=placement_group.lock_token, + lock_expires_at=placement_group.lock_expires_at, + prev_lock_expired=False, + ) + + class TestPlacementGroupWorker: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -37,11 +50,14 @@ async def test_deletes_placement_group( name="test1-pg", fleet_deleted=True, ) + placement_group.lock_token = uuid.uuid4() + placement_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: aws_mock = Mock() m.return_value = aws_mock aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) - await worker.process(cast(PipelineItem, placement_group)) + await worker.process(_placement_group_to_pipeline_item(placement_group)) aws_mock.compute.return_value.delete_placement_group.assert_called_once() await session.refresh(placement_group) assert placement_group.deleted