diff --git a/docs/docs/reference/dstack.yml/gateway.md b/docs/docs/reference/dstack.yml/gateway.md index b8e2742891..1d74c95705 100644 --- a/docs/docs/reference/dstack.yml/gateway.md +++ b/docs/docs/reference/dstack.yml/gateway.md @@ -14,7 +14,7 @@ The `gateway` configuration type allows creating and updating [gateways](../../c === "SGLang Model Gateway" - #SCHEMA# dstack._internal.core.models.routers.SGLangRouterConfig + #SCHEMA# dstack._internal.core.models.routers.SGLangGatewayRouterConfig overrides: show_root_heading: false type: diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index c40a37b7f5..6c4d406a6f 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ ] [project.optional-dependencies] -sglang = ["sglang-router==0.2.1"] +sglang = ["sglang-router==0.3.2"] [tool.setuptools.package-data] "dstack.gateway" = [ diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 49513e3211..a2507f4240 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -39,7 +39,7 @@ SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import ( Volume, @@ -924,7 +924,9 @@ def get_run_shim_script( ] -def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_gateway_user_data( + authorized_key: str, router: Optional[AnyGatewayRouterConfig] = None +) -> str: return get_cloud_config( package_update=True, packages=[ @@ -1036,7 +1038,7 @@ def get_latest_runner_build() -> Optional[str]: return None -def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str: +def get_dstack_gateway_wheel(build: str, router: Optional[AnyGatewayRouterConfig] = None) -> str: channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": @@ -1049,7 +1051,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non return f"dstack-gateway @ {wheel}" -def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: +def get_dstack_gateway_commands(router: Optional[AnyGatewayRouterConfig] = None) -> List[str]: build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 51abddc70c..870b6bb657 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -66,7 +66,7 @@ ) from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import CPUSpec, GPUSpec -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import get_or_error @@ -864,7 +864,7 @@ def _wait_for_load_balancer_address( def _get_gateway_commands( - authorized_keys: List[str], router: Optional[AnyRouterConfig] = None + authorized_keys: List[str], router: Optional[AnyGatewayRouterConfig] = None ) -> List[str]: authorized_keys_content = "\n".join(authorized_keys).strip() gateway_commands = " && ".join(get_dstack_gateway_commands(router=router)) diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index de94f6a18e..949d6515f8 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -31,9 +31,7 @@ def _get_gateway_configuration_excludes( ) -> IncludeExcludeDictType: configuration_excludes: IncludeExcludeDictType = {} - # Add excludes like this: - # - # if configuration.tags is None: - # configuration_excludes["tags"] = True + if configuration.router is None: + configuration_excludes["router"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 19c08cde55..4ece12392c 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -2,6 +2,7 @@ from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.routers import SGLangServiceRouterConfig from dstack._internal.core.models.runs import ( DEFAULT_PROBE_UNTIL_READY, DEFAULT_REPLICA_GROUP_NAME, @@ -72,6 +73,12 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: # Servers prior to 0.20.8 do not support probes=None configuration_excludes["probes"] = True + router = run_spec.configuration.router + if router is None: + configuration_excludes["router"] = True + elif isinstance(router, SGLangServiceRouterConfig) and router.pd_disaggregation is False: + configuration_excludes["router"] = {"pd_disaggregation": True} + if configuration_excludes: spec_excludes["configuration"] = configuration_excludes if profile_excludes: diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 040c382359..93c63e6b31 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -28,6 +28,7 @@ parse_off_duration, ) from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point @@ -887,6 +888,14 @@ class ServiceConfigurationParams(CoreModel): ) ), ] = None + router: Annotated[ + Optional[AnyServiceRouterConfig], + Field( + description=( + "Router configuration for the service. Requires a gateway with matching router enabled. " + ), + ), + ] = None @validator("port") def convert_port(cls, v) -> PortMapping: diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index b342c0a73b..816395fc82 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -8,7 +8,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.utils.tags import tags_validator @@ -63,8 +63,13 @@ class GatewayConfiguration(CoreModel): ), ] = None router: Annotated[ - Optional[AnyRouterConfig], - Field(description="The router configuration"), + Optional[AnyGatewayRouterConfig], + Field( + description=( + "The router configuration for this gateway. " + "E.g. `{ type: sglang, policy: round_robin }`." + ), + ), ] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") @@ -134,7 +139,7 @@ class GatewayComputeConfiguration(CoreModel): ssh_key_pub: str certificate: Optional[AnyGatewayCertificate] = None tags: Optional[Dict[str, str]] = None - router: Optional[AnyRouterConfig] = None + router: Optional[AnyGatewayRouterConfig] = None class GatewayProvisioningData(CoreModel): diff --git a/src/dstack/_internal/core/models/routers.py b/src/dstack/_internal/core/models/routers.py index e07631e12e..49769fb8f1 100644 --- a/src/dstack/_internal/core/models/routers.py +++ b/src/dstack/_internal/core/models/routers.py @@ -11,7 +11,25 @@ class RouterType(str, Enum): SGLANG = "sglang" -class SGLangRouterConfig(CoreModel): +class SGLangGatewayRouterConfig(CoreModel): + """Gateway-level router configuration. type and policy only. pd_disaggregation is service-level.""" + + type: Annotated[ + Literal["sglang"], + Field(description="The router type enabled on this gateway."), + ] = "sglang" + policy: Annotated[ + Literal["random", "round_robin", "cache_aware", "power_of_two"], + Field( + description=( + "The routing policy. Deprecated: prefer setting policy in the service's router config. " + "Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" + ), + ), + ] = "cache_aware" + + +class SGLangServiceRouterConfig(CoreModel): type: Annotated[Literal["sglang"], Field(description="The router type")] = "sglang" policy: Annotated[ Literal["random", "round_robin", "cache_aware", "power_of_two"], @@ -19,6 +37,11 @@ class SGLangRouterConfig(CoreModel): description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`" ), ] = "cache_aware" + pd_disaggregation: Annotated[ + bool, + Field(description="Enable PD disaggregation mode for the SGLang router"), + ] = False -AnyRouterConfig = SGLangRouterConfig +AnyServiceRouterConfig = SGLangServiceRouterConfig +AnyGatewayRouterConfig = SGLangGatewayRouterConfig diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index dd4f63f325..c5f4cf8a1a 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -80,6 +80,7 @@ async def register_replica( ssh_proxy=body.ssh_proxy, ssh_head_proxy=body.ssh_head_proxy, ssh_head_proxy_private_key=body.ssh_head_proxy_private_key, + internal_ip=body.internal_ip, repo=repo, nginx=nginx, service_conn_pool=service_conn_pool, diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 53a29f68ca..802d23a700 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.proxy.lib.models import RateLimit @@ -45,7 +45,7 @@ class RegisterServiceRequest(BaseModel): options: Options ssh_private_key: str rate_limits: tuple[RateLimit, ...] = () - router: Optional[AnyRouterConfig] = None + router: Optional[AnyServiceRouterConfig] = None class RegisterReplicaRequest(BaseModel): @@ -56,6 +56,7 @@ class RegisterReplicaRequest(BaseModel): ssh_proxy: Optional[SSHConnectionParams] ssh_head_proxy: Optional[SSHConnectionParams] ssh_head_proxy_private_key: Optional[str] + internal_ip: Optional[str] = None class RegisterEntrypointRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py index 9678699ac6..43477d2d3f 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/__init__.py @@ -1,11 +1,11 @@ -from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.gateway.services.model_routers.sglang import SglangRouter from dstack._internal.proxy.lib.errors import ProxyError from .base import Router, RouterContext -def get_router(router: AnyRouterConfig, context: RouterContext) -> Router: +def get_router(router: AnyServiceRouterConfig, context: RouterContext) -> Router: if router.type == RouterType.SGLANG: return SglangRouter(config=router, context=context) raise ProxyError(f"Router type '{router.type}' is not available") diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py index 867591ca13..83ec14cb4d 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/base.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/base.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig class RouterContext(BaseModel): @@ -29,7 +29,7 @@ class Router(ABC): def __init__( self, context: RouterContext, - config: Optional[AnyRouterConfig] = None, + config: Optional[AnyServiceRouterConfig] = None, ): """Initialize router with context. diff --git a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py index c3a0dfaae9..c1c03c5a11 100644 --- a/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py +++ b/src/dstack/_internal/proxy/gateway/services/model_routers/sglang.py @@ -2,13 +2,12 @@ import subprocess import sys import time -import urllib.parse from typing import List, Optional import httpx import psutil -from dstack._internal.core.models.routers import RouterType, SGLangRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.lib.errors import UnexpectedProxyError from dstack._internal.utils.logging import get_logger @@ -22,7 +21,7 @@ class SglangRouter(Router): TYPE = RouterType.SGLANG - def __init__(self, config: SGLangRouterConfig, context: RouterContext): + def __init__(self, config: AnyServiceRouterConfig, context: RouterContext): """Initialize SGLang router. Args: @@ -68,6 +67,8 @@ def start(self) -> None: "--policy", self.config.policy, ] + if self.config.pd_disaggregation: + cmd.append("--pd-disaggregation") subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) @@ -174,7 +175,7 @@ def update_replicas(self, replica_urls: List[str]) -> None: # Add workers for worker_url in sorted(workers_to_add): - success = self._add_worker_to_router(worker_url) + success = self._register_worker(worker_url) if not success: logger.warning("Failed to add worker %s, continuing with others", worker_url) @@ -197,9 +198,16 @@ def _get_router_workers(self) -> List[dict]: logger.exception("Error getting sglang router workers") return [] - def _add_worker_to_router(self, worker_url: str) -> bool: + def _add_worker_to_router( + self, + url: str, + worker_type: str = "regular", + bootstrap_port: Optional[int] = None, + ) -> bool: try: - payload = {"url": worker_url, "worker_type": "regular"} + payload: dict = {"url": url, "worker_type": worker_type} + if bootstrap_port is not None: + payload["bootstrap_port"] = bootstrap_port with httpx.Client(timeout=5.0) as client: response = client.post( f"http://{self.context.host}:{self.context.port}/workers", @@ -209,8 +217,9 @@ def _add_worker_to_router(self, worker_url: str) -> bool: response_data = response.json() if response_data.get("status") == "accepted": logger.info( - "Worker %s accepted by sglang router on port %s", - worker_url, + "Worker %s (type=%s) accepted by sglang router on port %s", + url, + worker_type, self.context.port, ) return True @@ -224,21 +233,68 @@ def _add_worker_to_router(self, worker_url: str) -> bool: else: logger.error( "Failed to add worker %s: status %d, %s", - worker_url, + url, response.status_code, response.text, ) return False except Exception: - logger.exception("Error adding worker %s", worker_url) + logger.exception("Error adding worker %s", url) + return False + + def _register_worker(self, url: str) -> bool: + if not self.config.pd_disaggregation: + return self._add_worker_to_router(url, "regular", None) + + server_info_url = f"{url}/server_info" + try: + with httpx.Client(timeout=10) as client: + resp = client.get(server_info_url) + if resp.status_code != 200: + return False + data = resp.json() + if data.get("status") != "ready": + return False + disaggregation_mode = data.get("disaggregation_mode", "") + if disaggregation_mode == "prefill": + worker_type = "prefill" + bootstrap_port = data.get("disaggregation_bootstrap_port") + elif disaggregation_mode == "decode": + worker_type = "decode" + bootstrap_port = None + else: + worker_type = "regular" + bootstrap_port = None + logger.info( + "Registering worker %s (type=%s)", + url, + worker_type, + ) + return self._add_worker_to_router( + url, + worker_type, + bootstrap_port, + ) + except Exception: + logger.exception("Error registering worker %s", url) return False def _remove_worker_from_router(self, worker_url: str) -> bool: try: - encoded_url = urllib.parse.quote(worker_url, safe="") + current_workers = self._get_router_workers() + worker_id = None + for worker in current_workers: + url = worker.get("url") + if url and isinstance(url, str) and url == worker_url: + worker_id = worker.get("id") + if worker_id and isinstance(worker_id, str): + break + if not worker_id: + logger.error("No worker id found for url %s", worker_url) + return False with httpx.Client(timeout=5.0) as client: response = client.delete( - f"http://{self.context.host}:{self.context.port}/workers/{encoded_url}" + f"http://{self.context.host}:{self.context.port}/workers/{worker_id}" ) if response.status_code == 202: response_data = response.json() diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index c971d4197a..47b93d074d 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -5,12 +5,13 @@ from asyncio import Lock from pathlib import Path from typing import Dict, Optional +from urllib.parse import urlparse import jinja2 from pydantic import BaseModel from typing_extensions import Literal -from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.gateway.const import PROXY_PORT_ON_GATEWAY from dstack._internal.proxy.gateway.models import ACMESettings from dstack._internal.proxy.gateway.services.model_routers import ( @@ -18,6 +19,7 @@ RouterContext, get_router, ) +from dstack._internal.proxy.lib import models from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -43,6 +45,8 @@ def render(self) -> str: class ReplicaConfig(BaseModel): id: str socket: Path + port: int + internal_ip: Optional[str] = None class LimitReqZoneConfig(BaseModel): @@ -70,7 +74,7 @@ class ServiceConfig(SiteConfig): limit_req_zones: list[LimitReqZoneConfig] locations: list[LocationConfig] replicas: list[ReplicaConfig] - router: Optional[AnyRouterConfig] = None + router: Optional[AnyServiceRouterConfig] = None router_port: Optional[int] = None cors_enabled: bool = False @@ -96,7 +100,7 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._next_router_port: int = self._ROUTER_PORT_MIN # Tracking of worker ports to avoid conflicts across router instances self._allocated_worker_ports: set[int] = set() - self._domain_to_worker_ports: Dict[str, list[int]] = {} + self._domain_to_worker_urls: Dict[str, list[str]] = {} self._next_worker_port: int = self._WORKER_PORT_MIN async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: @@ -145,33 +149,40 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: del self._domain_to_router[conf.domain] raise - allocated_ports = self._allocate_worker_ports(len(conf.replicas)) - replica_urls = [ - f"http://{router.context.host}:{port}" for port in allocated_ports - ] - - # Write router workers config - try: + if conf.router.pd_disaggregation: + # PD path: replica_urls from internal_ip (router talks directly to workers) + if any(not r.internal_ip for r in conf.replicas): + raise ProxyError( + "PD disaggregation requires internal IP for all replicas." + ) + replica_urls = [ + f"http://{replica.internal_ip}:{replica.port}" + for replica in conf.replicas + ] + self._domain_to_worker_urls[conf.domain] = replica_urls + else: + # Non-PD path: allocate gateway-local ports, nginx proxies to replica sockets + allocated_ports = self._allocate_worker_ports(len(conf.replicas)) + replica_urls = [ + f"http://{router.context.host}:{port}" for port in allocated_ports + ] if conf.replicas: - await run_async(self.write_router_workers_conf, conf, allocated_ports) - # Discard old worker ports if domain already has allocated ports (required for scaling case) - if conf.domain in self._domain_to_worker_ports: - old_worker_ports = self._domain_to_worker_ports[conf.domain] - for port in old_worker_ports: - self._allocated_worker_ports.discard(port) - self._domain_to_worker_ports[conf.domain] = allocated_ports - except Exception as e: - logger.exception( - "write_router_workers_conf failed for domain=%s: %s", conf.domain, e - ) - raise + await run_async( + self.write_router_workers_conf, + conf, + allocated_ports, + ) + if conf.domain in self._domain_to_worker_urls: + self._discard_ports(self._domain_to_worker_urls[conf.domain]) + self._domain_to_worker_urls[conf.domain] = replica_urls - # Update replicas to router (actual HTTP API calls to add workers) try: await run_async(router.update_replicas, replica_urls) except Exception as e: logger.exception( - "Failed to add replicas to router for domain=%s: %s", conf.domain, e + "Failed to add replicas to router for domain=%s: %s", + conf.domain, + e, ) raise @@ -179,7 +190,8 @@ async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.info("Registered %s domain %s", conf.type, conf.domain) - async def unregister(self, domain: str) -> None: + async def unregister(self, service: models.Service) -> None: + domain = service.domain_safe logger.debug("Unregistering domain %s", domain) conf_path = self._conf_dir / self.get_config_name(domain) if not conf_path.exists(): @@ -190,12 +202,16 @@ async def unregister(self, domain: str) -> None: if domain in self._domain_to_router: router = self._domain_to_router[domain] # Remove all workers for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - replica_urls = [ - f"http://{router.context.host}:{port}" for port in worker_ports - ] - await run_async(router.remove_replicas, replica_urls) + if domain in self._domain_to_worker_urls: + worker_urls = self._domain_to_worker_urls[domain] + await run_async(router.remove_replicas, worker_urls) + pd_disaggregation = ( + service.router.pd_disaggregation if service.router else False + ) + if not pd_disaggregation: + self._discard_ports(worker_urls) + del self._domain_to_worker_urls[domain] + logger.debug("Removed worker URLs for domain %s", domain) # Stop and kill the router await run_async(router.stop) # Remove from mappings @@ -204,14 +220,6 @@ async def unregister(self, domain: str) -> None: del self._router_port_to_domain[router_port] del self._domain_to_router[domain] - # Discard worker ports for this domain - if domain in self._domain_to_worker_ports: - worker_ports = self._domain_to_worker_ports[domain] - for port in worker_ports: - self._allocated_worker_ports.discard(port) - del self._domain_to_worker_ports[domain] - logger.debug("Freed worker ports %s for domain %s", worker_ports, domain) - # Remove workers config file workers_conf_path = self._conf_dir / f"router-workers.{domain}.conf" if workers_conf_path.exists(): @@ -404,6 +412,12 @@ def _allocate_worker_ports(self, num_ports: int) -> list[int]: return allocated + def _discard_ports(self, urls: list[str]) -> None: + for u in urls: + parsed = urlparse(u) + if parsed.port is not None and parsed.port in self._allocated_worker_ports: + self._allocated_worker_ports.discard(parsed.port) + def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 036b864396..dc6407d245 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -6,7 +6,7 @@ import dstack._internal.proxy.gateway.schemas.registry as schemas from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig, RouterType +from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType from dstack._internal.proxy.gateway import models as gateway_models from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.gateway.services.nginx import ( @@ -45,7 +45,7 @@ async def register_service( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, - router: Optional[AnyRouterConfig] = None, + router: Optional[AnyServiceRouterConfig] = None, ) -> None: cors_enabled = model is not None and model.type == "chat" and model.format == "openai" service = models.Service( @@ -118,7 +118,7 @@ async def unregister_service( ids=(r.id for r in service.replicas), service_conn_pool=service_conn_pool, ) - await nginx.unregister(service.domain_safe) + await nginx.unregister(service) await repo.delete_models_by_run(project_name, run_name) await repo.delete_service(project_name, run_name) @@ -138,6 +138,7 @@ async def register_replica( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + internal_ip: Optional[str] = None, ) -> None: replica = models.Replica( id=replica_id, @@ -147,6 +148,7 @@ async def register_replica( ssh_proxy=ssh_proxy, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, + internal_ip=internal_ip, ) async with lock: @@ -237,6 +239,11 @@ async def register_model_entrypoint( logger.info("Entrypoint %s is now registered in project %s", domain, project_name) +def _uses_pd_disaggregation(service: models.Service) -> bool: + """PD disaggregation: router talks to replicas via internal_ip, no SSH tunnels needed.""" + return service.router is not None and service.router.pd_disaggregation + + async def apply_service( service: models.Service, old_service: Optional[models.Service], @@ -256,13 +263,31 @@ async def apply_service( ), service_conn_pool=service_conn_pool, ) - replica_conns, replica_failures = await get_or_add_replica_connections( - service, repo, service_conn_pool - ) - replica_configs = [ - ReplicaConfig(id=replica.id, socket=conn.app_socket_path) - for replica, conn in replica_conns.items() - ] + if _uses_pd_disaggregation(service): + replica_conns = {} + replica_failures = {} + replica_configs = [ + ReplicaConfig( + id=replica.id, + socket=Path("/dev/null"), + port=replica.app_port, + internal_ip=replica.internal_ip, + ) + for replica in service.replicas + ] + else: + replica_conns, replica_failures = await get_or_add_replica_connections( + service, repo, service_conn_pool + ) + replica_configs = [ + ReplicaConfig( + id=replica.id, + socket=conn.app_socket_path, + port=replica.app_port, + internal_ip=replica.internal_ip, + ) + for replica, conn in replica_conns.items() + ] service_config = await get_nginx_service_config(service, replica_configs) await nginx.register(service_config, (await repo.get_config()).acme_settings) return replica_failures diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index f304bbc394..a0a724dbea 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -7,7 +7,7 @@ from typing_extensions import Annotated from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.proxy.lib.errors import UnexpectedProxyError @@ -27,6 +27,7 @@ class Replica(ImmutableModel): # Optional outer proxy, a head node/bastion ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None + internal_ip: Optional[str] = None class IPAddressPartitioningKey(ImmutableModel): @@ -58,7 +59,7 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] - router: Optional[AnyRouterConfig] = None + router: Optional[AnyServiceRouterConfig] = None cors_enabled: bool = False # only used on gateways; enabled for openai-format models @property diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index d4f1c831e8..9bc7a1f903 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -9,7 +9,7 @@ from dstack._internal.core.errors import GatewayError from dstack._internal.core.models.configurations import RateLimit from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.routers import AnyRouterConfig +from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port from dstack._internal.proxy.gateway.schemas.stats import ServiceStats from dstack._internal.server import settings @@ -46,7 +46,7 @@ async def register_service( options: dict, rate_limits: list[RateLimit], ssh_private_key: str, - router: Optional[AnyRouterConfig] = None, + router: Optional[AnyServiceRouterConfig] = None, ): if "openai" in options: entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}" @@ -99,6 +99,7 @@ async def register_replica( assert jpd is not None assert jpd.hostname is not None assert jpd.ssh_port is not None + payload["internal_ip"] = jpd.internal_ip if not jpd.dockerized: payload.update( { diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 7f1564fe62..385c9e654f 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -111,6 +111,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ssh_proxy=ssh_proxy, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, + internal_ip=jpd.internal_ip, ) replicas.append(replica) return Service( diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 511cf7cc93..b701b822b0 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -26,6 +26,11 @@ ) from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.routers import ( + AnyServiceRouterConfig, + RouterType, + SGLangServiceRouterConfig, +) from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.core.models.services import OpenAIChatModel from dstack._internal.server import settings @@ -45,6 +50,41 @@ logger = get_logger(__name__) +def _gateway_has_sglang_router(config: GatewayConfiguration) -> bool: + return config.router is not None and config.router.type == RouterType.SGLANG.value + + +def _build_service_router_config( + gateway_configuration: GatewayConfiguration, + service_configuration: ServiceConfiguration, +) -> Optional[AnyServiceRouterConfig]: + """ + Build router config from gateway (type, policy) + service (pd_disaggregation, policy override). + Service's policy overrides gateway's if present. Keeps backward compat: SGLang enabled + automatically when gateway has it configured. + """ + if not _gateway_has_sglang_router(gateway_configuration): + return None + + gateway_router = gateway_configuration.router + assert gateway_router is not None # ensured by _gateway_has_sglang_router + router_type = gateway_router.type + policy = gateway_router.policy + + service_router = service_configuration.router + if service_router is not None and isinstance(service_router, SGLangServiceRouterConfig): + policy = service_router.policy + pd_disaggregation = service_router.pd_disaggregation + else: + pd_disaggregation = False + + return SGLangServiceRouterConfig( + type=router_type, + policy=policy, + pd_disaggregation=pd_disaggregation, + ) + + async def register_service(session: AsyncSession, run_model: RunModel, run_spec: RunSpec): assert isinstance(run_spec.configuration, ServiceConfiguration) @@ -92,8 +132,20 @@ async def _register_service_in_gateway( raise ServerClientError("Gateway status is not running") gateway_configuration = get_gateway_configuration(gateway) + + # Check: service specifies SGLang router but gateway does not have it + service_router = run_spec.configuration.router + service_wants_sglang = service_router is not None and isinstance( + service_router, SGLangServiceRouterConfig + ) + if service_wants_sglang and not _gateway_has_sglang_router(gateway_configuration): + raise ServerClientError( + "Service requires gateway with SGLang router but gateway " + f"'{gateway.name}' does not have the SGLang router configured." + ) + service_https = _get_service_https(run_spec, gateway_configuration) - router = gateway_configuration.router + router = _build_service_router_config(gateway_configuration, run_spec.configuration) service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: @@ -158,6 +210,14 @@ async def _register_service_in_gateway( def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec: assert run_spec.configuration.type == "service" + if ( + run_spec.configuration.router is not None + and run_spec.configuration.router.type == RouterType.SGLANG + ): + raise ServerClientError( + "Service with SGLang router configuration requires a gateway. " + "Please configure a gateway with the SGLang router enabled." + ) if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: # Note: if the user sets `https: `, it will be ignored silently # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted