diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f25057e4d..63437b66b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -338,6 +338,12 @@ async def _handle_model_execution( else: tool_specs = agent.tool_registry.get_all_tool_specs() try: + model_kwargs = invocation_state.get("model_kwargs", {}) + if model_kwargs is None: + model_kwargs = {} + if not isinstance(model_kwargs, dict): + raise TypeError("invocation_state['model_kwargs'] must be a dict if provided.") + async for event in stream_messages( agent.model, agent.system_prompt, @@ -345,6 +351,7 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, + **model_kwargs, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 43836fe34..92bd6fdf3 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -404,7 +404,12 @@ async def process_stream( elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: - stop_reason = handle_message_stop(chunk["messageStop"]) + message_stop = chunk["messageStop"] + stop_reason = handle_message_stop(message_stop) + additional_fields = message_stop.get("additionalModelResponseFields") + if additional_fields is not None: + # Preserve provider-specific response fields (e.g., token IDs/logprobs) for downstream consumers. + state["message"]["additionalModelResponseFields"] = additional_fields elif "metadata" in chunk: time_to_first_byte_ms = ( int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None @@ -452,6 +457,7 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, + **kwargs, ) async for event in process_stream(chunks, start_time): diff --git a/src/strands/models/sglang.py b/src/strands/models/sglang.py new file mode 100644 index 000000000..deaa38118 --- /dev/null +++ b/src/strands/models/sglang.py @@ -0,0 +1,332 @@ +"""SGLang model provider (native API). + +This provider integrates with the SGLang Runtime **native** HTTP APIs, primarily: +- `/generate` for text generation (supports SSE streaming) +- `/tokenize` for tokenizing a prompt (optional; used for token-out prompt ids) + +Docs: +- https://docs.sglang.io/basic_usage/native_api.html + +Notes: +----- +`/generate` is completion-style: it accepts a single prompt (or input token IDs) and returns a single completion. +Strands uses a message-based interface, so this provider serializes text-only conversations into a single prompt. +Tool calling is not supported via `/generate`. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypedDict, TypeVar, Union, cast + +import httpx +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import Messages, SystemContentBlock +from ..types.event_loop import Metrics, Usage +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class SGLangModel(Model): + """SGLang native `/generate` provider with token-in/out helpers.""" + + class SGLangConfig(TypedDict, total=False): + """Configuration options for SGLang native API models.""" + + base_url: str + model_id: Optional[str] + params: Optional[dict[str, Any]] # default sampling params (merged into sampling_params) + timeout: Optional[Union[float, tuple[float, float]]] + + def __init__( + self, + *, + return_token_ids: bool = False, + **model_config: Unpack[SGLangConfig], + ) -> None: + """Create an SGLang model client.""" + validate_config_keys(model_config, self.SGLangConfig) + + base_url = str(model_config.get("base_url") or "http://localhost:30000").rstrip("/") + timeout = model_config.get("timeout") + if isinstance(timeout, tuple): + timeout_obj = httpx.Timeout(connect=timeout[0], read=timeout[1]) + else: + timeout_obj = httpx.Timeout(timeout or 30.0) + + self.client = httpx.AsyncClient(base_url=base_url, timeout=timeout_obj) + self.config = dict(model_config) + self.config["base_url"] = base_url + self._return_token_ids_default = bool(return_token_ids) + + logger.debug("config=<%s> | initializing", self.config) + + @override + def update_config(self, **model_config: Unpack[SGLangConfig]) -> None: # type: ignore[override] + validate_config_keys(model_config, self.SGLangConfig) + if "base_url" in model_config and model_config["base_url"]: + # Preserve base_url canonicalization + self.config["base_url"] = str(model_config["base_url"]).rstrip("/") + self.config.update(model_config) + + @override + def get_config(self) -> SGLangConfig: + return cast(SGLangModel.SGLangConfig, self.config) + + def _messages_to_prompt( + self, + messages: Messages, + system_prompt: Optional[str], + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + ) -> str: + # Only support text content blocks. Tools and multimodal content are not supported via /generate. + def text_from_blocks(role: str, blocks: list[dict[str, Any]]) -> str: + parts: list[str] = [] + for block in blocks: + if "text" in block: + parts.append(str(block["text"])) + else: + raise TypeError(f"SGLangModel only supports text content blocks. got role={role} block={block}") + return "".join(parts) + + # Back-compat: if system_prompt is provided but system_prompt_content is None. + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + lines: list[str] = [] + for block in system_prompt_content or []: + if "text" in block: + lines.append(f"system: {block['text']}") + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", []) + # Reject tool/multimodal blocks early + if any(k in b for b in content for k in ("toolUse", "toolResult", "image", "document", "reasoningContent")): + raise TypeError("SGLangModel /generate does not support tools or multimodal message blocks.") + text = text_from_blocks(str(role), cast(list[dict[str, Any]], content)) + if text.strip(): + lines.append(f"{role}: {text}") + + # Add a final assistant prefix to make the completion shape stable. + lines.append("assistant:") + return "\n".join(lines).strip() + "\n" + + async def _tokenize(self, prompt: str) -> list[int]: + model_id = self.get_config().get("model_id") + payload: dict[str, Any] = { + "prompt": prompt, + "add_special_tokens": False, + } + if model_id: + payload["model"] = model_id + + resp = await self.client.post("/tokenize", json=payload) + resp.raise_for_status() + data = resp.json() + tokens = data.get("tokens") + if not isinstance(tokens, list) or not all(isinstance(x, int) for x in tokens): + raise ValueError(f"Unexpected /tokenize response: {data}") + return cast(list[int], tokens) + + def _build_generate_payload( + self, + *, + prompt: Optional[str], + prompt_token_ids: Optional[list[int]], + sampling_params: dict[str, Any], + stream: bool, + ) -> dict[str, Any]: + model_id = self.get_config().get("model_id") + payload: dict[str, Any] = {"stream": stream} + + if model_id: + payload["model"] = model_id + + if prompt_token_ids is not None: + payload["input_ids"] = prompt_token_ids + else: + payload["text"] = prompt or "" + + if sampling_params: + payload["sampling_params"] = sampling_params + + return payload + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, + ) -> AsyncIterable[StreamEvent]: + if tool_specs is not None or tool_choice is not None: + raise TypeError("SGLangModel /generate does not support tool_specs/tool_choice.") + + return_token_ids = bool(kwargs.pop("return_token_ids", self._return_token_ids_default)) + prompt_token_ids = kwargs.pop("prompt_token_ids", None) + if prompt_token_ids is not None: + if ( + not isinstance(prompt_token_ids, list) + or not prompt_token_ids + or not all(isinstance(x, int) for x in prompt_token_ids) + ): + raise TypeError("prompt_token_ids must be a non-empty list[int].") + prompt_token_ids = cast(list[int], prompt_token_ids) + + sampling_params: dict[str, Any] = {} + cfg_params = self.get_config().get("params") + if isinstance(cfg_params, dict): + sampling_params.update(cfg_params) + + if "sampling_params" in kwargs: + sp = kwargs.pop("sampling_params") + if sp is not None: + if not isinstance(sp, dict): + raise TypeError("sampling_params must be a dict when provided.") + sampling_params.update(cast(dict[str, Any], sp)) + + sampling_params.update(kwargs) + + prompt_text: str | None = None + prompt_token_ids_out: list[int] | None = None + if prompt_token_ids is None: + prompt_text = self._messages_to_prompt(messages, system_prompt, system_prompt_content=system_prompt_content) + if return_token_ids: + try: + prompt_token_ids_out = await self._tokenize(prompt_text) + except httpx.HTTPStatusError as e: + if e.response.status_code == 429: + raise ModelThrottledException(str(e)) from e + raise + + payload = self._build_generate_payload( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + stream=True, + ) + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + + prev_text = "" + last_output_ids: list[int] = [] + last_meta: dict[str, Any] | None = None + + try: + async with self.client.stream("POST", "/generate", json=payload) as resp: + resp.raise_for_status() + + async for line in resp.aiter_lines(): + if not line: + continue + if not line.startswith("data:"): + continue + data_content = line[len("data:") :].strip() + if data_content == "[DONE]": + break + try: + event = json.loads(data_content) + except json.JSONDecodeError: + continue + + new_text = event.get("text") + if isinstance(new_text, str): + if new_text.startswith(prev_text): + delta = new_text[len(prev_text) :] + else: + delta = new_text + prev_text = new_text + if delta: + yield {"contentBlockDelta": {"delta": {"text": delta}}} + + output_ids = event.get("output_ids") + if isinstance(output_ids, list) and all(isinstance(x, int) for x in output_ids): + last_output_ids = cast(list[int], output_ids) + + meta = event.get("meta_info") + if isinstance(meta, dict): + last_meta = cast(dict[str, Any], meta) + + except httpx.HTTPStatusError as e: + status = e.response.status_code + if status == 400: + raise ContextWindowOverflowException(str(e)) from e + if status in (429, 503): + raise ModelThrottledException(str(e)) from e + raise + + yield {"contentBlockStop": {}} + + additional: dict[str, Any] = {} + if prompt_token_ids is not None: + additional["prompt_token_ids"] = prompt_token_ids + elif prompt_token_ids_out is not None: + additional["prompt_token_ids"] = prompt_token_ids_out + if last_output_ids: + additional["token_ids"] = last_output_ids + + stop_reason: str = "end_turn" + if last_meta and isinstance(last_meta.get("finish_reason"), dict): + fr = cast(dict[str, Any], last_meta.get("finish_reason")) + if fr.get("type") == "length": + stop_reason = "max_tokens" + + yield {"messageStop": {"stopReason": cast(Any, stop_reason), "additionalModelResponseFields": additional}} + + if last_meta: + usage: Usage = { + "inputTokens": int(last_meta.get("prompt_tokens") or 0), + "outputTokens": int(last_meta.get("completion_tokens") or 0), + "totalTokens": int((last_meta.get("prompt_tokens") or 0) + (last_meta.get("completion_tokens") or 0)), + } + latency_ms = int(float(last_meta.get("e2e_latency") or 0.0) * 1000) + metrics: Metrics = {"latencyMs": latency_ms} + yield {"metadata": {"usage": usage, "metrics": metrics}} + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + instruction = ( + "Return ONLY valid JSON that matches the schema. Do not include any extra keys or prose.\n" + f"Schema: {output_model.model_json_schema()}\n" + ) + prompt2: Messages = [ + {"role": "user", "content": [{"text": instruction}]}, + *prompt, + ] + + text = "" + async for event in self.stream( + prompt2, + tool_specs=None, + system_prompt=system_prompt, + system_prompt_content=kwargs.pop("system_prompt_content", None), + **kwargs, + ): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + text += delta["text"] + + try: + yield {"output": output_model.model_validate_json(text.strip())} + except Exception as e: + raise ValueError(f"Failed to parse structured output JSON: {e}") from e diff --git a/src/strands/models/vllm.py b/src/strands/models/vllm.py new file mode 100644 index 000000000..c8b6fcf16 --- /dev/null +++ b/src/strands/models/vllm.py @@ -0,0 +1,632 @@ +"""vLLM model provider (OpenAI-compatible). + +This provider is implemented as a first-class Strands `Model` (not a subclass of `OpenAIModel`). + +It targets vLLM's OpenAI-compatible server and supports: +- **token-out**: `prompt_token_ids`, `token_ids`, logprobs (when the server includes them) +- **token-in**: request-scoped `prompt_token_ids` via `extra_body` +- **tools**: via `/v1/chat/completions` (tool calling) + +vLLM exposes provider-specific fields that are not part of the official OpenAI API schema. We send +those fields via `extra_body` to avoid OpenAI SDK validation errors, and we preserve them back onto +`messageStop.additionalModelResponseFields` for downstream consumers. +""" + +from __future__ import annotations + +import base64 +import json +import logging +import mimetypes +from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypedDict, TypeVar, Union, cast + +import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import StopReason +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import MessageStopEvent, StreamEvent +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class VLLMModel(Model): + """OpenAI-compatible vLLM provider with token-in/out helpers.""" + + class VLLMConfig(TypedDict, total=False): + """Configuration options for vLLM OpenAI-compatible models. + + Attributes: + model_id: Model ID to pass to the server (e.g., "meta-llama/Llama-3.1-8B-Instruct"). + params: Base request params merged into every request (e.g., max_tokens, temperature). + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__( + self, + client_args: Optional[dict[str, Any]] = None, + *, + return_token_ids: bool = False, + **model_config: Unpack[VLLMConfig], + ) -> None: + """Create a vLLM OpenAI-compatible model client.""" + validate_config_keys(model_config, self.VLLMConfig) + self.config = dict(model_config) + self.client_args = client_args or {} + self._return_token_ids_default = bool(return_token_ids) + + @override + def update_config(self, **model_config: Unpack[VLLMConfig]) -> None: # type: ignore[override] + validate_config_keys(model_config, self.VLLMConfig) + self.config.update(model_config) + + @override + def get_config(self) -> VLLMConfig: + return cast(VLLMModel.VLLMConfig, self.config) + + @staticmethod + def _safe_model_dump(obj: Any) -> dict[str, Any]: + model_dump = getattr(obj, "model_dump", None) + if not callable(model_dump): + return {} + try: + dumped = model_dump() + except Exception: + return {} + return dumped if isinstance(dumped, dict) else {} + + @staticmethod + def _choice0_dump(dumped: dict[str, Any]) -> dict[str, Any] | None: + choices = dumped.get("choices") + if not isinstance(choices, list) or not choices: + return None + c0 = choices[0] + return c0 if isinstance(c0, dict) else None + + @staticmethod + def _extend_token_ids_from_choice_dump(token_ids: list[int], choice_dump: dict[str, Any] | None) -> None: + if not choice_dump: + return + maybe = choice_dump.get("token_ids") + if isinstance(maybe, list) and maybe and all(isinstance(x, int) for x in maybe): + token_ids.extend(cast(list[int], maybe)) + + @staticmethod + def _extend_logprobs_from_choice_dump(completion_logprobs: list[Any], choice_dump: dict[str, Any] | None) -> None: + if not choice_dump: + return + lp = choice_dump.get("logprobs") + if lp is None: + return + if isinstance(lp, dict): + content = lp.get("content") + if isinstance(content, list): + completion_logprobs.extend(content) + return + completion_logprobs.append(lp) + return + completion_logprobs.append(lp) + + @staticmethod + def _stream_switch_content(next_type: str, prev_type: str | None) -> tuple[list[StreamEvent], str]: + chunks: list[StreamEvent] = [] + if prev_type != next_type: + if prev_type is not None: + chunks.append({"contentBlockStop": {}}) + chunks.append({"contentBlockStart": {"start": {}}}) + return chunks, next_type + + @classmethod + def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def _format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @classmethod + def _format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": [cls._format_request_message_content(content) for content in contents], + } + + @classmethod + def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: + if not tool_choice: + return {} + + match tool_choice: + case {"auto": _}: + return {"tool_choice": "auto"} + case {"any": _}: + return {"tool_choice": "required"} + case {"tool": {"name": tool_name}}: + return {"tool_choice": {"type": "function", "function": {"name": tool_name}}} + case _: + return {"tool_choice": "auto"} + + @classmethod + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **_kwargs: Any, + ) -> list[dict[str, Any]]: + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + return [ + {"role": "system", "content": content["text"]} + for content in system_prompt_content or [] + if "text" in content + ] + + @classmethod + def _format_regular_messages(cls, messages: Messages, **_kwargs: Any) -> list[dict[str, Any]]: + formatted_messages: list[dict[str, Any]] = [] + + for message in messages: + contents = message["content"] + + if any("reasoningContent" in content for content in contents): + logger.warning( + "reasoningContent is not supported in multi-turn conversations with the Chat Completions API." + ) + + formatted_contents = [ + cls._format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) + ] + formatted_tool_calls = [ + cls._format_request_message_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + formatted_tool_messages = [ + cls._format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return formatted_messages + + @classmethod + def _format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + formatted_messages = cls._format_system_messages( + system_prompt, + system_prompt_content=system_prompt_content, + **kwargs, + ) + formatted_messages.extend(cls._format_regular_messages(messages, **kwargs)) + return [message for message in formatted_messages if message.get("content") or "tool_calls" in message] + + def _format_request( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: ToolChoice | None = None, + *, + system_prompt_content: list[SystemContentBlock] | None = None, + **_kwargs: Any, + ) -> dict[str, Any]: + return { + "messages": self._format_request_messages( + messages, + system_prompt, + system_prompt_content=system_prompt_content, + ), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + **(self._format_request_tool_choice(tool_choice)), + **cast(dict[str, Any], self.config.get("params", {})), + } + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **_kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + async with openai.AsyncOpenAI(**self.client_args) as client: + try: + response: ParsedChatCompletion = await client.beta.chat.completions.parse( + model=self.get_config()["model_id"], + messages=self._format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) + except openai.BadRequestError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + raise ContextWindowOverflowException(str(e)) from e + raise + except openai.RateLimitError as e: + raise ModelThrottledException(str(e)) from e + + parsed: T | None = None + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the OpenAI response.") + + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + yield {"output": parsed} + else: + raise ValueError("No valid tool use or tool use input was found in the OpenAI response.") + + async def _stream_completions_token_in( + self, + *, + prompt_token_ids: list[int], + max_tokens: int | None = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Token-in streaming (no messages) via vLLM `/v1/completions`. + + This bypasses chat message formatting and sends the already-tokenized prompt to vLLM. + """ + if ( + not isinstance(prompt_token_ids, list) + or not prompt_token_ids + or not all(isinstance(x, int) for x in prompt_token_ids) + ): + raise TypeError("prompt_token_ids must be a non-empty list[int].") + + req_kwargs = dict(kwargs) + req_kwargs["prompt_token_ids"] = prompt_token_ids + req_kwargs = self._merge_vllm_extra_body(kwargs=req_kwargs) + + extra_body = cast(dict[str, Any], req_kwargs.get("extra_body") or {}) + if self._return_token_ids_default and "return_token_ids" not in extra_body: + extra_body["return_token_ids"] = True + req_kwargs["extra_body"] = extra_body + + if max_tokens is not None: + req_kwargs["max_tokens"] = max_tokens + + # vLLM completions validates that `prompt` is non-empty (even if prompt_token_ids is provided), + # so provide a harmless placeholder and rely on prompt_token_ids for the actual tokens. + request: dict[str, Any] = { + "model": self.get_config()["model_id"], + "prompt": " ", + "stream": True, + **(self.get_config().get("params") or {}), + **req_kwargs, + } + + async with openai.AsyncOpenAI(**self.client_args) as client: + response = await client.completions.create(**request) + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + + token_ids: list[int] = [] + finish_reason: str | None = None + + async for event in response: + if not getattr(event, "choices", None): + continue + + dumped = self._safe_model_dump(event) + self._extend_token_ids_from_choice_dump(token_ids, self._choice0_dump(dumped)) + + choice0 = event.choices[0] + if getattr(choice0, "text", None): + yield {"contentBlockDelta": {"delta": {"text": choice0.text}}} + if getattr(choice0, "finish_reason", None): + finish_reason = choice0.finish_reason + break + + yield {"contentBlockStop": {}} + + additional: dict[str, Any] = {"prompt_token_ids": prompt_token_ids} + if token_ids: + additional["token_ids"] = token_ids + + stop_reason: StopReason = "end_turn" if finish_reason in (None, "stop") else "max_tokens" + yield { + "messageStop": { + "stopReason": stop_reason, + "additionalModelResponseFields": additional, + } + } + + def _merge_vllm_extra_body(self, *, kwargs: dict[str, Any]) -> dict[str, Any]: + """Merge vLLM-specific request fields into `extra_body` and remove them from kwargs. + + We keep vLLM-only fields inside `extra_body` to avoid OpenAI SDK validation errors. + """ + extra_body = kwargs.get("extra_body") + if extra_body is None: + extra_body_dict: dict[str, Any] = {} + else: + if not isinstance(extra_body, dict): + raise TypeError("extra_body must be a dict when provided.") + extra_body_dict = dict(extra_body) + + # Allow per-request override via kwarg while keeping OpenAIModel compatible: + # - `return_token_ids` is a vLLM extension. + if "return_token_ids" in kwargs: + extra_body_dict["return_token_ids"] = bool(kwargs.pop("return_token_ids")) + elif self._return_token_ids_default and "return_token_ids" not in extra_body_dict: + extra_body_dict["return_token_ids"] = True + + # Token-in: pass the fully formatted prompt token IDs. + # This is a vLLM extension; keep it in extra_body. + if "prompt_token_ids" in kwargs: + prompt_token_ids = kwargs.pop("prompt_token_ids") + if prompt_token_ids is not None: + if not isinstance(prompt_token_ids, list) or not all(isinstance(x, int) for x in prompt_token_ids): + raise TypeError("prompt_token_ids must be a list[int] when provided.") + extra_body_dict["prompt_token_ids"] = prompt_token_ids + + # vLLM logprobs: allow passing an int (e.g. 1) without OpenAI SDK type constraints. + if "logprobs" in kwargs: + extra_body_dict["logprobs"] = kwargs.pop("logprobs") + + kwargs["extra_body"] = extra_body_dict + return kwargs + + async def _stream_chat_vllm( + self, + *, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Chat-completions streaming with vLLM-specific token/logprobs extraction.""" + req_kwargs = self._merge_vllm_extra_body(kwargs=dict(kwargs)) + request_prompt_token_ids: list[int] | None = None + extra_body = req_kwargs.get("extra_body") + if isinstance(extra_body, dict): + pti = extra_body.get("prompt_token_ids") + if isinstance(pti, list) and all(isinstance(x, int) for x in pti): + request_prompt_token_ids = cast(list[int], pti) + + request = self._format_request( + messages, + tool_specs, + system_prompt, + tool_choice, + system_prompt_content=system_prompt_content, + ) + if req_kwargs: + request.update(req_kwargs) + + async with openai.AsyncOpenAI(**self.client_args) as client: + response = await client.chat.completions.create(**request) + + yield {"messageStart": {"role": "assistant"}} + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + event = None + + prompt_token_ids: list[int] | None = None + prompt_logprobs: Any = None + token_ids: list[int] = [] + completion_logprobs: list[Any] = [] + + async for event in response: + if not getattr(event, "choices", None): + continue + + dumped = self._safe_model_dump(event) + + if prompt_token_ids is None and dumped.get("prompt_token_ids") is not None: + prompt_token_ids = cast(list[int], dumped.get("prompt_token_ids")) + if prompt_logprobs is None and dumped.get("prompt_logprobs") is not None: + prompt_logprobs = dumped.get("prompt_logprobs") + + choice = event.choices[0] + choice0_dump = self._choice0_dump(dumped) + self._extend_token_ids_from_choice_dump(token_ids, choice0_dump) + self._extend_logprobs_from_choice_dump(completion_logprobs, choice0_dump) + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield chunk + yield { + "contentBlockDelta": {"delta": {"reasoningContent": {"text": choice.delta.reasoning_content}}} + } + + if choice.delta.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield chunk + yield {"contentBlockDelta": {"delta": {"text": choice.delta.content}}} + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + finish_reason = choice.finish_reason + break + + if data_type is not None: + yield {"contentBlockStop": {}} + + for tool_deltas in tool_calls.values(): + first = tool_deltas[0] + yield { + "contentBlockStart": {"start": {"toolUse": {"toolUseId": first.id, "name": first.function.name}}} + } + for td in tool_deltas: + yield {"contentBlockDelta": {"delta": {"toolUse": {"input": td.function.arguments or ""}}}} + yield {"contentBlockStop": {}} + + if tool_calls and finish_reason in (None, "stop"): + finish_reason = "tool_calls" + + additional: dict[str, Any] = {} + if request_prompt_token_ids is not None: + additional["prompt_token_ids"] = request_prompt_token_ids + elif prompt_token_ids is not None: + additional["prompt_token_ids"] = prompt_token_ids + if prompt_logprobs is not None: + additional["prompt_logprobs"] = prompt_logprobs + if token_ids: + additional["token_ids"] = token_ids + if completion_logprobs: + additional["logprobs"] = completion_logprobs + + stop_reason = ( + "tool_use" + if finish_reason == "tool_calls" + else ("max_tokens" if finish_reason == "length" else "end_turn") + ) + message_stop: MessageStopEvent = {"stopReason": cast(StopReason, stop_reason)} + if additional: + message_stop["additionalModelResponseFields"] = additional + yield {"messageStop": message_stop} + + async for event in response: + _ = event + if event and hasattr(event, "usage") and event.usage: + yield { + "metadata": { + "usage": { + "inputTokens": event.usage.prompt_tokens, + "outputTokens": event.usage.completion_tokens, + "totalTokens": event.usage.total_tokens, + }, + "metrics": {"latencyMs": 0}, + } + } + + @override + def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, + ) -> AsyncIterable[StreamEvent]: + prompt_token_ids = kwargs.pop("prompt_token_ids", None) + if prompt_token_ids is not None: + token_in_endpoint = kwargs.pop("token_in_endpoint", "auto") + if token_in_endpoint not in ("auto", "chat", "completions"): + raise ValueError("token_in_endpoint must be one of: 'auto', 'chat', 'completions'.") + + if token_in_endpoint == "auto": + token_in_endpoint = "chat" if (tool_specs or tool_choice) else "completions" + + if token_in_endpoint == "completions": + if tool_specs is not None or tool_choice is not None: + raise TypeError("tool_specs/tool_choice are not supported in token-only mode.") + if system_prompt is not None or system_prompt_content is not None: + raise TypeError("system_prompt/system_prompt_content are not supported in token-only mode.") + max_tokens = kwargs.pop("max_tokens", None) + return self._stream_completions_token_in( + prompt_token_ids=cast(list[int], prompt_token_ids), + max_tokens=max_tokens, + **kwargs, + ) + + return self._stream_chat_vllm( + messages=[{"role": "user", "content": [{"text": ""}]}], + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + prompt_token_ids=cast(list[int], prompt_token_ids), + **kwargs, + ) + + return self._stream_chat_vllm( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + **kwargs, + ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 02be400b1..4bf1cc407 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -899,6 +899,61 @@ async def test_stream_messages(agenerator, alist): ) +@pytest.mark.asyncio +async def test_stream_messages_forwards_kwargs(agenerator, alist): + """Test that stream_messages forwards kwargs to model.stream().""" + mock_model = unittest.mock.MagicMock() + mock_model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + ] + ) + + stream = strands.event_loop.streaming.stream_messages( + mock_model, + system_prompt_content=[{"text": "test prompt"}], + messages=[{"role": "user", "content": [{"text": "Hello"}]}], + tool_specs=[], + system_prompt=None, + extra_body={"return_token_ids": True}, + ) + + await alist(stream) + + mock_model.stream.assert_called_with( + [{"role": "user", "content": [{"text": "Hello"}]}], + None, + None, + tool_choice=None, + system_prompt_content=[{"text": "test prompt"}], + extra_body={"return_token_ids": True}, + ) + + +@pytest.mark.asyncio +async def test_process_stream_preserves_additional_model_response_fields(agenerator, alist): + """Test that messageStop.additionalModelResponseFields is preserved into the final message.""" + response = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + { + "messageStop": { + "stopReason": "end_turn", + "additionalModelResponseFields": {"prompt_token_ids": [1, 2, 3], "token_ids": [4, 5]}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + message = cast(Message, last_event["stop"][1]) + + assert message["additionalModelResponseFields"]["prompt_token_ids"] == [1, 2, 3] + assert message["additionalModelResponseFields"]["token_ids"] == [4, 5] + + @pytest.mark.asyncio async def test_stream_messages_with_system_prompt_content(agenerator, alist): """Test stream_messages with SystemContentBlock input.""" diff --git a/tests/strands/models/test_sglang.py b/tests/strands/models/test_sglang.py new file mode 100644 index 000000000..5df3a6a65 --- /dev/null +++ b/tests/strands/models/test_sglang.py @@ -0,0 +1,135 @@ +import unittest.mock + +import pytest + +import strands +from strands.models.sglang import SGLangModel + + +@pytest.fixture +def httpx_client(): + with unittest.mock.patch.object(strands.models.sglang.httpx, "AsyncClient") as mock_client_cls: + mock_client = unittest.mock.Mock() + mock_client_cls.return_value = mock_client + + # httpx.AsyncClient.stream(...) returns an async context manager. + stream_cm = unittest.mock.Mock() + stream_cm.__aenter__ = unittest.mock.AsyncMock() + stream_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) + mock_client.stream.return_value = stream_cm + + yield mock_client + + +def _aline_iter(lines: list[str]): + async def gen(): + for line in lines: + yield line + + return gen() + + +@pytest.mark.asyncio +async def test_sglang_stream_parses_sse_and_emits_text_deltas(httpx_client, alist): + # Mock /generate stream SSE + resp = unittest.mock.Mock() + resp.raise_for_status = unittest.mock.Mock() + resp.aiter_lines = unittest.mock.Mock( + return_value=_aline_iter( + [ + 'data: {"text":"h","output_ids":[1],"meta_info":{"finish_reason":{"type":"stop"},' + '"prompt_tokens":2,"completion_tokens":1,"e2e_latency":0.01}}', + 'data: {"text":"hi","output_ids":[1,2],"meta_info":{"finish_reason":{"type":"stop"},' + '"prompt_tokens":2,"completion_tokens":2,"e2e_latency":0.01}}', + "data: [DONE]", + ] + ) + ) + + # Async context manager for client.stream(...) + httpx_client.stream.return_value.__aenter__.return_value = resp + httpx_client.stream.return_value.__aexit__.return_value = None + + model = SGLangModel(base_url="http://localhost:30000", model_id=None, params=None, return_token_ids=False) + events = await alist(model.stream([{"role": "user", "content": [{"text": "hi"}]}])) + + assert events[0] == {"messageStart": {"role": "assistant"}} + assert events[1] == {"contentBlockStart": {"start": {}}} + # deltas should be incremental: "h" then "i" + deltas = [e["contentBlockDelta"]["delta"]["text"] for e in events if "contentBlockDelta" in e] + assert deltas == ["h", "i"] + + stop = next(e for e in events if "messageStop" in e)["messageStop"] + additional = stop["additionalModelResponseFields"] + assert additional["token_ids"] == [1, 2] + + +@pytest.mark.asyncio +async def test_sglang_token_in_preserves_prompt_token_ids(httpx_client, alist): + resp = unittest.mock.Mock() + resp.raise_for_status = unittest.mock.Mock() + resp.aiter_lines = unittest.mock.Mock( + return_value=_aline_iter( + [ + 'data: {"text":"ok","output_ids":[9,10],"meta_info":{"finish_reason":{"type":"stop"},' + '"prompt_tokens":3,"completion_tokens":2,"e2e_latency":0.01}}', + "data: [DONE]", + ] + ) + ) + httpx_client.stream.return_value.__aenter__.return_value = resp + httpx_client.stream.return_value.__aexit__.return_value = None + + model = SGLangModel(base_url="http://localhost:30000", model_id=None, params=None, return_token_ids=False) + events = await alist( + model.stream( + [{"role": "user", "content": [{"text": "ignored"}]}], + prompt_token_ids=[1, 2, 3], + temperature=0, + ) + ) + + # Ensure token-in was sent as input_ids + called = httpx_client.stream.call_args.kwargs["json"] + assert called["input_ids"] == [1, 2, 3] + assert called["stream"] is True + + stop = next(e for e in events if "messageStop" in e)["messageStop"] + additional = stop["additionalModelResponseFields"] + assert additional["prompt_token_ids"] == [1, 2, 3] + assert additional["token_ids"] == [9, 10] + + +@pytest.mark.asyncio +async def test_sglang_text_prompt_token_out_uses_tokenize_when_enabled(httpx_client, alist): + # Mock /tokenize + tok_resp = unittest.mock.Mock() + tok_resp.raise_for_status = unittest.mock.Mock() + tok_resp.json = unittest.mock.Mock(return_value={"tokens": [101, 102]}) + httpx_client.post = unittest.mock.AsyncMock(return_value=tok_resp) + + # Mock /generate stream + resp = unittest.mock.Mock() + resp.raise_for_status = unittest.mock.Mock() + resp.aiter_lines = unittest.mock.Mock( + return_value=_aline_iter( + [ + 'data: {"text":"yo","output_ids":[7],"meta_info":{"finish_reason":{"type":"stop"}}}', + "data: [DONE]", + ] + ) + ) + httpx_client.stream.return_value.__aenter__.return_value = resp + httpx_client.stream.return_value.__aexit__.return_value = None + + model = SGLangModel(base_url="http://localhost:30000", model_id="m1", params=None, return_token_ids=True) + events = await alist(model.stream([{"role": "user", "content": [{"text": "hello"}]}])) + + # tokenization called + httpx_client.post.assert_awaited() + assert httpx_client.post.call_args.args[0] == "/tokenize" + + stop = next(e for e in events if "messageStop" in e)["messageStop"] + additional = stop["additionalModelResponseFields"] + assert additional["prompt_token_ids"] == [101, 102] + assert additional["token_ids"] == [7] diff --git a/tests/strands/models/test_vllm.py b/tests/strands/models/test_vllm.py new file mode 100644 index 000000000..aefd50f37 --- /dev/null +++ b/tests/strands/models/test_vllm.py @@ -0,0 +1,177 @@ +import unittest.mock + +import pytest + +import strands +from strands.models.vllm import VLLMModel + + +@pytest.fixture +def openai_client(): + with unittest.mock.patch.object(strands.models.vllm.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client + + +@pytest.mark.asyncio +async def test_vllm_model_injects_return_token_ids_by_default(openai_client, agenerator, alist): + model = VLLMModel(model_id="m1", params={"max_tokens": 1}, return_token_ids=True) + + mock_delta = unittest.mock.Mock(content="hi", tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock(usage=None) + openai_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + _ = await alist(model.stream(messages)) + + called_kwargs = openai_client.chat.completions.create.call_args.kwargs + assert called_kwargs["extra_body"]["return_token_ids"] is True + + +@pytest.mark.asyncio +async def test_vllm_model_moves_prompt_token_ids_into_extra_body(openai_client, agenerator, alist): + model = VLLMModel(model_id="m1", params={"max_tokens": 1}) + + mock_delta = unittest.mock.Mock(content="hi", tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock(usage=None) + openai_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + # Force chat token-in to validate prompt_token_ids placement in chat requests. + _ = await alist( + model.stream( + messages, + prompt_token_ids=[1, 2, 3], + token_in_endpoint="chat", + extra_body={"foo": "bar"}, + ) + ) + + called_kwargs = openai_client.chat.completions.create.call_args.kwargs + # prompt_token_ids should *not* be a top-level OpenAI request parameter. + assert "prompt_token_ids" not in called_kwargs + assert called_kwargs["extra_body"]["prompt_token_ids"] == [1, 2, 3] + assert called_kwargs["extra_body"]["foo"] == "bar" + + +@pytest.mark.asyncio +async def test_vllm_stream_token_ids_uses_completions(openai_client, agenerator, alist): + model = VLLMModel(model_id="m1", params={}, return_token_ids=True) + + # Mock streaming completion events (text deltas + finish) + choice1 = unittest.mock.Mock(text="hi", finish_reason=None) + choice2 = unittest.mock.Mock(text=None, finish_reason="stop") + ev1 = unittest.mock.Mock(choices=[choice1]) + ev2 = unittest.mock.Mock(choices=[choice2]) + + openai_client.completions.create = unittest.mock.AsyncMock(return_value=agenerator([ev1, ev2])) + + # Token-only mode is exercised via the main stream() entrypoint. + messages = [{"role": "user", "content": [{"text": "ignored"}]}] + events = await alist( + model.stream( + messages, + prompt_token_ids=[1, 2, 3], + token_in_endpoint="completions", + max_tokens=4, + ) + ) + + # Ensure we called the completions endpoint (token-only mode). + called_kwargs = openai_client.completions.create.call_args.kwargs + assert called_kwargs["prompt"] == " " + assert called_kwargs["stream"] is True + assert called_kwargs["extra_body"]["prompt_token_ids"] == [1, 2, 3] + assert called_kwargs["extra_body"]["return_token_ids"] is True + + # Basic event shape + assert events[0] == {"messageStart": {"role": "assistant"}} + assert any("messageStop" in e for e in events) + + +@pytest.mark.asyncio +async def test_vllm_stream_chat_token_ids_uses_chat_completions(openai_client, agenerator, alist): + model = VLLMModel(model_id="m1", params={"max_tokens": 4}, return_token_ids=True) + + mock_delta = unittest.mock.Mock(content="hi", tool_calls=None, reasoning_content=None) + ev1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + ev2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + ev3 = unittest.mock.Mock(usage=None) + + openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=agenerator([ev1, ev2, ev3])) + + # Token-in chat mode is exercised via the main stream() entrypoint. + messages = [{"role": "user", "content": [{"text": "ignored"}]}] + events = await alist(model.stream(messages, prompt_token_ids=[11, 22, 33], token_in_endpoint="chat")) + + called_kwargs = openai_client.chat.completions.create.call_args.kwargs + assert called_kwargs["extra_body"]["prompt_token_ids"] == [11, 22, 33] + assert called_kwargs["extra_body"]["return_token_ids"] is True + assert called_kwargs["stream"] is True + assert isinstance(events, list) and events + + +@pytest.mark.asyncio +async def test_vllm_stream_routes_prompt_token_ids_to_completions(openai_client, agenerator, alist): + model = VLLMModel(model_id="m1", params={}, return_token_ids=True) + + choice1 = unittest.mock.Mock(text="hi", finish_reason=None) + choice2 = unittest.mock.Mock(text=None, finish_reason="stop") + ev1 = unittest.mock.Mock(choices=[choice1]) + ev2 = unittest.mock.Mock(choices=[choice2]) + openai_client.completions.create = unittest.mock.AsyncMock(return_value=agenerator([ev1, ev2])) + + # messages are required by the base interface, but will be ignored in completions token-in mode. + messages = [{"role": "user", "content": [{"text": "ignored"}]}] + events = await alist( + model.stream( + messages, + prompt_token_ids=[1, 2, 3], + token_in_endpoint="completions", + max_tokens=4, + ) + ) + + called_kwargs = openai_client.completions.create.call_args.kwargs + assert called_kwargs["extra_body"]["prompt_token_ids"] == [1, 2, 3] + assert any("messageStop" in e for e in events) + + +@pytest.mark.asyncio +async def test_vllm_stream_routes_prompt_token_ids_to_chat_when_tools_present(openai_client, agenerator, alist): + model = VLLMModel(model_id="m1", params={"max_tokens": 4}, return_token_ids=True) + + mock_delta = unittest.mock.Mock(content="hi", tool_calls=None, reasoning_content=None) + ev1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + ev2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + ev3 = unittest.mock.Mock(usage=None) + openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=agenerator([ev1, ev2, ev3])) + + tool_specs = [ + { + "name": "echo_tool", + "description": "Echo input text.", + "inputSchema": { + "json": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + } + }, + } + ] + + messages = [{"role": "user", "content": [{"text": "ignored"}]}] + _ = await alist(model.stream(messages, tool_specs=tool_specs, prompt_token_ids=[9, 9, 9], token_in_endpoint="auto")) + + called_kwargs = openai_client.chat.completions.create.call_args.kwargs + assert called_kwargs["extra_body"]["prompt_token_ids"] == [9, 9, 9] diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 26453e1f7..f9fae9c26 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -4,6 +4,7 @@ import boto3 import pytest +from botocore.exceptions import NoRegionError logger = logging.getLogger(__name__) @@ -54,20 +55,29 @@ async def alist(items): def _load_api_keys_from_secrets_manager(): """Load API keys as environment variables from AWS Secrets Manager.""" - session = boto3.session.Session() - client = session.client(service_name="secretsmanager") - if "STRANDS_TEST_API_KEYS_SECRET_NAME" in os.environ: + secret_name = os.getenv("STRANDS_TEST_API_KEYS_SECRET_NAME") + if secret_name: try: - secret_name = os.getenv("STRANDS_TEST_API_KEYS_SECRET_NAME") - response = client.get_secret_value(SecretId=secret_name) - - if "SecretString" in response: - secret = json.loads(response["SecretString"]) - for key, value in secret.items(): - os.environ[f"{key.upper()}_API_KEY"] = str(value) - + region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") + if not region: + logger.warning( + "STRANDS_TEST_API_KEYS_SECRET_NAME is set but AWS region is not configured; " + "skipping Secrets Manager lookup" + ) + else: + session = boto3.session.Session(region_name=region) + client = session.client(service_name="secretsmanager", region_name=region) + response = client.get_secret_value(SecretId=secret_name) + + if "SecretString" in response: + secret = json.loads(response["SecretString"]) + for key, value in secret.items(): + os.environ[f"{key.upper()}_API_KEY"] = str(value) + + except NoRegionError: + logger.warning("AWS region not configured; skipping Secrets Manager lookup") except Exception as e: - logger.warning("Error retrieving secret", e) + logger.warning("Error retrieving secret", exc_info=e) """ Validate that required environment variables are set when running in GitHub Actions. diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..2e745ed24 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -16,6 +16,8 @@ from strands.models.mistral import MistralModel from strands.models.ollama import OllamaModel from strands.models.openai import OpenAIModel +from strands.models.sglang import SGLangModel +from strands.models.vllm import VLLMModel from strands.models.writer import WriterModel @@ -59,6 +61,64 @@ def __init__(self): ) +class VLLMProviderInfo(ProviderInfo): + """Special case vLLM as it's dependent on the server being available.""" + + def __init__(self): + super().__init__( + id="vllm", + factory=lambda: VLLMModel( + client_args={ + "api_key": "EMPTY", + "base_url": os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"), + }, + model_id=os.getenv("VLLM_MODEL_ID", "AMead10/Llama-3.2-3B-Instruct-AWQ"), + params={"max_tokens": 64, "temperature": 0}, + return_token_ids=True, + ), + ) + + base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1").rstrip("/") + is_server_available = False + try: + # OpenAI-compatible discovery endpoint. + is_server_available = requests.get(f"{base_url}/models", timeout=2).ok + except requests.exceptions.RequestException: + pass + + self.mark = mark.skipif( + not is_server_available, + reason=f"Local vLLM endpoint not available at {base_url}", + ) + + +class SGLangProviderInfo(ProviderInfo): + """Special case SGLang as it's dependent on the server being available.""" + + def __init__(self): + super().__init__( + id="sglang", + factory=lambda: SGLangModel( + base_url=os.getenv("SGLANG_BASE_URL", "http://localhost:30000"), + model_id=os.getenv("SGLANG_MODEL_ID"), + params={"max_new_tokens": 64, "temperature": 0}, + return_token_ids=True, + ), + ) + + base_url = os.getenv("SGLANG_BASE_URL", "http://localhost:30000").rstrip("/") + is_server_available = False + try: + is_server_available = requests.get(f"{base_url}/health", timeout=2).ok + except requests.exceptions.RequestException: + pass + + self.mark = mark.skipif( + not is_server_available, + reason=f"Local SGLang endpoint not available at {base_url}", + ) + + anthropic = ProviderInfo( id="anthropic", environment_variable="ANTHROPIC_API_KEY", @@ -138,6 +198,8 @@ def __init__(self): ) ollama = OllamaProviderInfo() +vllm = VLLMProviderInfo() +sglang = SGLangProviderInfo() all_providers = [ diff --git a/tests_integ/models/test_model_sglang.py b/tests_integ/models/test_model_sglang.py new file mode 100644 index 000000000..1b9c3169d --- /dev/null +++ b/tests_integ/models/test_model_sglang.py @@ -0,0 +1,74 @@ +import pytest + +from strands import Agent +from strands.models.sglang import SGLangModel +from tests_integ.models import providers + +# These tests only run if a local SGLang server is reachable. +pytestmark = providers.sglang.mark + + +@pytest.fixture +def model() -> SGLangModel: + return providers.sglang.create_model() # type: ignore[return-value] + + +@pytest.fixture +def agent(model: SGLangModel) -> Agent: + return Agent(model=model) + + +def _additional(result_message: dict) -> dict: + additional = result_message.get("additionalModelResponseFields") + assert isinstance(additional, dict), f"missing additionalModelResponseFields: {result_message}" + return additional + + +def test_agent_invoke_preserves_token_ids(agent: Agent) -> None: + result = agent("hi", invocation_state={"model_kwargs": {"return_token_ids": True}}) + additional = _additional(result.message) + assert isinstance(additional.get("token_ids"), list) and additional["token_ids"] + assert isinstance(additional.get("prompt_token_ids"), list) and additional["prompt_token_ids"] + + +@pytest.mark.asyncio +async def test_agent_stream_async_preserves_token_ids(agent: Agent) -> None: + stream = agent.stream_async("hi", invocation_state={"model_kwargs": {"return_token_ids": True}}) + async for event in stream: + _ = event + result = event["result"] + additional = _additional(result.message) + assert isinstance(additional.get("token_ids"), list) and additional["token_ids"] + assert isinstance(additional.get("prompt_token_ids"), list) and additional["prompt_token_ids"] + + +@pytest.mark.asyncio +async def test_token_in_round_trip_preserves_prompt_token_ids(agent: Agent) -> None: + # Step 1: get prompt token ids from a text prompt + res1 = await agent.invoke_async( + "hi", + invocation_state={ + "model_kwargs": { + "return_token_ids": True, + # Ensure the model stops naturally (avoid MaxTokensReachedException in Agent loop). + "sampling_params": {"max_new_tokens": 64, "stop": ["\n"]}, + } + }, + ) + add1 = _additional(res1.message) + pti = add1["prompt_token_ids"] + assert isinstance(pti, list) and pti + + # Step 2: token-in call using those prompt_token_ids + res2 = await agent.invoke_async( + "ignored", + invocation_state={ + "model_kwargs": { + "prompt_token_ids": pti, + "sampling_params": {"max_new_tokens": 64, "stop": ["\n"]}, + } + }, + ) + add2 = _additional(res2.message) + assert add2.get("prompt_token_ids") == pti + assert isinstance(add2.get("token_ids"), list) and add2["token_ids"] diff --git a/tests_integ/models/test_model_vllm.py b/tests_integ/models/test_model_vllm.py new file mode 100644 index 000000000..c3695141b --- /dev/null +++ b/tests_integ/models/test_model_vllm.py @@ -0,0 +1,108 @@ +import os + +import pytest + +from strands import Agent, tool +from strands.event_loop.streaming import stream_messages +from strands.models.vllm import VLLMModel +from tests_integ.models import providers + +# These tests only run if a local vLLM OpenAI-compatible server is reachable. +pytestmark = providers.vllm.mark + + +@pytest.fixture +def model() -> VLLMModel: + base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1") + model_id = os.getenv("VLLM_MODEL_ID", "AMead10/Llama-3.2-3B-Instruct-AWQ") + return VLLMModel( + client_args={"api_key": "EMPTY", "base_url": base_url}, + model_id=model_id, + params={"max_tokens": 32, "temperature": 0}, + return_token_ids=True, + ) + + +@pytest.fixture +def agent(model: VLLMModel) -> Agent: + return Agent(model=model) + + +def _additional(result_message: dict) -> dict: + additional = result_message.get("additionalModelResponseFields") + assert isinstance(additional, dict), f"missing additionalModelResponseFields: {result_message}" + return additional + + +def test_agent_invoke_preserves_token_ids(agent: Agent) -> None: + result = agent( + "hi", + invocation_state={"model_kwargs": {"extra_body": {"return_token_ids": True}}}, + ) + + additional = _additional(result.message) + assert isinstance(additional.get("prompt_token_ids"), list) and additional["prompt_token_ids"] + assert isinstance(additional.get("token_ids"), list) and additional["token_ids"] + + +@pytest.mark.asyncio +async def test_agent_invoke_async_preserves_token_ids(agent: Agent) -> None: + result = await agent.invoke_async( + "hi", + invocation_state={"model_kwargs": {"extra_body": {"return_token_ids": True}}}, + ) + + additional = _additional(result.message) + assert isinstance(additional.get("prompt_token_ids"), list) and additional["prompt_token_ids"] + assert isinstance(additional.get("token_ids"), list) and additional["token_ids"] + + +@pytest.mark.asyncio +async def test_agent_stream_async_preserves_token_ids(agent: Agent) -> None: + stream = agent.stream_async( + "hi", + invocation_state={"model_kwargs": {"extra_body": {"return_token_ids": True}}}, + ) + + async for event in stream: + _ = event + + result = event["result"] + additional = _additional(result.message) + assert isinstance(additional.get("prompt_token_ids"), list) and additional["prompt_token_ids"] + assert isinstance(additional.get("token_ids"), list) and additional["token_ids"] + + +@pytest.mark.asyncio +async def test_tool_use_stop_event_preserves_token_ids(model: VLLMModel) -> None: + # Minimal tool; we only need tool specs, not tool execution. + @tool + def echo_tool(text: str) -> str: + return text + + tool_specs = Agent(model=model, tools=[echo_tool]).tool_registry.get_all_tool_specs() + + events: list[dict] = [] + async for event in stream_messages( + model, + system_prompt=None, + messages=[{"role": "user", "content": [{"text": "Call echo_tool with text='hello'. Return nothing else."}]}], + tool_specs=tool_specs, + tool_choice={"tool": {"name": "echo_tool"}}, + return_token_ids=True, + logprobs=1, + max_tokens=64, + ): + events.append(event) + + stop_events = [e["event"] for e in events if isinstance(e, dict) and "event" in e and "messageStop" in e["event"]] + assert stop_events, f"no messageStop found; got: {events}" + + tool_stop = next((e for e in stop_events if e["messageStop"].get("stopReason") == "tool_use"), None) + assert tool_stop is not None, "expected stopReason='tool_use' (tool calling may not be enabled on server)" + + additional = tool_stop["messageStop"].get("additionalModelResponseFields") + assert isinstance(additional, dict), f"missing additionalModelResponseFields: {tool_stop}" + assert isinstance(additional.get("prompt_token_ids"), list) and additional["prompt_token_ids"] + assert isinstance(additional.get("token_ids"), list) and additional["token_ids"] + assert additional.get("logprobs") is not None