|
| 1 | +"""SGLang model provider (native API). |
| 2 | +
|
| 3 | +This provider integrates with the SGLang Runtime **native** HTTP APIs, primarily: |
| 4 | +- `/generate` for text generation (supports SSE streaming) |
| 5 | +- `/tokenize` for tokenizing a prompt (optional; used for token-out prompt ids) |
| 6 | +
|
| 7 | +Docs: |
| 8 | +- https://docs.sglang.io/basic_usage/native_api.html |
| 9 | +
|
| 10 | +Notes: |
| 11 | +----- |
| 12 | +`/generate` is completion-style: it accepts a single prompt (or input token IDs) and returns a single completion. |
| 13 | +Strands uses a message-based interface, so this provider serializes text-only conversations into a single prompt. |
| 14 | +Tool calling is not supported via `/generate`. |
| 15 | +""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import json |
| 20 | +import logging |
| 21 | +from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypedDict, TypeVar, Union, cast |
| 22 | + |
| 23 | +import httpx |
| 24 | +from pydantic import BaseModel |
| 25 | +from typing_extensions import Unpack, override |
| 26 | + |
| 27 | +from ..types.content import Messages, SystemContentBlock |
| 28 | +from ..types.event_loop import Metrics, Usage |
| 29 | +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException |
| 30 | +from ..types.streaming import StreamEvent |
| 31 | +from ..types.tools import ToolChoice, ToolSpec |
| 32 | +from ._validation import validate_config_keys |
| 33 | +from .model import Model |
| 34 | + |
| 35 | +logger = logging.getLogger(__name__) |
| 36 | + |
| 37 | +T = TypeVar("T", bound=BaseModel) |
| 38 | + |
| 39 | + |
| 40 | +class SGLangModel(Model): |
| 41 | + """SGLang native `/generate` provider with token-in/out helpers.""" |
| 42 | + |
| 43 | + class SGLangConfig(TypedDict, total=False): |
| 44 | + """Configuration options for SGLang native API models.""" |
| 45 | + |
| 46 | + base_url: str |
| 47 | + model_id: Optional[str] |
| 48 | + params: Optional[dict[str, Any]] # default sampling params (merged into sampling_params) |
| 49 | + timeout: Optional[Union[float, tuple[float, float]]] |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + *, |
| 54 | + return_token_ids: bool = False, |
| 55 | + **model_config: Unpack[SGLangConfig], |
| 56 | + ) -> None: |
| 57 | + """Create an SGLang model client.""" |
| 58 | + validate_config_keys(model_config, self.SGLangConfig) |
| 59 | + |
| 60 | + base_url = str(model_config.get("base_url") or "http://localhost:30000").rstrip("/") |
| 61 | + timeout = model_config.get("timeout") |
| 62 | + if isinstance(timeout, tuple): |
| 63 | + timeout_obj = httpx.Timeout(connect=timeout[0], read=timeout[1]) |
| 64 | + else: |
| 65 | + timeout_obj = httpx.Timeout(timeout or 30.0) |
| 66 | + |
| 67 | + self.client = httpx.AsyncClient(base_url=base_url, timeout=timeout_obj) |
| 68 | + self.config = dict(model_config) |
| 69 | + self.config["base_url"] = base_url |
| 70 | + self._return_token_ids_default = bool(return_token_ids) |
| 71 | + |
| 72 | + logger.debug("config=<%s> | initializing", self.config) |
| 73 | + |
| 74 | + @override |
| 75 | + def update_config(self, **model_config: Unpack[SGLangConfig]) -> None: # type: ignore[override] |
| 76 | + validate_config_keys(model_config, self.SGLangConfig) |
| 77 | + if "base_url" in model_config and model_config["base_url"]: |
| 78 | + # Preserve base_url canonicalization |
| 79 | + self.config["base_url"] = str(model_config["base_url"]).rstrip("/") |
| 80 | + self.config.update(model_config) |
| 81 | + |
| 82 | + @override |
| 83 | + def get_config(self) -> SGLangConfig: |
| 84 | + return cast(SGLangModel.SGLangConfig, self.config) |
| 85 | + |
| 86 | + def _messages_to_prompt( |
| 87 | + self, |
| 88 | + messages: Messages, |
| 89 | + system_prompt: Optional[str], |
| 90 | + *, |
| 91 | + system_prompt_content: Optional[list[SystemContentBlock]] = None, |
| 92 | + ) -> str: |
| 93 | + # Only support text content blocks. Tools and multimodal content are not supported via /generate. |
| 94 | + def text_from_blocks(role: str, blocks: list[dict[str, Any]]) -> str: |
| 95 | + parts: list[str] = [] |
| 96 | + for block in blocks: |
| 97 | + if "text" in block: |
| 98 | + parts.append(str(block["text"])) |
| 99 | + else: |
| 100 | + raise TypeError(f"SGLangModel only supports text content blocks. got role={role} block={block}") |
| 101 | + return "".join(parts) |
| 102 | + |
| 103 | + # Back-compat: if system_prompt is provided but system_prompt_content is None. |
| 104 | + if system_prompt and system_prompt_content is None: |
| 105 | + system_prompt_content = [{"text": system_prompt}] |
| 106 | + |
| 107 | + lines: list[str] = [] |
| 108 | + for block in system_prompt_content or []: |
| 109 | + if "text" in block: |
| 110 | + lines.append(f"system: {block['text']}") |
| 111 | + |
| 112 | + for msg in messages: |
| 113 | + role = msg.get("role", "user") |
| 114 | + content = msg.get("content", []) |
| 115 | + # Reject tool/multimodal blocks early |
| 116 | + if any(k in b for b in content for k in ("toolUse", "toolResult", "image", "document", "reasoningContent")): |
| 117 | + raise TypeError("SGLangModel /generate does not support tools or multimodal message blocks.") |
| 118 | + text = text_from_blocks(str(role), cast(list[dict[str, Any]], content)) |
| 119 | + if text.strip(): |
| 120 | + lines.append(f"{role}: {text}") |
| 121 | + |
| 122 | + # Add a final assistant prefix to make the completion shape stable. |
| 123 | + lines.append("assistant:") |
| 124 | + return "\n".join(lines).strip() + "\n" |
| 125 | + |
| 126 | + async def _tokenize(self, prompt: str) -> list[int]: |
| 127 | + model_id = self.get_config().get("model_id") |
| 128 | + payload: dict[str, Any] = { |
| 129 | + "prompt": prompt, |
| 130 | + "add_special_tokens": False, |
| 131 | + } |
| 132 | + if model_id: |
| 133 | + payload["model"] = model_id |
| 134 | + |
| 135 | + resp = await self.client.post("/tokenize", json=payload) |
| 136 | + resp.raise_for_status() |
| 137 | + data = resp.json() |
| 138 | + tokens = data.get("tokens") |
| 139 | + if not isinstance(tokens, list) or not all(isinstance(x, int) for x in tokens): |
| 140 | + raise ValueError(f"Unexpected /tokenize response: {data}") |
| 141 | + return cast(list[int], tokens) |
| 142 | + |
| 143 | + def _build_generate_payload( |
| 144 | + self, |
| 145 | + *, |
| 146 | + prompt: Optional[str], |
| 147 | + prompt_token_ids: Optional[list[int]], |
| 148 | + sampling_params: dict[str, Any], |
| 149 | + stream: bool, |
| 150 | + ) -> dict[str, Any]: |
| 151 | + model_id = self.get_config().get("model_id") |
| 152 | + payload: dict[str, Any] = {"stream": stream} |
| 153 | + |
| 154 | + if model_id: |
| 155 | + payload["model"] = model_id |
| 156 | + |
| 157 | + if prompt_token_ids is not None: |
| 158 | + payload["input_ids"] = prompt_token_ids |
| 159 | + else: |
| 160 | + payload["text"] = prompt or "" |
| 161 | + |
| 162 | + if sampling_params: |
| 163 | + payload["sampling_params"] = sampling_params |
| 164 | + |
| 165 | + return payload |
| 166 | + |
| 167 | + @override |
| 168 | + async def stream( |
| 169 | + self, |
| 170 | + messages: Messages, |
| 171 | + tool_specs: Optional[list[ToolSpec]] = None, |
| 172 | + system_prompt: Optional[str] = None, |
| 173 | + *, |
| 174 | + tool_choice: ToolChoice | None = None, |
| 175 | + system_prompt_content: list[SystemContentBlock] | None = None, |
| 176 | + **kwargs: Any, |
| 177 | + ) -> AsyncIterable[StreamEvent]: |
| 178 | + if tool_specs is not None or tool_choice is not None: |
| 179 | + raise TypeError("SGLangModel /generate does not support tool_specs/tool_choice.") |
| 180 | + |
| 181 | + return_token_ids = bool(kwargs.pop("return_token_ids", self._return_token_ids_default)) |
| 182 | + prompt_token_ids = kwargs.pop("prompt_token_ids", None) |
| 183 | + if prompt_token_ids is not None: |
| 184 | + if ( |
| 185 | + not isinstance(prompt_token_ids, list) |
| 186 | + or not prompt_token_ids |
| 187 | + or not all(isinstance(x, int) for x in prompt_token_ids) |
| 188 | + ): |
| 189 | + raise TypeError("prompt_token_ids must be a non-empty list[int].") |
| 190 | + prompt_token_ids = cast(list[int], prompt_token_ids) |
| 191 | + |
| 192 | + sampling_params: dict[str, Any] = {} |
| 193 | + cfg_params = self.get_config().get("params") |
| 194 | + if isinstance(cfg_params, dict): |
| 195 | + sampling_params.update(cfg_params) |
| 196 | + |
| 197 | + if "sampling_params" in kwargs: |
| 198 | + sp = kwargs.pop("sampling_params") |
| 199 | + if sp is not None: |
| 200 | + if not isinstance(sp, dict): |
| 201 | + raise TypeError("sampling_params must be a dict when provided.") |
| 202 | + sampling_params.update(cast(dict[str, Any], sp)) |
| 203 | + |
| 204 | + sampling_params.update(kwargs) |
| 205 | + |
| 206 | + prompt_text: str | None = None |
| 207 | + prompt_token_ids_out: list[int] | None = None |
| 208 | + if prompt_token_ids is None: |
| 209 | + prompt_text = self._messages_to_prompt(messages, system_prompt, system_prompt_content=system_prompt_content) |
| 210 | + if return_token_ids: |
| 211 | + try: |
| 212 | + prompt_token_ids_out = await self._tokenize(prompt_text) |
| 213 | + except httpx.HTTPStatusError as e: |
| 214 | + if e.response.status_code == 429: |
| 215 | + raise ModelThrottledException(str(e)) from e |
| 216 | + raise |
| 217 | + |
| 218 | + payload = self._build_generate_payload( |
| 219 | + prompt=prompt_text, |
| 220 | + prompt_token_ids=prompt_token_ids, |
| 221 | + sampling_params=sampling_params, |
| 222 | + stream=True, |
| 223 | + ) |
| 224 | + |
| 225 | + yield {"messageStart": {"role": "assistant"}} |
| 226 | + yield {"contentBlockStart": {"start": {}}} |
| 227 | + |
| 228 | + prev_text = "" |
| 229 | + last_output_ids: list[int] = [] |
| 230 | + last_meta: dict[str, Any] | None = None |
| 231 | + |
| 232 | + try: |
| 233 | + async with self.client.stream("POST", "/generate", json=payload) as resp: |
| 234 | + resp.raise_for_status() |
| 235 | + |
| 236 | + async for line in resp.aiter_lines(): |
| 237 | + if not line: |
| 238 | + continue |
| 239 | + if not line.startswith("data:"): |
| 240 | + continue |
| 241 | + data_content = line[len("data:") :].strip() |
| 242 | + if data_content == "[DONE]": |
| 243 | + break |
| 244 | + try: |
| 245 | + event = json.loads(data_content) |
| 246 | + except json.JSONDecodeError: |
| 247 | + continue |
| 248 | + |
| 249 | + new_text = event.get("text") |
| 250 | + if isinstance(new_text, str): |
| 251 | + if new_text.startswith(prev_text): |
| 252 | + delta = new_text[len(prev_text) :] |
| 253 | + else: |
| 254 | + delta = new_text |
| 255 | + prev_text = new_text |
| 256 | + if delta: |
| 257 | + yield {"contentBlockDelta": {"delta": {"text": delta}}} |
| 258 | + |
| 259 | + output_ids = event.get("output_ids") |
| 260 | + if isinstance(output_ids, list) and all(isinstance(x, int) for x in output_ids): |
| 261 | + last_output_ids = cast(list[int], output_ids) |
| 262 | + |
| 263 | + meta = event.get("meta_info") |
| 264 | + if isinstance(meta, dict): |
| 265 | + last_meta = cast(dict[str, Any], meta) |
| 266 | + |
| 267 | + except httpx.HTTPStatusError as e: |
| 268 | + status = e.response.status_code |
| 269 | + if status == 400: |
| 270 | + raise ContextWindowOverflowException(str(e)) from e |
| 271 | + if status in (429, 503): |
| 272 | + raise ModelThrottledException(str(e)) from e |
| 273 | + raise |
| 274 | + |
| 275 | + yield {"contentBlockStop": {}} |
| 276 | + |
| 277 | + additional: dict[str, Any] = {} |
| 278 | + if prompt_token_ids is not None: |
| 279 | + additional["prompt_token_ids"] = prompt_token_ids |
| 280 | + elif prompt_token_ids_out is not None: |
| 281 | + additional["prompt_token_ids"] = prompt_token_ids_out |
| 282 | + if last_output_ids: |
| 283 | + additional["token_ids"] = last_output_ids |
| 284 | + |
| 285 | + stop_reason: str = "end_turn" |
| 286 | + if last_meta and isinstance(last_meta.get("finish_reason"), dict): |
| 287 | + fr = cast(dict[str, Any], last_meta.get("finish_reason")) |
| 288 | + if fr.get("type") == "length": |
| 289 | + stop_reason = "max_tokens" |
| 290 | + |
| 291 | + yield {"messageStop": {"stopReason": cast(Any, stop_reason), "additionalModelResponseFields": additional}} |
| 292 | + |
| 293 | + if last_meta: |
| 294 | + usage: Usage = { |
| 295 | + "inputTokens": int(last_meta.get("prompt_tokens") or 0), |
| 296 | + "outputTokens": int(last_meta.get("completion_tokens") or 0), |
| 297 | + "totalTokens": int((last_meta.get("prompt_tokens") or 0) + (last_meta.get("completion_tokens") or 0)), |
| 298 | + } |
| 299 | + latency_ms = int(float(last_meta.get("e2e_latency") or 0.0) * 1000) |
| 300 | + metrics: Metrics = {"latencyMs": latency_ms} |
| 301 | + yield {"metadata": {"usage": usage, "metrics": metrics}} |
| 302 | + |
| 303 | + @override |
| 304 | + async def structured_output( |
| 305 | + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any |
| 306 | + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: |
| 307 | + instruction = ( |
| 308 | + "Return ONLY valid JSON that matches the schema. Do not include any extra keys or prose.\n" |
| 309 | + f"Schema: {output_model.model_json_schema()}\n" |
| 310 | + ) |
| 311 | + prompt2: Messages = [ |
| 312 | + {"role": "user", "content": [{"text": instruction}]}, |
| 313 | + *prompt, |
| 314 | + ] |
| 315 | + |
| 316 | + text = "" |
| 317 | + async for event in self.stream( |
| 318 | + prompt2, |
| 319 | + tool_specs=None, |
| 320 | + system_prompt=system_prompt, |
| 321 | + system_prompt_content=kwargs.pop("system_prompt_content", None), |
| 322 | + **kwargs, |
| 323 | + ): |
| 324 | + if "contentBlockDelta" in event: |
| 325 | + delta = event["contentBlockDelta"]["delta"] |
| 326 | + if "text" in delta: |
| 327 | + text += delta["text"] |
| 328 | + |
| 329 | + try: |
| 330 | + yield {"output": output_model.model_validate_json(text.strip())} |
| 331 | + except Exception as e: |
| 332 | + raise ValueError(f"Failed to parse structured output JSON: {e}") from e |
0 commit comments