Skip to content
2 changes: 1 addition & 1 deletion docs/docs/reference/dstack.yml/gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The `gateway` configuration type allows creating and updating [gateways](../../c

=== "SGLang Model Gateway"

#SCHEMA# dstack._internal.core.models.routers.SGLangRouterConfig
#SCHEMA# dstack._internal.core.models.routers.SGLangGatewayRouterConfig
overrides:
show_root_heading: false
type:
Expand Down
2 changes: 1 addition & 1 deletion gateway/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
]

[project.optional-dependencies]
sglang = ["sglang-router==0.2.1"]
sglang = ["sglang-router==0.3.2"]

[tool.setuptools.package-data]
"dstack.gateway" = [
Expand Down
10 changes: 6 additions & 4 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import AnyGatewayRouterConfig
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import (
Volume,
Expand Down Expand Up @@ -924,7 +924,9 @@ def get_run_shim_script(
]


def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str:
def get_gateway_user_data(
authorized_key: str, router: Optional[AnyGatewayRouterConfig] = None
) -> str:
return get_cloud_config(
package_update=True,
packages=[
Expand Down Expand Up @@ -1036,7 +1038,7 @@ def get_latest_runner_build() -> Optional[str]:
return None


def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str:
def get_dstack_gateway_wheel(build: str, router: Optional[AnyGatewayRouterConfig] = None) -> str:
channel = "release" if settings.DSTACK_RELEASE else "stgn"
base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}"
if build == "latest":
Expand All @@ -1049,7 +1051,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non
return f"dstack-gateway @ {wheel}"


def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]:
def get_dstack_gateway_commands(router: Optional[AnyGatewayRouterConfig] = None) -> List[str]:
build = get_dstack_runner_version() or "latest"
gateway_package = get_dstack_gateway_wheel(build, router)
return [
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)
from dstack._internal.core.models.placement import PlacementGroup
from dstack._internal.core.models.resources import CPUSpec, GPUSpec
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import AnyGatewayRouterConfig
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import Volume
from dstack._internal.utils.common import get_or_error
Expand Down Expand Up @@ -864,7 +864,7 @@ def _wait_for_load_balancer_address(


def _get_gateway_commands(
authorized_keys: List[str], router: Optional[AnyRouterConfig] = None
authorized_keys: List[str], router: Optional[AnyGatewayRouterConfig] = None
) -> List[str]:
authorized_keys_content = "\n".join(authorized_keys).strip()
gateway_commands = " && ".join(get_dstack_gateway_commands(router=router))
Expand Down
6 changes: 2 additions & 4 deletions src/dstack/_internal/core/compatibility/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def _get_gateway_configuration_excludes(
) -> IncludeExcludeDictType:
configuration_excludes: IncludeExcludeDictType = {}

# Add excludes like this:
#
# if configuration.tags is None:
# configuration_excludes["tags"] = True
if configuration.router is None:
configuration_excludes["router"] = True

return configuration_excludes
7 changes: 7 additions & 0 deletions src/dstack/_internal/core/compatibility/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType
from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.routers import SGLangServiceRouterConfig
from dstack._internal.core.models.runs import (
DEFAULT_PROBE_UNTIL_READY,
DEFAULT_REPLICA_GROUP_NAME,
Expand Down Expand Up @@ -72,6 +73,12 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
# Servers prior to 0.20.8 do not support probes=None
configuration_excludes["probes"] = True

router = run_spec.configuration.router
if router is None:
configuration_excludes["router"] = True
elif isinstance(router, SGLangServiceRouterConfig) and router.pd_disaggregation is False:
configuration_excludes["router"] = {"pd_disaggregation": True}

if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
if profile_excludes:
Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
parse_off_duration,
)
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.core.models.routers import AnyServiceRouterConfig
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
Expand Down Expand Up @@ -887,6 +888,14 @@ class ServiceConfigurationParams(CoreModel):
)
),
] = None
router: Annotated[
Optional[AnyServiceRouterConfig],
Field(
description=(
"Router configuration for the service. Requires a gateway with matching router enabled. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit)

Suggested change
"Router configuration for the service. Requires a gateway with matching router enabled. "
"Router configuration for the service. Requires a gateway with matching router enabled"

),
),
] = None

@validator("port")
def convert_port(cls, v) -> PortMapping:
Expand Down
13 changes: 9 additions & 4 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import AnyGatewayRouterConfig
from dstack._internal.utils.tags import tags_validator


Expand Down Expand Up @@ -63,8 +63,13 @@ class GatewayConfiguration(CoreModel):
),
] = None
router: Annotated[
Optional[AnyRouterConfig],
Field(description="The router configuration"),
Optional[AnyGatewayRouterConfig],
Field(
description=(
"The router configuration for this gateway. "
"E.g. `{ type: sglang, policy: round_robin }`."
),
),
] = None
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
Expand Down Expand Up @@ -134,7 +139,7 @@ class GatewayComputeConfiguration(CoreModel):
ssh_key_pub: str
certificate: Optional[AnyGatewayCertificate] = None
tags: Optional[Dict[str, str]] = None
router: Optional[AnyRouterConfig] = None
router: Optional[AnyGatewayRouterConfig] = None


class GatewayProvisioningData(CoreModel):
Expand Down
27 changes: 25 additions & 2 deletions src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,37 @@ class RouterType(str, Enum):
SGLANG = "sglang"


class SGLangRouterConfig(CoreModel):
class SGLangGatewayRouterConfig(CoreModel):
"""Gateway-level router configuration. type and policy only. pd_disaggregation is service-level."""

type: Annotated[
Literal["sglang"],
Field(description="The router type enabled on this gateway."),
] = "sglang"
policy: Annotated[
Literal["random", "round_robin", "cache_aware", "power_of_two"],
Field(
description=(
"The routing policy. Deprecated: prefer setting policy in the service's router config. "
"Options: `random`, `round_robin`, `cache_aware`, `power_of_two`"
),
),
] = "cache_aware"


class SGLangServiceRouterConfig(CoreModel):
type: Annotated[Literal["sglang"], Field(description="The router type")] = "sglang"
policy: Annotated[
Literal["random", "round_robin", "cache_aware", "power_of_two"],
Field(
description="The routing policy. Options: `random`, `round_robin`, `cache_aware`, `power_of_two`"
),
] = "cache_aware"
pd_disaggregation: Annotated[
bool,
Field(description="Enable PD disaggregation mode for the SGLang router"),
] = False


AnyRouterConfig = SGLangRouterConfig
AnyServiceRouterConfig = SGLangServiceRouterConfig
AnyGatewayRouterConfig = SGLangGatewayRouterConfig
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def register_replica(
ssh_proxy=body.ssh_proxy,
ssh_head_proxy=body.ssh_head_proxy,
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
internal_ip=body.internal_ip,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
Expand Down
5 changes: 3 additions & 2 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, Field

from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import AnyServiceRouterConfig
from dstack._internal.proxy.lib.models import RateLimit


Expand Down Expand Up @@ -45,7 +45,7 @@ class RegisterServiceRequest(BaseModel):
options: Options
ssh_private_key: str
rate_limits: tuple[RateLimit, ...] = ()
router: Optional[AnyRouterConfig] = None
router: Optional[AnyServiceRouterConfig] = None


class RegisterReplicaRequest(BaseModel):
Expand All @@ -56,6 +56,7 @@ class RegisterReplicaRequest(BaseModel):
ssh_proxy: Optional[SSHConnectionParams]
ssh_head_proxy: Optional[SSHConnectionParams]
ssh_head_proxy_private_key: Optional[str]
internal_ip: Optional[str] = None


class RegisterEntrypointRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dstack._internal.core.models.routers import AnyRouterConfig, RouterType
from dstack._internal.core.models.routers import AnyServiceRouterConfig, RouterType
from dstack._internal.proxy.gateway.services.model_routers.sglang import SglangRouter
from dstack._internal.proxy.lib.errors import ProxyError

from .base import Router, RouterContext


def get_router(router: AnyRouterConfig, context: RouterContext) -> Router:
def get_router(router: AnyServiceRouterConfig, context: RouterContext) -> Router:
if router.type == RouterType.SGLANG:
return SglangRouter(config=router, context=context)
raise ProxyError(f"Router type '{router.type}' is not available")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.routers import AnyServiceRouterConfig


class RouterContext(BaseModel):
Expand All @@ -29,7 +29,7 @@ class Router(ABC):
def __init__(
self,
context: RouterContext,
config: Optional[AnyRouterConfig] = None,
config: Optional[AnyServiceRouterConfig] = None,
):
"""Initialize router with context.

Expand Down
Loading