Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/dstack/_internal/core/models/compute_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ class ComputeGroupStatus(str, enum.Enum):
RUNNING = "running"
TERMINATED = "terminated"

@classmethod
def finished_statuses(cls) -> List["ComputeGroupStatus"]:
return [cls.TERMINATED]

def is_finished(self):
return self in self.finished_statuses()


class ComputeGroupProvisioningData(CoreModel):
compute_group_id: str
Expand Down
18 changes: 14 additions & 4 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from dstack._internal.proxy.lib.deps import get_injector_from_app
from dstack._internal.proxy.lib.routers import model_proxy
from dstack._internal.server import settings
from dstack._internal.server.background import start_background_tasks
from dstack._internal.server.background.tasks.process_probes import PROBES_SCHEDULER
from dstack._internal.server.background.pipeline_tasks import start_pipeline_tasks
from dstack._internal.server.background.scheduled_tasks import start_scheduled_tasks
from dstack._internal.server.background.scheduled_tasks.probes import PROBES_SCHEDULER
from dstack._internal.server.db import get_db, get_session_ctx, migrate
from dstack._internal.server.routers import (
auth,
Expand Down Expand Up @@ -163,8 +164,11 @@ async def lifespan(app: FastAPI):
if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
init_default_storage()
scheduler = None
pipeline_manager = None
if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
scheduler = start_background_tasks()
scheduler = start_scheduled_tasks()
pipeline_manager = start_pipeline_tasks()
app.state.pipeline_manager = pipeline_manager
else:
logger.info("Background processing is disabled")
PROBES_SCHEDULER.start()
Expand All @@ -189,9 +193,15 @@ async def lifespan(app: FastAPI):
for func in _ON_STARTUP_HOOKS:
await func(app)
yield
PROBES_SCHEDULER.shutdown(wait=False)
if pipeline_manager is not None:
pipeline_manager.shutdown()
if scheduler is not None:
# Note: Scheduler does not cancel currently running jobs, so scheduled tasks cannot do cleanup.
# TODO: Track and cancel scheduled tasks.
scheduler.shutdown()
PROBES_SCHEDULER.shutdown(wait=False)
if pipeline_manager is not None:
await pipeline_manager.drain()
await gateway_connections_pool.remove_all()
service_conn_pool = await get_injector_from_app(app).get_service_connection_pool()
await service_conn_pool.remove_all()
Expand Down
142 changes: 0 additions & 142 deletions src/dstack/_internal/server/background/__init__.py
Original file line number Diff line number Diff line change
@@ -1,142 +0,0 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger

from dstack._internal.server import settings
from dstack._internal.server.background.tasks.process_compute_groups import process_compute_groups
from dstack._internal.server.background.tasks.process_events import delete_events
from dstack._internal.server.background.tasks.process_fleets import process_fleets
from dstack._internal.server.background.tasks.process_gateways import (
process_gateways,
process_gateways_connections,
)
from dstack._internal.server.background.tasks.process_idle_volumes import process_idle_volumes
from dstack._internal.server.background.tasks.process_instances import (
delete_instance_health_checks,
process_instances,
)
from dstack._internal.server.background.tasks.process_metrics import (
collect_metrics,
delete_metrics,
)
from dstack._internal.server.background.tasks.process_placement_groups import (
process_placement_groups,
)
from dstack._internal.server.background.tasks.process_probes import process_probes
from dstack._internal.server.background.tasks.process_prometheus_metrics import (
collect_prometheus_metrics,
delete_prometheus_metrics,
)
from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs
from dstack._internal.server.background.tasks.process_runs import process_runs
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
from dstack._internal.server.background.tasks.process_terminating_jobs import (
process_terminating_jobs,
)
from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes

_scheduler = AsyncIOScheduler()


def get_scheduler() -> AsyncIOScheduler:
return _scheduler


def start_background_tasks() -> AsyncIOScheduler:
# Background processing is implemented via in-memory locks on SQLite
# and SELECT FOR UPDATE on Postgres. Locks may be held for a long time.
# This is currently the main bottleneck for scaling dstack processing
# as processing more resources requires more DB connections.
# TODO: Make background processing efficient by committing locks to DB
# and processing outside of DB transactions.
#
# Now we just try to process as many resources as possible without exhausting DB connections.
#
# Quick tasks can process multiple resources per transaction.
# Potentially long tasks process one resource per transaction
# to avoid holding locks for all the resources if one is slow to process.
# Still, the next batch won't be processed unless all resources are processed,
# so larger batches do not increase processing rate linearly.
#
# The interval, batch_size, and max_instances determine background tasks processing rates.
# By default, one server replica can handle:
#
# * 150 active jobs with 2 minutes processing latency
# * 150 active runs with 2 minutes processing latency
# * 150 active instances with 2 minutes processing latency
#
# These latency numbers do not account for provisioning time,
# so it may be slower if a backend is slow to provision.
#
# Users can set SERVER_BACKGROUND_PROCESSING_FACTOR to process more resources per replica.
# They also need to increase max db connections on the client side and db side.
#
# In-memory locking via locksets does not guarantee
# that the first waiting for the lock will acquire it.
# The jitter is needed to give all tasks a chance to acquire locks.

_scheduler.add_job(process_probes, IntervalTrigger(seconds=3, jitter=1))
_scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1)
_scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1)
_scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1)
if settings.ENABLE_PROMETHEUS_METRICS:
_scheduler.add_job(
collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1
)
_scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1)
_scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15))
_scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5)
_scheduler.add_job(
process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5
)
_scheduler.add_job(
process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1
)
_scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5))
_scheduler.add_job(
process_fleets,
IntervalTrigger(seconds=10, jitter=2),
max_instances=1,
)
_scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1)
for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR):
# Add multiple copies of tasks if requested.
# max_instances=1 for additional copies to avoid running too many tasks.
# Move other tasks here when they need per-replica scaling.
_scheduler.add_job(
process_submitted_jobs,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=4 if replica == 0 else 1,
)
_scheduler.add_job(
process_running_jobs,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=2 if replica == 0 else 1,
)
_scheduler.add_job(
process_terminating_jobs,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=2 if replica == 0 else 1,
)
_scheduler.add_job(
process_runs,
IntervalTrigger(seconds=2, jitter=1),
kwargs={"batch_size": 5},
max_instances=2 if replica == 0 else 1,
)
_scheduler.add_job(
process_instances,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=2 if replica == 0 else 1,
)
_scheduler.add_job(
process_compute_groups,
IntervalTrigger(seconds=15, jitter=2),
kwargs={"batch_size": 1},
max_instances=2 if replica == 0 else 1,
)
_scheduler.start()
return _scheduler
69 changes: 69 additions & 0 deletions src/dstack/_internal/server/background/pipeline_tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio

from dstack._internal.server.background.pipeline_tasks.base import Pipeline
from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline
from dstack._internal.server.background.pipeline_tasks.placement_groups import (
PlacementGroupPipeline,
)
from dstack._internal.settings import FeatureFlags
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


class PipelineManager:
def __init__(self) -> None:
self._pipelines: list[Pipeline] = []
if FeatureFlags.PIPELINE_PROCESSING_ENABLED:
self._pipelines += [
ComputeGroupPipeline(),
PlacementGroupPipeline(),
]
self._hinter = PipelineHinter(self._pipelines)

def start(self):
for pipeline in self._pipelines:
pipeline.start()

def shutdown(self):
for pipeline in self._pipelines:
pipeline.shutdown()

async def drain(self):
results = await asyncio.gather(
*[p.drain() for p in self._pipelines], return_exceptions=True
)
for pipeline, result in zip(self._pipelines, results):
if isinstance(result, BaseException):
logger.error(
"Unexpected exception when draining pipeline %r",
pipeline,
exc_info=(type(result), result, result.__traceback__),
)

@property
def hinter(self):
return self._hinter


class PipelineHinter:
def __init__(self, pipelines: list[Pipeline]) -> None:
self._pipelines = pipelines
self._hint_fetch_map = {p.hint_fetch_model_name: p for p in self._pipelines}

def hint_fetch(self, model_name: str):
pipeline = self._hint_fetch_map.get(model_name)
if pipeline is None:
logger.warning("Model %s not registered for fetch hints", model_name)
return
pipeline.hint_fetch()


def start_pipeline_tasks() -> PipelineManager:
"""
Start tasks processed by fetch-workers pipelines based on db + in-memory queues.
Suitable for tasks that run frequently and need to lock rows for a long time.
"""
pipeline_manager = PipelineManager()
pipeline_manager.start()
return pipeline_manager
Loading