Skip to content

Commit f0ef731

Browse files
committed
feat(models): add sglang provider and vllm token ids
Add SGLang native /generate provider with token-in/out and SSE streaming. Refactor vLLM provider for token-in/out, tool-use streaming, and preserve provider fields. Update event loop streaming to carry additionalModelResponseFields.
1 parent 2944abf commit f0ef731

File tree

11 files changed

+1611
-13
lines changed

11 files changed

+1611
-13
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,13 +338,20 @@ async def _handle_model_execution(
338338
else:
339339
tool_specs = agent.tool_registry.get_all_tool_specs()
340340
try:
341+
model_kwargs = invocation_state.get("model_kwargs", {})
342+
if model_kwargs is None:
343+
model_kwargs = {}
344+
if not isinstance(model_kwargs, dict):
345+
raise TypeError("invocation_state['model_kwargs'] must be a dict if provided.")
346+
341347
async for event in stream_messages(
342348
agent.model,
343349
agent.system_prompt,
344350
agent.messages,
345351
tool_specs,
346352
system_prompt_content=agent._system_prompt_content,
347353
tool_choice=structured_output_context.tool_choice,
354+
**model_kwargs,
348355
):
349356
yield event
350357

src/strands/event_loop/streaming.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,12 @@ async def process_stream(
404404
elif "contentBlockStop" in chunk:
405405
state = handle_content_block_stop(state)
406406
elif "messageStop" in chunk:
407-
stop_reason = handle_message_stop(chunk["messageStop"])
407+
message_stop = chunk["messageStop"]
408+
stop_reason = handle_message_stop(message_stop)
409+
additional_fields = message_stop.get("additionalModelResponseFields")
410+
if additional_fields is not None:
411+
# Preserve provider-specific response fields (e.g., token IDs/logprobs) for downstream consumers.
412+
state["message"]["additionalModelResponseFields"] = additional_fields
408413
elif "metadata" in chunk:
409414
time_to_first_byte_ms = (
410415
int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None
@@ -452,6 +457,7 @@ async def stream_messages(
452457
system_prompt,
453458
tool_choice=tool_choice,
454459
system_prompt_content=system_prompt_content,
460+
**kwargs,
455461
)
456462

457463
async for event in process_stream(chunks, start_time):

src/strands/models/sglang.py

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
"""SGLang model provider (native API).
2+
3+
This provider integrates with the SGLang Runtime **native** HTTP APIs, primarily:
4+
- `/generate` for text generation (supports SSE streaming)
5+
- `/tokenize` for tokenizing a prompt (optional; used for token-out prompt ids)
6+
7+
Docs:
8+
- https://docs.sglang.io/basic_usage/native_api.html
9+
10+
Notes:
11+
-----
12+
`/generate` is completion-style: it accepts a single prompt (or input token IDs) and returns a single completion.
13+
Strands uses a message-based interface, so this provider serializes text-only conversations into a single prompt.
14+
Tool calling is not supported via `/generate`.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import json
20+
import logging
21+
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypedDict, TypeVar, Union, cast
22+
23+
import httpx
24+
from pydantic import BaseModel
25+
from typing_extensions import Unpack, override
26+
27+
from ..types.content import Messages, SystemContentBlock
28+
from ..types.event_loop import Metrics, Usage
29+
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
30+
from ..types.streaming import StreamEvent
31+
from ..types.tools import ToolChoice, ToolSpec
32+
from ._validation import validate_config_keys
33+
from .model import Model
34+
35+
logger = logging.getLogger(__name__)
36+
37+
T = TypeVar("T", bound=BaseModel)
38+
39+
40+
class SGLangModel(Model):
41+
"""SGLang native `/generate` provider with token-in/out helpers."""
42+
43+
class SGLangConfig(TypedDict, total=False):
44+
"""Configuration options for SGLang native API models."""
45+
46+
base_url: str
47+
model_id: Optional[str]
48+
params: Optional[dict[str, Any]] # default sampling params (merged into sampling_params)
49+
timeout: Optional[Union[float, tuple[float, float]]]
50+
51+
def __init__(
52+
self,
53+
*,
54+
return_token_ids: bool = False,
55+
**model_config: Unpack[SGLangConfig],
56+
) -> None:
57+
"""Create an SGLang model client."""
58+
validate_config_keys(model_config, self.SGLangConfig)
59+
60+
base_url = str(model_config.get("base_url") or "http://localhost:30000").rstrip("/")
61+
timeout = model_config.get("timeout")
62+
if isinstance(timeout, tuple):
63+
timeout_obj = httpx.Timeout(connect=timeout[0], read=timeout[1])
64+
else:
65+
timeout_obj = httpx.Timeout(timeout or 30.0)
66+
67+
self.client = httpx.AsyncClient(base_url=base_url, timeout=timeout_obj)
68+
self.config = dict(model_config)
69+
self.config["base_url"] = base_url
70+
self._return_token_ids_default = bool(return_token_ids)
71+
72+
logger.debug("config=<%s> | initializing", self.config)
73+
74+
@override
75+
def update_config(self, **model_config: Unpack[SGLangConfig]) -> None: # type: ignore[override]
76+
validate_config_keys(model_config, self.SGLangConfig)
77+
if "base_url" in model_config and model_config["base_url"]:
78+
# Preserve base_url canonicalization
79+
self.config["base_url"] = str(model_config["base_url"]).rstrip("/")
80+
self.config.update(model_config)
81+
82+
@override
83+
def get_config(self) -> SGLangConfig:
84+
return cast(SGLangModel.SGLangConfig, self.config)
85+
86+
def _messages_to_prompt(
87+
self,
88+
messages: Messages,
89+
system_prompt: Optional[str],
90+
*,
91+
system_prompt_content: Optional[list[SystemContentBlock]] = None,
92+
) -> str:
93+
# Only support text content blocks. Tools and multimodal content are not supported via /generate.
94+
def text_from_blocks(role: str, blocks: list[dict[str, Any]]) -> str:
95+
parts: list[str] = []
96+
for block in blocks:
97+
if "text" in block:
98+
parts.append(str(block["text"]))
99+
else:
100+
raise TypeError(f"SGLangModel only supports text content blocks. got role={role} block={block}")
101+
return "".join(parts)
102+
103+
# Back-compat: if system_prompt is provided but system_prompt_content is None.
104+
if system_prompt and system_prompt_content is None:
105+
system_prompt_content = [{"text": system_prompt}]
106+
107+
lines: list[str] = []
108+
for block in system_prompt_content or []:
109+
if "text" in block:
110+
lines.append(f"system: {block['text']}")
111+
112+
for msg in messages:
113+
role = msg.get("role", "user")
114+
content = msg.get("content", [])
115+
# Reject tool/multimodal blocks early
116+
if any(k in b for b in content for k in ("toolUse", "toolResult", "image", "document", "reasoningContent")):
117+
raise TypeError("SGLangModel /generate does not support tools or multimodal message blocks.")
118+
text = text_from_blocks(str(role), cast(list[dict[str, Any]], content))
119+
if text.strip():
120+
lines.append(f"{role}: {text}")
121+
122+
# Add a final assistant prefix to make the completion shape stable.
123+
lines.append("assistant:")
124+
return "\n".join(lines).strip() + "\n"
125+
126+
async def _tokenize(self, prompt: str) -> list[int]:
127+
model_id = self.get_config().get("model_id")
128+
payload: dict[str, Any] = {
129+
"prompt": prompt,
130+
"add_special_tokens": False,
131+
}
132+
if model_id:
133+
payload["model"] = model_id
134+
135+
resp = await self.client.post("/tokenize", json=payload)
136+
resp.raise_for_status()
137+
data = resp.json()
138+
tokens = data.get("tokens")
139+
if not isinstance(tokens, list) or not all(isinstance(x, int) for x in tokens):
140+
raise ValueError(f"Unexpected /tokenize response: {data}")
141+
return cast(list[int], tokens)
142+
143+
def _build_generate_payload(
144+
self,
145+
*,
146+
prompt: Optional[str],
147+
prompt_token_ids: Optional[list[int]],
148+
sampling_params: dict[str, Any],
149+
stream: bool,
150+
) -> dict[str, Any]:
151+
model_id = self.get_config().get("model_id")
152+
payload: dict[str, Any] = {"stream": stream}
153+
154+
if model_id:
155+
payload["model"] = model_id
156+
157+
if prompt_token_ids is not None:
158+
payload["input_ids"] = prompt_token_ids
159+
else:
160+
payload["text"] = prompt or ""
161+
162+
if sampling_params:
163+
payload["sampling_params"] = sampling_params
164+
165+
return payload
166+
167+
@override
168+
async def stream(
169+
self,
170+
messages: Messages,
171+
tool_specs: Optional[list[ToolSpec]] = None,
172+
system_prompt: Optional[str] = None,
173+
*,
174+
tool_choice: ToolChoice | None = None,
175+
system_prompt_content: list[SystemContentBlock] | None = None,
176+
**kwargs: Any,
177+
) -> AsyncIterable[StreamEvent]:
178+
if tool_specs is not None or tool_choice is not None:
179+
raise TypeError("SGLangModel /generate does not support tool_specs/tool_choice.")
180+
181+
return_token_ids = bool(kwargs.pop("return_token_ids", self._return_token_ids_default))
182+
prompt_token_ids = kwargs.pop("prompt_token_ids", None)
183+
if prompt_token_ids is not None:
184+
if (
185+
not isinstance(prompt_token_ids, list)
186+
or not prompt_token_ids
187+
or not all(isinstance(x, int) for x in prompt_token_ids)
188+
):
189+
raise TypeError("prompt_token_ids must be a non-empty list[int].")
190+
prompt_token_ids = cast(list[int], prompt_token_ids)
191+
192+
sampling_params: dict[str, Any] = {}
193+
cfg_params = self.get_config().get("params")
194+
if isinstance(cfg_params, dict):
195+
sampling_params.update(cfg_params)
196+
197+
if "sampling_params" in kwargs:
198+
sp = kwargs.pop("sampling_params")
199+
if sp is not None:
200+
if not isinstance(sp, dict):
201+
raise TypeError("sampling_params must be a dict when provided.")
202+
sampling_params.update(cast(dict[str, Any], sp))
203+
204+
sampling_params.update(kwargs)
205+
206+
prompt_text: str | None = None
207+
prompt_token_ids_out: list[int] | None = None
208+
if prompt_token_ids is None:
209+
prompt_text = self._messages_to_prompt(messages, system_prompt, system_prompt_content=system_prompt_content)
210+
if return_token_ids:
211+
try:
212+
prompt_token_ids_out = await self._tokenize(prompt_text)
213+
except httpx.HTTPStatusError as e:
214+
if e.response.status_code == 429:
215+
raise ModelThrottledException(str(e)) from e
216+
raise
217+
218+
payload = self._build_generate_payload(
219+
prompt=prompt_text,
220+
prompt_token_ids=prompt_token_ids,
221+
sampling_params=sampling_params,
222+
stream=True,
223+
)
224+
225+
yield {"messageStart": {"role": "assistant"}}
226+
yield {"contentBlockStart": {"start": {}}}
227+
228+
prev_text = ""
229+
last_output_ids: list[int] = []
230+
last_meta: dict[str, Any] | None = None
231+
232+
try:
233+
async with self.client.stream("POST", "/generate", json=payload) as resp:
234+
resp.raise_for_status()
235+
236+
async for line in resp.aiter_lines():
237+
if not line:
238+
continue
239+
if not line.startswith("data:"):
240+
continue
241+
data_content = line[len("data:") :].strip()
242+
if data_content == "[DONE]":
243+
break
244+
try:
245+
event = json.loads(data_content)
246+
except json.JSONDecodeError:
247+
continue
248+
249+
new_text = event.get("text")
250+
if isinstance(new_text, str):
251+
if new_text.startswith(prev_text):
252+
delta = new_text[len(prev_text) :]
253+
else:
254+
delta = new_text
255+
prev_text = new_text
256+
if delta:
257+
yield {"contentBlockDelta": {"delta": {"text": delta}}}
258+
259+
output_ids = event.get("output_ids")
260+
if isinstance(output_ids, list) and all(isinstance(x, int) for x in output_ids):
261+
last_output_ids = cast(list[int], output_ids)
262+
263+
meta = event.get("meta_info")
264+
if isinstance(meta, dict):
265+
last_meta = cast(dict[str, Any], meta)
266+
267+
except httpx.HTTPStatusError as e:
268+
status = e.response.status_code
269+
if status == 400:
270+
raise ContextWindowOverflowException(str(e)) from e
271+
if status in (429, 503):
272+
raise ModelThrottledException(str(e)) from e
273+
raise
274+
275+
yield {"contentBlockStop": {}}
276+
277+
additional: dict[str, Any] = {}
278+
if prompt_token_ids is not None:
279+
additional["prompt_token_ids"] = prompt_token_ids
280+
elif prompt_token_ids_out is not None:
281+
additional["prompt_token_ids"] = prompt_token_ids_out
282+
if last_output_ids:
283+
additional["token_ids"] = last_output_ids
284+
285+
stop_reason: str = "end_turn"
286+
if last_meta and isinstance(last_meta.get("finish_reason"), dict):
287+
fr = cast(dict[str, Any], last_meta.get("finish_reason"))
288+
if fr.get("type") == "length":
289+
stop_reason = "max_tokens"
290+
291+
yield {"messageStop": {"stopReason": cast(Any, stop_reason), "additionalModelResponseFields": additional}}
292+
293+
if last_meta:
294+
usage: Usage = {
295+
"inputTokens": int(last_meta.get("prompt_tokens") or 0),
296+
"outputTokens": int(last_meta.get("completion_tokens") or 0),
297+
"totalTokens": int((last_meta.get("prompt_tokens") or 0) + (last_meta.get("completion_tokens") or 0)),
298+
}
299+
latency_ms = int(float(last_meta.get("e2e_latency") or 0.0) * 1000)
300+
metrics: Metrics = {"latencyMs": latency_ms}
301+
yield {"metadata": {"usage": usage, "metrics": metrics}}
302+
303+
@override
304+
async def structured_output(
305+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
306+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
307+
instruction = (
308+
"Return ONLY valid JSON that matches the schema. Do not include any extra keys or prose.\n"
309+
f"Schema: {output_model.model_json_schema()}\n"
310+
)
311+
prompt2: Messages = [
312+
{"role": "user", "content": [{"text": instruction}]},
313+
*prompt,
314+
]
315+
316+
text = ""
317+
async for event in self.stream(
318+
prompt2,
319+
tool_specs=None,
320+
system_prompt=system_prompt,
321+
system_prompt_content=kwargs.pop("system_prompt_content", None),
322+
**kwargs,
323+
):
324+
if "contentBlockDelta" in event:
325+
delta = event["contentBlockDelta"]["delta"]
326+
if "text" in delta:
327+
text += delta["text"]
328+
329+
try:
330+
yield {"output": output_model.model_validate_json(text.strip())}
331+
except Exception as e:
332+
raise ValueError(f"Failed to parse structured output JSON: {e}") from e

0 commit comments

Comments
 (0)