From 44ecd1b7a5b64916dc436f5640995dc363d91b2e Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 30 Jan 2026 14:22:54 +0545 Subject: [PATCH 01/13] Initial PD disaggregation implementation Test2 Internal IP Test Add worker with internal_ip Check status and register Add Status Ready Log Add Prefill-Decode Add PD to dstack Test register worker without poll Add router config in service config Update remove worker Clean Up router code Clean Up Further Cleanup --- gateway/pyproject.toml | 4 +- .../_internal/core/backends/base/compute.py | 3 +- .../core/backends/kubernetes/compute.py | 5 ++ .../_internal/core/models/configurations.py | 9 +++ src/dstack/_internal/core/models/routers.py | 21 ++++- .../_internal/proxy/gateway/repo/state_v1.py | 1 + .../proxy/gateway/routers/registry.py | 1 + .../proxy/gateway/schemas/registry.py | 1 + .../gateway/services/model_routers/base.py | 24 +++++- .../gateway/services/model_routers/sglang.py | 79 +++++++++++++++--- .../_internal/proxy/gateway/services/nginx.py | 80 ++++++++++--------- .../proxy/gateway/services/registry.py | 9 ++- src/dstack/_internal/proxy/lib/models.py | 1 + .../server/services/gateways/client.py | 1 + .../server/services/services/__init__.py | 6 +- 15 files changed, 189 insertions(+), 56 deletions(-) diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index c40a37b7f5..243f37b939 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -11,11 +11,11 @@ requires-python = ">=3.10" dynamic = ["version"] dependencies = [ # release builds of dstack-gateway depend on a PyPI version of dstack instead - "dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz", + "dstack[gateway] @ https://github.com/Bihan/dstack/archive/refs/heads/pd_design_test.tar.gz", ] [project.optional-dependencies] -sglang = ["sglang-router==0.2.1"] +sglang = ["sglang-router==0.3.2"] [tool.setuptools.package-data] "dstack.gateway" = [ diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 49513e3211..b07c510e9c 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1042,7 +1042,8 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.type}] @ {wheel}" diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 7f8ef9123f..f25b87d9cc 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -330,6 +330,11 @@ def update_provisioning_data( if not pod_ip: return provisioning_data.internal_ip = pod_ip + logger.debug( + "Replica pod %s internal_ip=%s (cluster_ip will be set from Service)", + provisioning_data.instance_id, + pod_ip, + ) service = self.api.read_namespaced_service( name=_get_pod_service_name(provisioning_data.instance_id), namespace=self.config.namespace, diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 9c8b40b6ec..7055b14073 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -28,6 +28,7 @@ parse_off_duration, ) from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point @@ -885,6 +886,14 @@ class ServiceConfigurationParams(CoreModel): ) ), ] = None + router_config: Annotated[ + Optional[AnyRouterConfig], + Field( + description=( + "Router configuration for the service (e.g. routing policy and pd_disaggregation). " + ), + ), + ] = None @validator("port") def convert_port(cls, v) -> PortMapping: diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index e07631e12e..d80ad35873 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal +from typing import Literal, Optional from pydantic import Field from typing_extensions import Annotated @@ -19,6 +19,25 @@ class SGLangRouterConfig(CoreModel): description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" ), ] = "cache_aware" + pd_disaggregation: Annotated[ + bool, + Field(description="Enable PD disaggregation mode for the SGLang router"), + ] = False AnyRouterConfig = SGLangRouterConfig + + +def merge_router_config_for_service( + gateway_router: Optional[AnyRouterConfig], + service_router_config: Optional[AnyRouterConfig], +) -> Optional[AnyRouterConfig]: + if gateway_router is None: + return None + if service_router_config is None: + return gateway_router + return SGLangRouterConfig( + type=gateway_router.type, + policy=service_router_config.policy, + pd_disaggregation=service_router_config.pd_disaggregation, + ) diff --git a/src/dstack/_internal/proxy/gateway/repo/state_v1.py b/src/dstack/_internal/proxy/gateway/repo/state_v1.py index b49550ef5f..2f4ec79418 100644 --- a/src/dstack/_internal/proxy/gateway/repo/state_v1.py +++ b/src/dstack/_internal/proxy/gateway/repo/state_v1.py @@ -98,6 +98,7 @@ def parse_replica(replica: dict) -> Replica: ssh_destination=replica["ssh_host"], ssh_port=replica["ssh_port"], ssh_proxy=ssh_proxy, + internal_ip=replica.get("internal_ip"), ) diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index dd4f63f325..c5f4cf8a1a 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -80,6 +80,7 @@ async def register_replica( ssh_proxy=body.ssh_proxy, ssh_head_proxy=body.ssh_head_proxy, ssh_head_proxy_private_key=body.ssh_head_proxy_private_key, + internal_ip=body.internal_ip, repo=repo, nginx=nginx, service_conn_pool=service_conn_pool, diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 53a29f68ca..967e9d9960 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -56,6 +56,7 @@ class RegisterReplicaRequest(BaseModel): ssh_proxy: Optional[SSHConnectionParams] ssh_head_proxy: Optional[SSHConnectionParams] ssh_head_proxy_private_key: Optional[str] + internal_ip: Optional[str] = None class RegisterEntrypointRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 867591ca13..129562e497 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -79,7 +79,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: ... @abstractmethod - def update_replicas(self, replica_urls: List[str]) -> None: + async def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set. Args: @@ -89,3 +89,25 @@ def update_replicas(self, replica_urls: List[str]) -> None: Exception: If updating replicas fails. """ ... + + def add_worker_to_router( + self, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, + ) -> bool: + """Add a worker to the router. + + Args: + url: Worker URL (e.g. http://10.0.5.134:8000). + worker_type: Type of worker ("regular", "prefill", or "decode"). + bootstrap_port: Bootstrap port for prefill workers (optional). + + Returns: + True if the worker was accepted, False otherwise. + """ + raise NotImplementedError + + async def register_worker(self, url: str) -> bool: + """Register worker with one attempt (no polling). Returns True if ready and added.""" + raise NotImplementedError diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index c3a0dfaae9..7d976b0059 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -2,7 +2,6 @@ import subprocess import sys import time -import urllib.parse from typing import List, Optional import httpx @@ -10,6 +9,7 @@ from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError +from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger from .base import Router, RouterContext @@ -68,6 +68,8 @@ def start(self) -> None: "--policy", self.config.policy, ] + if self.config.pd_disaggregation: + cmd.append("--pd-disaggregation") subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) @@ -140,7 +142,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: for replica_url in replica_urls: self._remove_worker_from_router(replica_url) - def update_replicas(self, replica_urls: List[str]) -> None: + async def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set.""" # Query router to get current worker URLs current_workers = self._get_router_workers() @@ -172,9 +174,9 @@ def update_replicas(self, replica_urls: List[str]) -> None: self.context.port, ) - # Add workers + # Add workers: poll /server_info and register with discovered type for worker_url in sorted(workers_to_add): - success = self._add_worker_to_router(worker_url) + success = await self.register_worker(worker_url) if not success: logger.warning("Failed to add worker %s, continuing with others", worker_url) @@ -197,9 +199,16 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - def _add_worker_to_router(self, worker_url: str) -> bool: + def add_worker_to_router( + self, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, + ) -> bool: try: - payload = {"url": worker_url, "worker_type": "regular"} + payload: dict = {"url": url, "worker_type": worker_type} + if bootstrap_port is not None: + payload["bootstrap_port"] = bootstrap_port with httpx.Client(timeout=5.0) as client: response = client.post( f"http://{self.context.host}:{self.context.port}/workers", @@ -209,8 +218,9 @@ def _add_worker_to_router(self, worker_url: str) -> bool: response_data = response.json() if response_data.get("status") == "accepted": logger.info( - "Worker %s accepted by sglang router on port %s", - worker_url, + "Worker %s (type=%s) accepted by sglang router on port %s", + url, + worker_type, self.context.port, ) return True @@ -224,21 +234,66 @@ def _add_worker_to_router(self, worker_url: str) -> bool: else: logger.error( "Failed to add worker %s: status %d, %s", - worker_url, + url, response.status_code, response.text, ) return False except Exception: - logger.exception("Error adding worker %s", worker_url) + logger.exception("Error adding worker %s", url) + return False + + async def register_worker(self, url: str) -> bool: + server_info_url = f"{url}/server_info" + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(server_info_url) + if resp.status_code != 200: + return False + data = resp.json() + if data.get("status") != "ready": + return False + disaggregation_mode = data.get("disaggregation_mode", "") + if disaggregation_mode == "prefill": + worker_type = "prefill" + bootstrap_port = data.get("disaggregation_bootstrap_port") + elif disaggregation_mode == "decode": + worker_type = "decode" + bootstrap_port = None + else: + worker_type = "regular" + bootstrap_port = None + logger.info( + "Registering worker %s (type=%s)", + url, + worker_type, + ) + return await run_async( + self.add_worker_to_router, + url, + worker_type, + bootstrap_port, + ) + except Exception: + logger.exception("Error registering worker %s", url) return False def _remove_worker_from_router(self, worker_url: str) -> bool: try: - encoded_url = urllib.parse.quote(worker_url, safe="") + current_workers = self._get_router_workers() + worker_id = None + for worker in current_workers: + url = worker.get("url") + if url and isinstance(url, str) and url == worker_url: + worker_id = worker.get("id") + if worker_id and isinstance(worker_id, str): + break + if not worker_id: + logger.exception("No worker id found for url %s", worker_url) + return False with httpx.Client(timeout=5.0) as client: response = client.delete( - f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" + f"http://{self.context.host}:{self.context.port}/workers/{worker_id}" ) if response.status_code == 202: response_data = response.json() diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index bbda92d91b..85473923f7 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -5,6 +5,7 @@ from asyncio import Lock from pathlib import Path from typing import Dict, Optional +from urllib.parse import urlparse import jinja2 from pydantic import BaseModel @@ -43,6 +44,8 @@ def render(self) -> str: class ReplicaConfig(BaseModel): id: str socket: Path + port: int + internal_ip: Optional[str] = None class LimitReqZoneConfig(BaseModel): @@ -95,7 +98,8 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._next_router_port: int = self._ROUTER_PORT_MIN # Tracking of worker ports to avoid conflicts across router instances self._allocated_worker_ports: set[int] = set() - self._domain_to_worker_ports: Dict[str, list[int]] = {} + # Domain -> list of worker URLs (used for remove_replicas; non-PD URLs are gateway-local) + self._domain_to_worker_urls: Dict[str, list[str]] = {} self._next_worker_port: int = self._WORKER_PORT_MIN async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: @@ -144,33 +148,37 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: del self._domain_to_router[conf.domain] raise - allocated_ports = self._allocate_worker_ports(len(conf.replicas)) - replica_urls = [ - f"http://{router.context.host}:{port}" for port in allocated_ports - ] - - # Write router workers config - try: + if conf.router.pd_disaggregation: + # PD path: replica_urls from internal_ip (router talks directly to workers) + replica_urls = [ + f"http://{replica.internal_ip}:{replica.port}" + for replica in conf.replicas + if replica.internal_ip + ] + self._domain_to_worker_urls[conf.domain] = replica_urls + else: + # Non-PD path: allocate gateway-local ports, nginx proxies to replica sockets + allocated_ports = self._allocate_worker_ports(len(conf.replicas)) + replica_urls = [ + f"http://{router.context.host}:{port}" for port in allocated_ports + ] if conf.replicas: - await run_async(self.write_router_workers_conf, conf, allocated_ports) - # Discard old worker ports if domain already has allocated ports (required for scaling case) - if conf.domain in self._domain_to_worker_ports: - old_worker_ports = self._domain_to_worker_ports[conf.domain] - for port in old_worker_ports: - self._allocated_worker_ports.discard(port) - self._domain_to_worker_ports[conf.domain] = allocated_ports - except Exception as e: - logger.exception( - "write_router_workers_conf failed for domain=%s: %s", conf.domain, e - ) - raise + await run_async( + self.write_router_workers_conf, + conf, + allocated_ports, + ) + if conf.domain in self._domain_to_worker_urls: + self._discard_ports(self._domain_to_worker_urls[conf.domain]) + self._domain_to_worker_urls[conf.domain] = replica_urls - # Update replicas to router (actual HTTP API calls to add workers) try: - await run_async(router.update_replicas, replica_urls) + await router.update_replicas(replica_urls) except Exception as e: logger.exception( - "Failed to add replicas to router for domain=%s: %s", conf.domain, e + "Failed to add replicas to router for domain=%s: %s", + conf.domain, + e, ) raise @@ -189,12 +197,12 @@ async def unregister(self, domain: str) -> None: if domain in self._domain_to_router: router = self._domain_to_router[domain] # Remove all workers for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - replica_urls = [ - f"http://{router.context.host}:{port}" for port in worker_ports - ] - await run_async(router.remove_replicas, replica_urls) + if domain in self._domain_to_worker_urls: + worker_urls = self._domain_to_worker_urls[domain] + await run_async(router.remove_replicas, worker_urls) + self._discard_ports(worker_urls) + del self._domain_to_worker_urls[domain] + logger.debug("Removed worker URLs for domain %s", domain) # Stop and kill the router await run_async(router.stop) # Remove from mappings @@ -203,14 +211,6 @@ async def unregister(self, domain: str) -> None: del self._router_port_to_domain[router_port] del self._domain_to_router[domain] - # Discard worker ports for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - for port in worker_ports: - self._allocated_worker_ports.discard(port) - del self._domain_to_worker_ports[domain] - logger.debug("Freed worker ports %s for domain %s", worker_ports, domain) - # Remove workers config file workers_conf_path = self._conf_dir / f"router-workers.{domain}.conf" if workers_conf_path.exists(): @@ -403,6 +403,12 @@ def _allocate_worker_ports(self, num_ports: int) -> list[int]: return allocated + def _discard_ports(self, urls: list[str]) -> None: + for u in urls: + parsed = urlparse(u) + if parsed.port is not None and parsed.port in self._allocated_worker_ports: + self._allocated_worker_ports.discard(parsed.port) + def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 636d8c38ec..061ee8bec8 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -136,6 +136,7 @@ async def register_replica( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + internal_ip: Optional[str] = None, ) -> None: replica = models.Replica( id=replica_id, @@ -145,6 +146,7 @@ async def register_replica( ssh_proxy=ssh_proxy, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, + internal_ip=internal_ip, ) async with lock: @@ -258,7 +260,12 @@ async def apply_service( service, repo, service_conn_pool ) replica_configs = [ - ReplicaConfig(id=replica.id, socket=conn.app_socket_path) + ReplicaConfig( + id=replica.id, + socket=conn.app_socket_path, + port=replica.app_port, + internal_ip=replica.internal_ip, + ) for replica, conn in replica_conns.items() ] service_config = await get_nginx_service_config(service, replica_configs) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index bf37e0b5aa..9ae6b4d2ad 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -27,6 +27,7 @@ class Replica(ImmutableModel): # Optional outer proxy, a head node/bastion ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None + internal_ip: Optional[str] = None class IPAddressPartitioningKey(ImmutableModel): diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index d4f1c831e8..e68b874728 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -99,6 +99,7 @@ async def register_replica( assert jpd is not None assert jpd.hostname is not None assert jpd.ssh_port is not None + payload["internal_ip"] = jpd.internal_ip if not jpd.dockerized: payload.update( { diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 06aa5b0ef0..272f66d6b3 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,6 +26,7 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import merge_router_config_for_service from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -92,7 +93,10 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) - router = gateway_configuration.router + router = merge_router_config_for_service( + gateway_configuration.router, + run_spec.configuration.router_config, + ) service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: From a90173b7612b5b695c9e380874c10cfe8e7d0f48 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 10 Feb 2026 13:14:36 +0545 Subject: [PATCH 02/13] Add pd disaggregation service --- gateway/pyproject.toml | 2 +- src/dstack/_internal/core/backends/base/compute.py | 3 +-- .../_internal/core/backends/kubernetes/compute.py | 5 ----- src/dstack/_internal/core/models/configurations.py | 2 +- src/dstack/_internal/core/models/routers.py | 5 +++++ .../proxy/gateway/services/model_routers/base.py | 2 +- .../proxy/gateway/services/model_routers/sglang.py | 12 +++++------- src/dstack/_internal/proxy/gateway/services/nginx.py | 1 - 8 files changed, 14 insertions(+), 18 deletions(-) diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index 243f37b939..6c4d406a6f 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.10" dynamic = ["version"] dependencies = [ # release builds of dstack-gateway depend on a PyPI version of dstack instead - "dstack[gateway] @ https://github.com/Bihan/dstack/archive/refs/heads/pd_design_test.tar.gz", + "dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz", ] [project.optional-dependencies] diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index b07c510e9c..49513e3211 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1042,8 +1042,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" - wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" + wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.type}] @ {wheel}" diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index f25b87d9cc..7f8ef9123f 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -330,11 +330,6 @@ def update_provisioning_data( if not pod_ip: return provisioning_data.internal_ip = pod_ip - logger.debug( - "Replica pod %s internal_ip=%s (cluster_ip will be set from Service)", - provisioning_data.instance_id, - pod_ip, - ) service = self.api.read_namespaced_service( name=_get_pod_service_name(provisioning_data.instance_id), namespace=self.config.namespace, diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 7055b14073..955f90657c 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -890,7 +890,7 @@ class ServiceConfigurationParams(CoreModel): Optional[AnyRouterConfig], Field( description=( - "Router configuration for the service (e.g. routing policy and pd_disaggregation). " + "Router configuration for the service. Currently supports routing policy and pd_disaggregation. " ), ), ] = None diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index d80ad35873..aeabadc31a 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -32,6 +32,11 @@ def merge_router_config_for_service( gateway_router: Optional[AnyRouterConfig], service_router_config: Optional[AnyRouterConfig], ) -> Optional[AnyRouterConfig]: + """Merge gateway and service router config. + + Gateway router config supplies the router type; service router config supplies + policy and pd_disaggregation. The result combines both. + """ if gateway_router is None: return None if service_router_config is None: diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 129562e497..c8704e0cee 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -90,7 +90,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: """ ... - def add_worker_to_router( + async def add_worker_to_router( self, url: str, worker_type: str = "regular", diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index 7d976b0059..b187a0699f 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -9,7 +9,6 @@ from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError -from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger from .base import Router, RouterContext @@ -174,7 +173,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: self.context.port, ) - # Add workers: poll /server_info and register with discovered type + # Add workers for worker_url in sorted(workers_to_add): success = await self.register_worker(worker_url) if not success: @@ -199,7 +198,7 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - def add_worker_to_router( + async def add_worker_to_router( self, url: str, worker_type: str = "regular", @@ -209,8 +208,8 @@ def add_worker_to_router( payload: dict = {"url": url, "worker_type": worker_type} if bootstrap_port is not None: payload["bootstrap_port"] = bootstrap_port - with httpx.Client(timeout=5.0) as client: - response = client.post( + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( f"http://{self.context.host}:{self.context.port}/workers", json=payload, ) @@ -268,8 +267,7 @@ async def register_worker(self, url: str) -> bool: url, worker_type, ) - return await run_async( - self.add_worker_to_router, + return await self.add_worker_to_router( url, worker_type, bootstrap_port, diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 85473923f7..fdd6249adb 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -98,7 +98,6 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._next_router_port: int = self._ROUTER_PORT_MIN # Tracking of worker ports to avoid conflicts across router instances self._allocated_worker_ports: set[int] = set() - # Domain -> list of worker URLs (used for remove_replicas; non-PD URLs are gateway-local) self._domain_to_worker_urls: Dict[str, list[str]] = {} self._next_worker_port: int = self._WORKER_PORT_MIN From 5ec1f97a35ea7616c2fc1e9429f03616dd710514 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 08:12:45 +0545 Subject: [PATCH 03/13] Move router configuration to service --- .../_internal/core/backends/base/compute.py | 9 ++++---- .../core/backends/kubernetes/compute.py | 5 +---- .../_internal/core/models/configurations.py | 4 ++-- src/dstack/_internal/core/models/gateways.py | 8 +++---- src/dstack/_internal/core/models/routers.py | 22 +------------------ .../server/services/services/__init__.py | 6 +---- 6 files changed, 13 insertions(+), 41 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 49513e3211..5939751132 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,7 +39,6 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData -from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -924,7 +923,7 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[str] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1036,7 +1035,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1045,11 +1044,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: - return f"dstack-gateway[{router.type}] @ {wheel}" + return f"dstack-gateway[{router}] @ {wheel}" return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[str] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 7f8ef9123f..e9692344d9 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -50,7 +50,6 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Memory -from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error, parse_memory @@ -1002,9 +1001,7 @@ def _add_authorized_key_to_jump_pod( ) -def _get_gateway_commands( - authorized_keys: List[str], router: Optional[AnyRouterConfig] = None -) -> List[str]: +def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) quoted_gateway_commands = shlex.quote(gateway_commands) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 955f90657c..801f87daae 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -886,11 +886,11 @@ class ServiceConfigurationParams(CoreModel): ) ), ] = None - router_config: Annotated[ + router: Annotated[ Optional[AnyRouterConfig], Field( description=( - "Router configuration for the service. Currently supports routing policy and pd_disaggregation. " + "Router configuration for the service. Requires a gateway with matching router enabled. " ), ), ] = None diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index b342c0a73b..e9581b83f0 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import RouterType from dstack._internal.utils.tags import tags_validator @@ -63,8 +63,8 @@ class GatewayConfiguration(CoreModel): ), ] = None router: Annotated[ - Optional[AnyRouterConfig], - Field(description="The router configuration"), + Optional[RouterType], + Field(description="The router type enabled on this gateway. E.g. 'sglang'."), ] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") @@ -134,7 +134,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None - router: Optional[AnyRouterConfig] = None + router: Optional[RouterType] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index aeabadc31a..e42cd9976d 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal, Optional +from typing import Literal from pydantic import Field from typing_extensions import Annotated @@ -26,23 +26,3 @@ class SGLangRouterConfig(CoreModel): AnyRouterConfig = SGLangRouterConfig - - -def merge_router_config_for_service( - gateway_router: Optional[AnyRouterConfig], - service_router_config: Optional[AnyRouterConfig], -) -> Optional[AnyRouterConfig]: - """Merge gateway and service router config. - - Gateway router config supplies the router type; service router config supplies - policy and pd_disaggregation. The result combines both. - """ - if gateway_router is None: - return None - if service_router_config is None: - return gateway_router - return SGLangRouterConfig( - type=gateway_router.type, - policy=service_router_config.policy, - pd_disaggregation=service_router_config.pd_disaggregation, - ) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 272f66d6b3..8f3d10c9dc 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,7 +26,6 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import merge_router_config_for_service from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -93,10 +92,7 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) - router = merge_router_config_for_service( - gateway_configuration.router, - run_spec.configuration.router_config, - ) + router = run_spec.configuration.router service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: From 6e7dbe7abbe9f8cdc640d148885ccc3a4fdaf751 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 13:02:07 +0545 Subject: [PATCH 04/13] Resolve major comments --- .../gateway/services/model_routers/base.py | 6 +++--- .../gateway/services/model_routers/sglang.py | 18 +++++++++--------- .../_internal/proxy/gateway/services/nginx.py | 16 ++++++++++++---- .../proxy/gateway/services/registry.py | 2 +- .../server/services/services/__init__.py | 18 ++++++++++++++++++ 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index c8704e0cee..a9b54347e4 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -79,7 +79,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: ... @abstractmethod - async def update_replicas(self, replica_urls: List[str]) -> None: + def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set. Args: @@ -90,7 +90,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: """ ... - async def add_worker_to_router( + def add_worker_to_router( self, url: str, worker_type: str = "regular", @@ -108,6 +108,6 @@ async def add_worker_to_router( """ raise NotImplementedError - async def register_worker(self, url: str) -> bool: + def register_worker(self, url: str) -> bool: """Register worker with one attempt (no polling). Returns True if ready and added.""" raise NotImplementedError diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index b187a0699f..5214bb8e93 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -141,7 +141,7 @@ def remove_replicas(self, replica_urls: List[str]) -> None: for replica_url in replica_urls: self._remove_worker_from_router(replica_url) - async def update_replicas(self, replica_urls: List[str]) -> None: + def update_replicas(self, replica_urls: List[str]) -> None: """Update replicas for service, replacing the current set.""" # Query router to get current worker URLs current_workers = self._get_router_workers() @@ -175,7 +175,7 @@ async def update_replicas(self, replica_urls: List[str]) -> None: # Add workers for worker_url in sorted(workers_to_add): - success = await self.register_worker(worker_url) + success = self.register_worker(worker_url) if not success: logger.warning("Failed to add worker %s, continuing with others", worker_url) @@ -198,7 +198,7 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - async def add_worker_to_router( + def add_worker_to_router( self, url: str, worker_type: str = "regular", @@ -208,8 +208,8 @@ async def add_worker_to_router( payload: dict = {"url": url, "worker_type": worker_type} if bootstrap_port is not None: payload["bootstrap_port"] = bootstrap_port - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( + with httpx.Client(timeout=5.0) as client: + response = client.post( f"http://{self.context.host}:{self.context.port}/workers", json=payload, ) @@ -242,11 +242,11 @@ async def add_worker_to_router( logger.exception("Error adding worker %s", url) return False - async def register_worker(self, url: str) -> bool: + def register_worker(self, url: str) -> bool: server_info_url = f"{url}/server_info" try: - async with httpx.AsyncClient(timeout=10) as client: - resp = await client.get(server_info_url) + with httpx.Client(timeout=10) as client: + resp = client.get(server_info_url) if resp.status_code != 200: return False data = resp.json() @@ -267,7 +267,7 @@ async def register_worker(self, url: str) -> bool: url, worker_type, ) - return await self.add_worker_to_router( + return self.add_worker_to_router( url, worker_type, bootstrap_port, diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index fdd6249adb..d79b0c7932 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -19,6 +19,7 @@ RouterContext, get_router, ) +from dstack._internal.proxy.lib import models from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -149,10 +150,13 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: if conf.router.pd_disaggregation: # PD path: replica_urls from internal_ip (router talks directly to workers) + if any(not r.internal_ip for r in conf.replicas): + raise ProxyError( + "PD disaggregation requires internal IP for all replicas." + ) replica_urls = [ f"http://{replica.internal_ip}:{replica.port}" for replica in conf.replicas - if replica.internal_ip ] self._domain_to_worker_urls[conf.domain] = replica_urls else: @@ -172,7 +176,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: self._domain_to_worker_urls[conf.domain] = replica_urls try: - await router.update_replicas(replica_urls) + await run_async(router.update_replicas, replica_urls) except Exception as e: logger.exception( "Failed to add replicas to router for domain=%s: %s", @@ -185,7 +189,7 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.info("Registered %s domain %s", conf.type, conf.domain) - async def unregister(self, domain: str) -> None: + async def unregister(self, domain: str, service: models.Service) -> None: logger.debug("Unregistering domain %s", domain) conf_path = self._conf_dir / self.get_config_name(domain) if not conf_path.exists(): @@ -199,7 +203,11 @@ async def unregister(self, domain: str) -> None: if domain in self._domain_to_worker_urls: worker_urls = self._domain_to_worker_urls[domain] await run_async(router.remove_replicas, worker_urls) - self._discard_ports(worker_urls) + pd_disaggregation = ( + service.router.pd_disaggregation if service.router else False + ) + if not pd_disaggregation: + self._discard_ports(worker_urls) del self._domain_to_worker_urls[domain] logger.debug("Removed worker URLs for domain %s", domain) # Stop and kill the router diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 061ee8bec8..fd523e8d12 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -116,7 +116,7 @@ async def unregister_service( ids=(r.id for r in service.replicas), service_conn_pool=service_conn_pool, ) - await nginx.unregister(service.domain_safe) + await nginx.unregister(service.domain_safe, service) await repo.delete_models_by_run(project_name, run_name) await repo.delete_service(project_name, run_name) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 8f3d10c9dc..201351e50d 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,6 +26,7 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import RouterType from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -91,6 +92,15 @@ async def _register_service_in_gateway( raise ServerClientError("Gateway status is not running") gateway_configuration = get_gateway_configuration(gateway) + if ( + run_spec.configuration.router is not None + and run_spec.configuration.router.type == RouterType.SGLANG + ): + if gateway_configuration.router != RouterType.SGLANG: + raise ServerClientError( + f"Service requires a SGLang gateway but gateway '{gateway.name}' " + "does not have the SGLang router configured." + ) service_https = _get_service_https(run_spec, gateway_configuration) router = run_spec.configuration.router service_protocol = "https" if service_https else "http" @@ -152,6 +162,14 @@ async def _register_service_in_gateway( def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec: assert run_spec.configuration.type == "service" + if ( + run_spec.configuration.router is not None + and run_spec.configuration.router.type == RouterType.SGLANG + ): + raise ServerClientError( + "Service with SGLang router configuration requires a gateway. " + "Please configure a gateway with the SGLang router enabled." + ) if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: # Note: if the user sets `https: `, it will be ignored silently # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted From 860ea230ff53cfd1b80925f353de2cfb15a0f84c Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 13:49:35 +0545 Subject: [PATCH 05/13] Resolve Lint Error --- src/dstack/_internal/core/backends/kubernetes/compute.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index e3a430b35c..5223cdaa7c 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -862,9 +862,7 @@ def _wait_for_load_balancer_address( time.sleep(1) -def _get_gateway_commands( - authorized_keys: List[str], router: Optional[str] = None -) -> List[str]: +def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) quoted_gateway_commands = shlex.quote(gateway_commands) From 38eee94052c9d44eb8d6a3ada91899761a46857d Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 13 Feb 2026 20:35:02 +0545 Subject: [PATCH 06/13] Minor Update --- src/dstack/_internal/core/backends/base/compute.py | 9 +++++---- src/dstack/_internal/core/backends/kubernetes/compute.py | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 5939751132..ade1b3daeb 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,6 +39,7 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData +from dstack._internal.core.models.routers import RouterType from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -923,7 +924,7 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[str] = None) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[RouterType] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1035,7 +1036,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1044,11 +1045,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[str] = None) -> str: wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: - return f"dstack-gateway[{router}] @ {wheel}" + return f"dstack-gateway[{router.value}] @ {wheel}" return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[str] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[RouterType] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 5223cdaa7c..d98dfc94a8 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -66,6 +66,7 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec +from dstack._internal.core.models.routers import RouterType from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error @@ -862,7 +863,9 @@ def _wait_for_load_balancer_address( time.sleep(1) -def _get_gateway_commands(authorized_keys: List[str], router: Optional[str] = None) -> List[str]: +def _get_gateway_commands( + authorized_keys: List[str], router: Optional[RouterType] = None +) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) quoted_gateway_commands = shlex.quote(gateway_commands) From 63ed75c53adb97b3656971172ce69a4ea4985c34 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Mon, 16 Feb 2026 10:31:22 +0545 Subject: [PATCH 07/13] Resolve Minor Comments --- .../_internal/core/backends/base/compute.py | 3 +- .../_internal/core/compatibility/gateways.py | 6 +-- .../_internal/core/compatibility/runs.py | 6 +++ .../_internal/proxy/gateway/repo/state_v1.py | 1 - .../gateway/services/model_routers/base.py | 22 ---------- .../gateway/services/model_routers/sglang.py | 13 +++--- .../proxy/gateway/services/registry.py | 42 ++++++++++++++----- .../_internal/server/services/proxy/repo.py | 1 + 8 files changed, 50 insertions(+), 44 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index ade1b3daeb..76b1199f0a 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1042,7 +1042,8 @@ def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.value}] @ {wheel}" diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index de94f6a18e..949d6515f8 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -31,9 +31,7 @@ def _get_gateway_configuration_excludes( ) -> IncludeExcludeDictType: configuration_excludes: IncludeExcludeDictType = {} - # Add excludes like this: - # - # if configuration.tags is None: - # configuration_excludes["tags"] = True + if configuration.router is None: + configuration_excludes["router"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 19c08cde55..96db2574b3 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -72,6 +72,12 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: # Servers prior to 0.20.8 do not support probes=None configuration_excludes["probes"] = True + router = run_spec.configuration.router + if router is None: + configuration_excludes["router"] = True + elif router.pd_disaggregation is False: + configuration_excludes["router"] = {"pd_disaggregation": True} + if configuration_excludes: spec_excludes["configuration"] = configuration_excludes if profile_excludes: diff --git a/src/dstack/_internal/proxy/gateway/repo/state_v1.py b/src/dstack/_internal/proxy/gateway/repo/state_v1.py index 2f4ec79418..b49550ef5f 100644 --- a/src/dstack/_internal/proxy/gateway/repo/state_v1.py +++ b/src/dstack/_internal/proxy/gateway/repo/state_v1.py @@ -98,7 +98,6 @@ def parse_replica(replica: dict) -> Replica: ssh_destination=replica["ssh_host"], ssh_port=replica["ssh_port"], ssh_proxy=ssh_proxy, - internal_ip=replica.get("internal_ip"), ) diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index a9b54347e4..867591ca13 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -89,25 +89,3 @@ def update_replicas(self, replica_urls: List[str]) -> None: Exception: If updating replicas fails. """ ... - - def add_worker_to_router( - self, - url: str, - worker_type: str = "regular", - bootstrap_port: Optional[int] = None, - ) -> bool: - """Add a worker to the router. - - Args: - url: Worker URL (e.g. http://10.0.5.134:8000). - worker_type: Type of worker ("regular", "prefill", or "decode"). - bootstrap_port: Bootstrap port for prefill workers (optional). - - Returns: - True if the worker was accepted, False otherwise. - """ - raise NotImplementedError - - def register_worker(self, url: str) -> bool: - """Register worker with one attempt (no polling). Returns True if ready and added.""" - raise NotImplementedError diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index 5214bb8e93..6c7dafd1b0 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -175,7 +175,7 @@ def update_replicas(self, replica_urls: List[str]) -> None: # Add workers for worker_url in sorted(workers_to_add): - success = self.register_worker(worker_url) + success = self._register_worker(worker_url) if not success: logger.warning("Failed to add worker %s, continuing with others", worker_url) @@ -198,7 +198,7 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - def add_worker_to_router( + def _add_worker_to_router( self, url: str, worker_type: str = "regular", @@ -242,7 +242,10 @@ def add_worker_to_router( logger.exception("Error adding worker %s", url) return False - def register_worker(self, url: str) -> bool: + def _register_worker(self, url: str) -> bool: + if not self.config.pd_disaggregation: + return self._add_worker_to_router(url, "regular", None) + server_info_url = f"{url}/server_info" try: with httpx.Client(timeout=10) as client: @@ -267,7 +270,7 @@ def register_worker(self, url: str) -> bool: url, worker_type, ) - return self.add_worker_to_router( + return self._add_worker_to_router( url, worker_type, bootstrap_port, @@ -287,7 +290,7 @@ def _remove_worker_from_router(self, worker_url: str) -> bool: if worker_id and isinstance(worker_id, str): break if not worker_id: - logger.exception("No worker id found for url %s", worker_url) + logger.error("No worker id found for url %s", worker_url) return False with httpx.Client(timeout=5.0) as client: response = client.delete( diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index fd523e8d12..ed0ea07d77 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -237,6 +237,13 @@ async def register_model_entrypoint( logger.info("Entrypoint %s is now registered in project %s", domain, project_name) +def _uses_pd_disaggregation(service: models.Service) -> bool: + """PD disaggregation: router talks to replicas via internal_ip, no SSH tunnels needed.""" + return ( + service.router is not None and getattr(service.router, "pd_disaggregation", False) is True + ) + + async def apply_service( service: models.Service, old_service: Optional[models.Service], @@ -256,18 +263,31 @@ async def apply_service( ), service_conn_pool=service_conn_pool, ) - replica_conns, replica_failures = await get_or_add_replica_connections( - service, repo, service_conn_pool - ) - replica_configs = [ - ReplicaConfig( - id=replica.id, - socket=conn.app_socket_path, - port=replica.app_port, - internal_ip=replica.internal_ip, + if _uses_pd_disaggregation(service): + replica_conns = {} + replica_failures = {} + replica_configs = [ + ReplicaConfig( + id=replica.id, + socket=Path("/dev/null"), + port=replica.app_port, + internal_ip=replica.internal_ip, + ) + for replica in service.replicas + ] + else: + replica_conns, replica_failures = await get_or_add_replica_connections( + service, repo, service_conn_pool ) - for replica, conn in replica_conns.items() - ] + replica_configs = [ + ReplicaConfig( + id=replica.id, + socket=conn.app_socket_path, + port=replica.app_port, + internal_ip=replica.internal_ip, + ) + for replica, conn in replica_conns.items() + ] service_config = await get_nginx_service_config(service, replica_configs) await nginx.register(service_config, (await repo.get_config()).acme_settings) return replica_failures diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 7f1564fe62..385c9e654f 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -111,6 +111,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ssh_proxy=ssh_proxy, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, + internal_ip=jpd.internal_ip, ) replicas.append(replica) return Service( From 8560bcb0bc2e030e277b42c41686022130742e74 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Mon, 16 Feb 2026 11:05:18 +0545 Subject: [PATCH 08/13] Update wheel url --- src/dstack/_internal/core/backends/base/compute.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 76b1199f0a..ade1b3daeb 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1042,8 +1042,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" - wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" + wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.value}] @ {wheel}" From 56286c47fb60e2ab2c5b10792d8f66b36a2c7168 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 17 Feb 2026 19:44:43 +0545 Subject: [PATCH 09/13] Resolve backward incompatibility --- .../_internal/core/backends/base/compute.py | 10 +-- .../core/backends/kubernetes/compute.py | 4 +- .../_internal/core/compatibility/runs.py | 3 +- src/dstack/_internal/core/models/gateways.py | 13 ++-- src/dstack/_internal/core/models/routers.py | 22 ++++++- .../server/services/services/__init__.py | 62 +++++++++++++++---- 6 files changed, 89 insertions(+), 25 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index ade1b3daeb..49513e3211 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,7 +39,7 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData -from dstack._internal.core.models.routers import RouterType +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -924,7 +924,7 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[RouterType] = None) -> str: +def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1036,7 +1036,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1045,11 +1045,11 @@ def get_dstack_gateway_wheel(build: str, router: Optional[RouterType] = None) -> wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: - return f"dstack-gateway[{router.value}] @ {wheel}" + return f"dstack-gateway[{router.type}] @ {wheel}" return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[RouterType] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index d98dfc94a8..51abddc70c 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -66,7 +66,7 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec -from dstack._internal.core.models.routers import RouterType +from dstack._internal.core.models.routers import AnyRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error @@ -864,7 +864,7 @@ def _wait_for_load_balancer_address( def _get_gateway_commands( - authorized_keys: List[str], router: Optional[RouterType] = None + authorized_keys: List[str], router: Optional[AnyRouterConfig] = None ) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 96db2574b3..ac355b7be1 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -2,6 +2,7 @@ from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.routers import SGLangRouterConfig from dstack._internal.core.models.runs import ( DEFAULT_PROBE_UNTIL_READY, DEFAULT_REPLICA_GROUP_NAME, @@ -75,7 +76,7 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: router = run_spec.configuration.router if router is None: configuration_excludes["router"] = True - elif router.pd_disaggregation is False: + elif isinstance(router, SGLangRouterConfig) and router.pd_disaggregation is False: configuration_excludes["router"] = {"pd_disaggregation": True} if configuration_excludes: diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index e9581b83f0..0ef4f38a49 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel -from dstack._internal.core.models.routers import RouterType +from dstack._internal.core.models.routers import GatewayRouterConfig from dstack._internal.utils.tags import tags_validator @@ -63,8 +63,13 @@ class GatewayConfiguration(CoreModel): ), ] = None router: Annotated[ - Optional[RouterType], - Field(description="The router type enabled on this gateway. E.g. 'sglang'."), + Optional[GatewayRouterConfig], + Field( + description=( + "The router configuration for this gateway. " + "E.g. `{ type: sglang, policy: round_robin }`." + ), + ), ] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") @@ -134,7 +139,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None - router: Optional[RouterType] = None + router: Optional[GatewayRouterConfig] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index e42cd9976d..12f3d18584 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal +from typing import Literal, Union from pydantic import Field from typing_extensions import Annotated @@ -11,6 +11,24 @@ class RouterType(str, Enum): SGLANG = "sglang" +class GatewayRouterConfig(CoreModel): + """Gateway-level router configuration. type and policy only. pd_disaggregation is service-level.""" + + type: Annotated[ + Literal["sglang"], + Field(description="The router type enabled on this gateway."), + ] = "sglang" + policy: Annotated[ + Literal["random", "round_robin", "cache_aware", "power_of_two"], + Field( + description=( + "The routing policy. Deprecated: prefer setting policy in the service's router config. " + "Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" + ), + ), + ] = "cache_aware" + + class SGLangRouterConfig(CoreModel): type: Annotated[Literal["sglang"], Field(description="The router type")] = "sglang" policy: Annotated[ @@ -25,4 +43,4 @@ class SGLangRouterConfig(CoreModel): ] = False -AnyRouterConfig = SGLangRouterConfig +AnyRouterConfig = Union[SGLangRouterConfig, GatewayRouterConfig] diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 201351e50d..a916e8b795 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,7 +26,7 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import RouterType +from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -45,6 +45,41 @@ logger = get_logger(__name__) +def _gateway_has_sglang_router(config: GatewayConfiguration) -> bool: + return config.router is not None and config.router.type == RouterType.SGLANG.value + + +def _build_service_router_config( + gateway_configuration: GatewayConfiguration, + service_configuration: ServiceConfiguration, +) -> Optional[SGLangRouterConfig]: + """ + Build router config from gateway (type, policy) + service (pd_disaggregation, policy override). + Service's policy overrides gateway's if present. Keeps backward compat: SGLang enabled + automatically when gateway has it configured. + """ + if not _gateway_has_sglang_router(gateway_configuration): + return None + + gateway_router = gateway_configuration.router + assert gateway_router is not None # ensured by _gateway_has_sglang_router + router_type = gateway_router.type + policy = gateway_router.policy + + service_router = service_configuration.router + if service_router is not None and isinstance(service_router, SGLangRouterConfig): + policy = service_router.policy + pd_disaggregation = service_router.pd_disaggregation + else: + pd_disaggregation = False + + return SGLangRouterConfig( + type=router_type, + policy=policy, + pd_disaggregation=pd_disaggregation, + ) + + async def register_service(session: AsyncSession, run_model: RunModel, run_spec: RunSpec): assert isinstance(run_spec.configuration, ServiceConfiguration) @@ -92,17 +127,22 @@ async def _register_service_in_gateway( raise ServerClientError("Gateway status is not running") gateway_configuration = get_gateway_configuration(gateway) - if ( - run_spec.configuration.router is not None - and run_spec.configuration.router.type == RouterType.SGLANG - ): - if gateway_configuration.router != RouterType.SGLANG: - raise ServerClientError( - f"Service requires a SGLang gateway but gateway '{gateway.name}' " - "does not have the SGLang router configured." - ) + + # Check: service wants pd_disaggregation but gateway has no SGLang router + service_router = run_spec.configuration.router + service_pd_disaggregation = ( + service_router is not None + and isinstance(service_router, SGLangRouterConfig) + and service_router.pd_disaggregation + ) + if service_pd_disaggregation and not _gateway_has_sglang_router(gateway_configuration): + raise ServerClientError( + "Service requires gateway with SGLang router for pd_disaggregation but gateway " + f"'{gateway.name}' does not have the SGLang router configured." + ) + service_https = _get_service_https(run_spec, gateway_configuration) - router = run_spec.configuration.router + router = _build_service_router_config(gateway_configuration, run_spec.configuration) service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: From b71080d3a7e4726dd642f0489f033b13e0ef6758 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 18 Feb 2026 06:18:15 +0545 Subject: [PATCH 10/13] Update RouterConfigs --- docs/docs/reference/dstack.yml/gateway.md | 2 +- .../_internal/core/backends/base/compute.py | 13 ++++++---- .../core/backends/kubernetes/compute.py | 4 ++-- .../_internal/core/compatibility/runs.py | 4 ++-- .../_internal/core/models/configurations.py | 4 ++-- src/dstack/_internal/core/models/gateways.py | 6 ++--- src/dstack/_internal/core/models/routers.py | 9 +++---- .../proxy/gateway/schemas/registry.py | 4 ++-- .../services/model_routers/__init__.py | 4 ++-- .../gateway/services/model_routers/base.py | 4 ++-- .../gateway/services/model_routers/sglang.py | 4 ++-- .../_internal/proxy/gateway/services/nginx.py | 4 ++-- .../proxy/gateway/services/registry.py | 4 ++-- src/dstack/_internal/proxy/lib/models.py | 4 ++-- .../server/services/gateways/client.py | 4 ++-- .../server/services/services/__init__.py | 24 ++++++++++--------- 16 files changed, 52 insertions(+), 46 deletions(-) diff --git a/docs/docs/reference/dstack.yml/gateway.md b/docs/docs/reference/dstack.yml/gateway.md index b8e2742891..1d74c95705 100644 --- a/docs/docs/reference/dstack.yml/gateway.md +++ b/docs/docs/reference/dstack.yml/gateway.md @@ -14,7 +14,7 @@ The `gateway` configuration type allows creating and updating [gateways](../../c === "SGLang Model Gateway" - #SCHEMA# dstack._internal.core.models.routers.SGLangRouterConfig + #SCHEMA# dstack._internal.core.models.routers.SGLangGatewayRouterConfig overrides: show_root_heading: false type: diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 49513e3211..ff0f9323b6 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,7 +39,7 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -924,7 +924,9 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_gateway_user_data( + authorized_key: str, router: Optional[AnyGatewayRouterConfig] = None +) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1036,20 +1038,21 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[AnyGatewayRouterConfig] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.type}] @ {wheel}" return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[AnyGatewayRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 51abddc70c..870b6bb657 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -66,7 +66,7 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error @@ -864,7 +864,7 @@ def _wait_for_load_balancer_address( def _get_gateway_commands( - authorized_keys: List[str], router: Optional[AnyRouterConfig] = None + authorized_keys: List[str], router: Optional[AnyGatewayRouterConfig] = None ) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index ac355b7be1..4ece12392c 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -2,7 +2,7 @@ from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.configurations import ServiceConfiguration -from dstack._internal.core.models.routers import SGLangRouterConfig +from dstack._internal.core.models.routers import SGLangServiceRouterConfig from dstack._internal.core.models.runs import ( DEFAULT_PROBE_UNTIL_READY, DEFAULT_REPLICA_GROUP_NAME, @@ -76,7 +76,7 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: router = run_spec.configuration.router if router is None: configuration_excludes["router"] = True - elif isinstance(router, SGLangRouterConfig) and router.pd_disaggregation is False: + elif isinstance(router, SGLangServiceRouterConfig) and router.pd_disaggregation is False: configuration_excludes["router"] = {"pd_disaggregation": True} if configuration_excludes: diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index f66053831a..81c57f198f 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -28,7 +28,7 @@ parse_off_duration, ) from dstack._internal.core.models.resources import Range, ResourcesSpec -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point @@ -890,7 +890,7 @@ class ServiceConfigurationParams(CoreModel): ), ] = None router: Annotated[ - Optional[AnyRouterConfig], + Optional[AnyServiceRouterConfig], Field( description=( "Router configuration for the service. Requires a gateway with matching router enabled. " diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 0ef4f38a49..816395fc82 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel -from dstack._internal.core.models.routers import GatewayRouterConfig +from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.utils.tags import tags_validator @@ -63,7 +63,7 @@ class GatewayConfiguration(CoreModel): ), ] = None router: Annotated[ - Optional[GatewayRouterConfig], + Optional[AnyGatewayRouterConfig], Field( description=( "The router configuration for this gateway. " @@ -139,7 +139,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None - router: Optional[GatewayRouterConfig] = None + router: Optional[AnyGatewayRouterConfig] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index 12f3d18584..49769fb8f1 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal, Union +from typing import Literal from pydantic import Field from typing_extensions import Annotated @@ -11,7 +11,7 @@ class RouterType(str, Enum): SGLANG = "sglang" -class GatewayRouterConfig(CoreModel): +class SGLangGatewayRouterConfig(CoreModel): """Gateway-level router configuration. type and policy only. pd_disaggregation is service-level.""" type: Annotated[ @@ -29,7 +29,7 @@ class GatewayRouterConfig(CoreModel): ] = "cache_aware" -class SGLangRouterConfig(CoreModel): +class SGLangServiceRouterConfig(CoreModel): type: Annotated[Literal["sglang"], Field(description="The router type")] = "sglang" policy: Annotated[ Literal["random", "round_robin", "cache_aware", "power_of_two"], @@ -43,4 +43,5 @@ class SGLangRouterConfig(CoreModel): ] = False -AnyRouterConfig = Union[SGLangRouterConfig, GatewayRouterConfig] +AnyServiceRouterConfig = SGLangServiceRouterConfig +AnyGatewayRouterConfig = SGLangGatewayRouterConfig diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 967e9d9960..802d23a700 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.proxy.lib.models import RateLimit @@ -45,7 +45,7 @@ class RegisterServiceRequest(BaseModel): options: Options ssh_private_key: str rate_limits: tuple[RateLimit, ...] = () - router: Optional[AnyRouterConfig] = None + router: Optional[AnyServiceRouterConfig] = None class RegisterReplicaRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py index 9678699ac6..43477d2d3f 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py @@ -1,11 +1,11 @@ -from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.gateway.services.model_routers.sglang import SglangRouter from dstack._internal.proxy.lib.errors import ProxyError from .base import Router, RouterContext -def get_router(router: AnyRouterConfig, context: RouterContext) -> Router: +def get_router(router: AnyServiceRouterConfig, context: RouterContext) -> Router: if router.type == RouterType.SGLANG: return SglangRouter(config=router, context=context) raise ProxyError(f"Router type '{router.type}' is not available") diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 867591ca13..83ec14cb4d 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig class RouterContext(BaseModel): @@ -29,7 +29,7 @@ class Router(ABC): def __init__( self, context: RouterContext, - config: Optional[AnyRouterConfig] = None, + config: Optional[AnyServiceRouterConfig] = None, ): """Initialize router with context. diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index 6c7dafd1b0..c1c03c5a11 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -7,7 +7,7 @@ import httpx import psutil -from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.lib.errors import UnexpectedProxyError from dstack._internal.utils.logging import get_logger @@ -21,7 +21,7 @@ class SglangRouter(Router): TYPE = RouterType.SGLANG - def __init__(self, config: SGLangRouterConfig, context: RouterContext): + def __init__(self, config: AnyServiceRouterConfig, context: RouterContext): """Initialize SGLang router. Args: diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index d79b0c7932..d400c24880 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from typing_extensions import Literal -from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.gateway.const import PROXY_PORT_ON_GATEWAY from dstack._internal.proxy.gateway.models import ACMESettings from dstack._internal.proxy.gateway.services.model_routers import ( @@ -74,7 +74,7 @@ class ServiceConfig(SiteConfig): limit_req_zones: list[LimitReqZoneConfig] locations: list[LocationConfig] replicas: list[ReplicaConfig] - router: Optional[AnyRouterConfig] = None + router: Optional[AnyServiceRouterConfig] = None router_port: Optional[int] = None diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index ed0ea07d77..9db2d831eb 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -6,7 +6,7 @@ import dstack._internal.proxy.gateway.schemas.registry as schemas from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.gateway import models as gateway_models from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.gateway.services.nginx import ( @@ -45,7 +45,7 @@ async def register_service( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, - router: Optional[AnyRouterConfig] = None, + router: Optional[AnyServiceRouterConfig] = None, ) -> None: service = models.Service( project_name=project_name, diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index 9ae6b4d2ad..a51766fdd5 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -7,7 +7,7 @@ from typing_extensions import Annotated from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError @@ -59,7 +59,7 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] - router: Optional[AnyRouterConfig] = None + router: Optional[AnyServiceRouterConfig] = None @property def domain_safe(self) -> str: diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index e68b874728..9bc7a1f903 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -9,7 +9,7 @@ from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.configurations import RateLimit from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port from dstack._internal.proxy.gateway.schemas.stats import ServiceStats from dstack._internal.server import settings @@ -46,7 +46,7 @@ async def register_service( options: dict, rate_limits: list[RateLimit], ssh_private_key: str, - router: Optional[AnyRouterConfig] = None, + router: Optional[AnyServiceRouterConfig] = None, ): if "openai" in options: entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}" diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index a916e8b795..c0a283f924 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,7 +26,11 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig +from dstack._internal.core.models.routers import ( + AnyServiceRouterConfig, + RouterType, + SGLangServiceRouterConfig, +) from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel @@ -52,7 +56,7 @@ def _gateway_has_sglang_router(config: GatewayConfiguration) -> bool: def _build_service_router_config( gateway_configuration: GatewayConfiguration, service_configuration: ServiceConfiguration, -) -> Optional[SGLangRouterConfig]: +) -> Optional[AnyServiceRouterConfig]: """ Build router config from gateway (type, policy) + service (pd_disaggregation, policy override). Service's policy overrides gateway's if present. Keeps backward compat: SGLang enabled @@ -67,13 +71,13 @@ def _build_service_router_config( policy = gateway_router.policy service_router = service_configuration.router - if service_router is not None and isinstance(service_router, SGLangRouterConfig): + if service_router is not None and isinstance(service_router, SGLangServiceRouterConfig): policy = service_router.policy pd_disaggregation = service_router.pd_disaggregation else: pd_disaggregation = False - return SGLangRouterConfig( + return SGLangServiceRouterConfig( type=router_type, policy=policy, pd_disaggregation=pd_disaggregation, @@ -128,16 +132,14 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) - # Check: service wants pd_disaggregation but gateway has no SGLang router + # Check: service specifies SGLang router but gateway does not have it service_router = run_spec.configuration.router - service_pd_disaggregation = ( - service_router is not None - and isinstance(service_router, SGLangRouterConfig) - and service_router.pd_disaggregation + service_wants_sglang = service_router is not None and isinstance( + service_router, SGLangServiceRouterConfig ) - if service_pd_disaggregation and not _gateway_has_sglang_router(gateway_configuration): + if service_wants_sglang and not _gateway_has_sglang_router(gateway_configuration): raise ServerClientError( - "Service requires gateway with SGLang router for pd_disaggregation but gateway " + "Service requires gateway with SGLang router but gateway " f"'{gateway.name}' does not have the SGLang router configured." ) From ac4b2a4b8fc1db99df755e5e7b84d7c3151e71a7 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 18 Feb 2026 06:35:13 +0545 Subject: [PATCH 11/13] Resolve Lint Error --- src/dstack/_internal/proxy/lib/models.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index b3e203657b..a0a724dbea 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -59,12 +59,8 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] -<<<<<<< add_pd_disaggregated_inference router: Optional[AnyServiceRouterConfig] = None -======= - router: Optional[AnyRouterConfig] = None cors_enabled: bool = False # only used on gateways; enabled for openai-format models ->>>>>>> master @property def domain_safe(self) -> str: From b619657854ae9aa588d5d5a4714ff7601f702304 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 18 Feb 2026 13:26:05 +0545 Subject: [PATCH 12/13] Update gateway wheel --- src/dstack/_internal/core/backends/base/compute.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index ff0f9323b6..a2507f4240 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1044,8 +1044,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyGatewayRouterConfig if build == "latest": build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) - # wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" - wheel = "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl" + wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified if router: return f"dstack-gateway[{router.type}] @ {wheel}" From f540943c28154f2f2f7ee1def6b257dc343e26f6 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 18 Feb 2026 14:06:44 +0545 Subject: [PATCH 13/13] Minor Update --- src/dstack/_internal/proxy/gateway/services/nginx.py | 3 ++- src/dstack/_internal/proxy/gateway/services/registry.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 8f77a9dbca..47b93d074d 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -190,7 +190,8 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.info("Registered %s domain %s", conf.type, conf.domain) - async def unregister(self, domain: str, service: models.Service) -> None: + async def unregister(self, service: models.Service) -> None: + domain = service.domain_safe logger.debug("Unregistering domain %s", domain) conf_path = self._conf_dir / self.get_config_name(domain) if not conf_path.exists(): diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 0584515268..dc6407d245 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -118,7 +118,7 @@ async def unregister_service( ids=(r.id for r in service.replicas), service_conn_pool=service_conn_pool, ) - await nginx.unregister(service.domain_safe, service) + await nginx.unregister(service) await repo.delete_models_by_run(project_name, run_name) await repo.delete_service(project_name, run_name) @@ -241,9 +241,7 @@ async def register_model_entrypoint( def _uses_pd_disaggregation(service: models.Service) -> bool: """PD disaggregation: router talks to replicas via internal_ip, no SSH tunnels needed.""" - return ( - service.router is not None and getattr(service.router, "pd_disaggregation", False) is True - ) + return service.router is not None and service.router.pd_disaggregation async def apply_service(