diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3829305b4..86bf72418 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -698,10 +698,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche ), } - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ + toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [ mcp_server.name for mcp_server in configuration.mcp_servers ] # Convert empty list to None for consistency with existing behavior @@ -797,30 +797,30 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None: def get_rag_toolgroups( - vector_db_ids: list[str], + vector_store_ids: list[str], ) -> list[Toolgroup] | None: """ - Return a list of RAG Tool groups if the given vector DB list is not empty. + Return a list of RAG Tool groups if the given vector store list is not empty. Generate a list containing a RAG knowledge search toolgroup if - vector database IDs are provided. + vector store IDs are provided. Parameters: - vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup. + vector_store_ids (list[str]): List of vector store identifiers to include in the toolgroup. Returns: list[Toolgroup] | None: A list with a single RAG toolgroup if - vector_db_ids is non-empty; otherwise, None. + vector_store_ids is non-empty; otherwise, None. """ return ( [ ToolgroupAgentToolGroupWithArgs( name="builtin::rag/knowledge_search", args={ - "vector_db_ids": vector_db_ids, + "vector_store_ids": vector_store_ids, }, ) ] - if vector_db_ids + if vector_store_ids else None ) diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py new file mode 100644 index 000000000..4c6047b28 --- /dev/null +++ b/src/app/endpoints/query_v2.py @@ -0,0 +1,395 @@ +"""Handler for REST API call to provide answer to query using Response API.""" + +import logging +from typing import Annotated, Any, cast + +from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack_client import APIConnectionError +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseObject, +) + +from fastapi import APIRouter, HTTPException, Request, status, Depends + +from app.endpoints.query import ( + evaluate_model_hints, + is_transcripts_enabled, + persist_user_conversation_details, + query_response, + select_model_and_provider_id, + validate_attachments_metadata, + get_topic_summary, +) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +import metrics +from models.config import Action +from models.database.conversations import UserConversation +from app.database import get_session +from models.requests import QueryRequest +from models.responses import QueryResponse +from utils.endpoints import ( + check_configuration_loaded, + get_system_prompt, + validate_model_provider_override, +) +from utils.mcp_headers import mcp_headers_dependency +from utils.transcripts import store_transcript +from utils.types import TurnSummary, ToolCallSummary + + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["query_v2"]) +auth_dependency = get_auth_dependency() + + +@router.post("/query", responses=query_response) +@authorize(Action.QUERY) +async def query_endpoint_handler_v2( + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(auth_dependency)], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> QueryResponse: + """ + Handle request to the /query endpoint using Response API. + + Processes a POST request to the /query endpoint, forwarding the + user's query to a selected Llama Stack LLM using Response API + and returning the generated response. + + Validates configuration and authentication, selects the appropriate model + and provider, retrieves the LLM response, updates metrics, and optionally + stores a transcript of the interaction. Handles connection errors to the + Llama Stack service by returning an HTTP 500 error. + + Returns: + QueryResponse: Contains the conversation ID and the LLM-generated response. + """ + check_configuration_loaded(configuration) + + # Enforce RBAC: optionally disallow overriding model/provider in requests + validate_model_provider_override(query_request, request.state.authorized_actions) + + # log Llama Stack configuration + logger.info("Llama stack config: %s", configuration.llama_stack_configuration) + + user_id, _, _, token = auth + + user_conversation: UserConversation | None = None + if query_request.conversation_id: + # TODO: Implement conversation once Llama Stack supports its API + pass + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + llama_stack_model_id, model_id, provider_id = select_model_and_provider_id( + await client.models.list(), + *evaluate_model_hints( + user_conversation=user_conversation, query_request=query_request + ), + ) + summary, conversation_id = await retrieve_response( + client, + llama_stack_model_id, + query_request, + token, + mcp_headers=mcp_headers, + provider_id=provider_id, + ) + # Update metrics for the LLM call + metrics.llm_calls_total.labels(provider_id, model_id).inc() + + # Compute topic summary if this is a brand new conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation) + .filter_by(id=conversation_id) + .first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) + + process_transcript_and_persist_conversation( + user_id=user_id, + conversation_id=conversation_id, + model_id=model_id, + provider_id=provider_id, + query_request=query_request, + summary=summary, + topic_summary=topic_summary, + ) + + return QueryResponse( + conversation_id=conversation_id, + response=summary.llm_response, + ) + + # connection to Llama Stack server + except APIConnectionError as e: + # Update metrics for the LLM call failure + metrics.llm_calls_failures_total.inc() + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + + +async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: dict[str, dict[str, str]] | None = None, + provider_id: str = "", +) -> tuple[TurnSummary, str]: + """ + Retrieve response from LLMs and agents. + + Retrieves a response from the Llama Stack LLM or agent for a + given query, handling shield configuration, tool usage, and + attachment validation. + + This function configures input/output shields, system prompts, + and toolgroups (including RAG and MCP integration) as needed + based on the query request and system configuration. It + validates attachments, manages conversation and session + context, and processes MCP headers for multi-component + processing. Shield violations in the response are detected and + corresponding metrics are updated. + + Parameters: + model_id (str): The identifier of the LLM model to use. + query_request (QueryRequest): The user's query and associated metadata. + token (str): The authentication token for authorization. + mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. + + Returns: + tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content + and the conversation ID. + """ + logger.info("Shields are not yet supported in Responses API. Disabling safety") + + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + tools: list[dict[str, Any]] = [] + if not query_request.no_tools: + # Get vector stores for RAG tools + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data + ] + + # Add RAG tools if vector stores are available + rag_tools = get_rag_tools(vector_store_ids) + if rag_tools: + tools.extend(rag_tools) + + # Add MCP server tools + mcp_tools = get_mcp_tools(configuration.mcp_servers, token) + if mcp_tools: + tools.extend(mcp_tools) + logger.debug( + "Configured %d MCP tools: %s", + len(mcp_tools), + [tool.get("server_label", "unknown") for tool in mcp_tools], + ) + + # Create OpenAI response using responses API + response = await client.responses.create( + input=query_request.query, + model=model_id, + instructions=system_prompt, + previous_response_id=query_request.conversation_id, + tools=(cast(Any, tools) if tools else cast(Any, None)), + stream=False, + store=True, + ) + response = cast(OpenAIResponseObject, response) + + logger.debug( + "Received response with ID: %s, output items: %d", + response.id, + len(response.output), + ) + # Return the response ID - client can use it for chaining if desired + conversation_id = response.id + + # Process OpenAI response format + llm_response = "" + tool_calls: list[ToolCallSummary] = [] + + for idx, output_item in enumerate(response.output): + logger.debug( + "Processing output item %d, type: %s", idx, type(output_item).__name__ + ) + + if hasattr(output_item, "content") and output_item.content: + # Extract text content from message output + if isinstance(output_item.content, list): + for content_item in output_item.content: + if hasattr(content_item, "text"): + llm_response += content_item.text + elif hasattr(output_item.content, "text"): + llm_response += output_item.content.text + elif isinstance(output_item.content, str): + llm_response += output_item.content + + if llm_response: + logger.info( + "Model response content: '%s'", + ( + llm_response[:200] + "..." + if len(llm_response) > 200 + else llm_response + ), + ) + + # Process tool calls if present + if hasattr(output_item, "tool_calls") and output_item.tool_calls: + logger.debug( + "Found %d tool calls in output item %d", + len(output_item.tool_calls), + idx, + ) + for tool_idx, tool_call in enumerate(output_item.tool_calls): + tool_name = ( + tool_call.function.name + if hasattr(tool_call, "function") + else "unknown" + ) + tool_args = ( + tool_call.function.arguments + if hasattr(tool_call, "function") + else {} + ) + + logger.debug( + "Tool call %d - Name: %s, Args: %s", + tool_idx, + tool_name, + str(tool_args)[:100], + ) + + tool_calls.append( + ToolCallSummary( + id=( + tool_call.id + if hasattr(tool_call, "id") + else str(len(tool_calls)) + ), + name=tool_name, + args=tool_args, + response=None, # Tool responses would be in subsequent output items + ) + ) + + logger.info( + "Response processing complete - Tool calls: %d, Response length: %d chars", + len(tool_calls), + len(llm_response), + ) + + summary = TurnSummary( + llm_response=llm_response, + tool_calls=tool_calls, + ) + + # TODO(ltomasbo): update token count metrics for the LLM call + # Update token count metrics for the LLM call + # model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id + # update_llm_token_count_from_response(response, model_label, provider_id, system_prompt) + + if not summary.llm_response: + logger.warning( + "Response lacks content (conversation_id=%s)", + conversation_id, + ) + return summary, conversation_id + + +def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: + """Convert vector store IDs to tools format for responses API.""" + if not vector_store_ids: + return None + + return [ + { + "type": "file_search", + "vector_store_ids": vector_store_ids, + "max_num_results": 10, + } + ] + + +def get_mcp_tools(mcp_servers: list, token: str | None = None) -> list[dict[str, Any]]: + """Convert MCP servers to tools format for responses API.""" + tools = [] + for mcp_server in mcp_servers: + tool_def = { + "type": "mcp", + "server_label": mcp_server.name, + "server_url": mcp_server.url, + "require_approval": "never", + } + + # Add authentication if token provided (Response API format) + if token: + tool_def["headers"] = {"Authorization": f"Bearer {token}"} + + tools.append(tool_def) + return tools + + +def process_transcript_and_persist_conversation( + user_id: str, + conversation_id: str, + model_id: str, + provider_id: str, + query_request: QueryRequest, + summary: TurnSummary, + topic_summary: str | None = None, +) -> None: + """Process transcript storage and persist conversation details.""" + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + store_transcript( + user_id=user_id, + conversation_id=conversation_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part of quota work + attachments=query_request.attachments or [], + ) + + persist_user_conversation_details( + user_id=user_id, + conversation_id=conversation_id, + model=model_id, + provider_id=provider_id, + topic_summary=topic_summary, + ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d4ad3088a..fde9b2d25 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1031,10 +1031,10 @@ async def retrieve_response( ), } - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ + toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [ mcp_server.name for mcp_server in configuration.mcp_servers ] # Convert empty list to None for consistency with existing behavior diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py new file mode 100644 index 000000000..f2179639e --- /dev/null +++ b/src/app/endpoints/streaming_query_v2.py @@ -0,0 +1,402 @@ +"""Streaming query handler using Responses API (v2).""" + +import logging +from typing import Annotated, Any, AsyncIterator, cast + +from llama_stack_client import APIConnectionError +from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseObjectStream, +) + +from fastapi import APIRouter, Depends, Request, HTTPException +from fastapi.responses import StreamingResponse +from starlette import status + + +from app.endpoints.query import ( + evaluate_model_hints, + is_transcripts_enabled, + persist_user_conversation_details, + select_model_and_provider_id, + validate_attachments_metadata, + validate_conversation_ownership, +) +from app.endpoints.query_v2 import ( + get_rag_tools, + get_mcp_tools, +) +from app.endpoints.streaming_query import ( + format_stream_data, + stream_start_event, + stream_end_event, +) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +import metrics +from models.config import Action +from models.database.conversations import UserConversation +from models.requests import QueryRequest +from utils.endpoints import ( + check_configuration_loaded, + get_system_prompt, + validate_model_provider_override, +) +from utils.mcp_headers import mcp_headers_dependency +from utils.transcripts import store_transcript +from utils.types import TurnSummary, ToolCallSummary + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["streaming_query_v2"]) +auth_dependency = get_auth_dependency() + + +@router.post("/streaming_query") +@authorize(Action.STREAMING_QUERY) +async def streaming_query_endpoint_handler_v2( # pylint: disable=too-many-locals + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(auth_dependency)], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> StreamingResponse: + """ + Handle request to the /streaming_query endpoint. + + This endpoint receives a query request, authenticates the user, + selects the appropriate model and provider, and streams + incremental response events from the Llama Stack backend to the + client. Events include start, token updates, tool calls, turn + completions, errors, and end-of-stream metadata. Optionally + stores the conversation transcript if enabled in configuration. + + Returns: + StreamingResponse: An HTTP streaming response yielding + SSE-formatted events for the query lifecycle. + + Raises: + HTTPException: Returns HTTP 500 if unable to connect to the + Llama Stack server. + """ + check_configuration_loaded(configuration) + + # Enforce RBAC: optionally disallow overriding model/provider in requests + validate_model_provider_override(query_request, request.state.authorized_actions) + + # log Llama Stack configuration + logger.info("Llama stack config: %s", configuration.llama_stack_configuration) + + user_id, _user_name, _skip_userid_check, token = auth + + user_conversation: UserConversation | None = None + if query_request.conversation_id: + user_conversation = validate_conversation_ownership( + user_id=user_id, conversation_id=query_request.conversation_id + ) + + if user_conversation is None: + logger.warning( + "User %s attempted to query conversation %s they don't own", + user_id, + query_request.conversation_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "response": "Access denied", + "cause": "You do not have permission to access this conversation", + }, + ) + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + llama_stack_model_id, model_id, provider_id = select_model_and_provider_id( + await client.models.list(), + *evaluate_model_hints(user_conversation=None, query_request=query_request), + ) + + response, _ = await retrieve_response( + client, + llama_stack_model_id, + query_request, + token, + mcp_headers=mcp_headers, + ) + metadata_map: dict[str, dict[str, Any]] = {} + + async def response_generator( + turn_response: AsyncIterator[OpenAIResponseObjectStream], + ) -> AsyncIterator[str]: + """ + Generate SSE formatted streaming response. + + Asynchronously generates a stream of Server-Sent Events + (SSE) representing incremental responses from a + language model turn. + + Yields start, token, tool call, turn completion, and + end events as SSE-formatted strings. Collects the + complete response for transcript storage if enabled. + """ + chunk_id = 0 + summary = TurnSummary(llm_response="", tool_calls=[]) + + # Accumulators for Responses API + text_parts: list[str] = [] + tool_item_registry: dict[str, dict[str, str]] = {} + emitted_turn_complete = False + + # Handle conversation id and start event in-band on response.created + conv_id = "" + + logger.debug("Starting streaming response (Responses API) processing") + + async for chunk in turn_response: + event_type = getattr(chunk, "type", None) + logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) + + # Emit start and persist on response.created + if event_type == "response.created": + try: + conv_id = getattr(chunk, "response").id + except Exception: + conv_id = "" + yield stream_start_event(conv_id) + if conv_id: + persist_user_conversation_details( + user_id=user_id, + conversation_id=conv_id, + model=model_id, + provider_id=provider_id, + ) + continue + + # Text streaming + if event_type == "response.output_text.delta": + delta = getattr(chunk, "delta", "") + if delta: + text_parts.append(delta) + yield format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "token": delta, + }, + } + ) + chunk_id += 1 + + # Final text of the output (capture, but emit at response.completed) + elif event_type == "response.output_text.done": + final_text = getattr(chunk, "text", "") + if final_text: + summary.llm_response = final_text + + # Content part started - emit an empty token to kick off UI streaming if desired + elif event_type == "response.content_part.added": + yield format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "token": "", + }, + } + ) + chunk_id += 1 + + # Track tool call items as they are added so we can build a summary later + elif event_type == "response.output_item.added": + item = getattr(chunk, "item", None) + item_type = getattr(item, "type", None) + if item and item_type == "function_call": + item_id = getattr(item, "id", "") + name = getattr(item, "name", "function_call") + call_id = getattr(item, "call_id", item_id) + if item_id: + tool_item_registry[item_id] = { + "name": name, + "call_id": call_id, + } + + # Stream tool call arguments as tool_call events + elif event_type == "response.function_call_arguments.delta": + delta = getattr(chunk, "delta", "") + yield format_stream_data( + { + "event": "tool_call", + "data": { + "id": chunk_id, + "role": "tool_execution", + "token": delta, + }, + } + ) + chunk_id += 1 + + # Finalize tool call arguments and append to summary + elif event_type in ( + "response.function_call_arguments.done", + "response.mcp_call.arguments.done", + ): + item_id = getattr(chunk, "item_id", "") + arguments = getattr(chunk, "arguments", "") + meta = tool_item_registry.get(item_id, {}) + summary.tool_calls.append( + ToolCallSummary( + id=meta.get("call_id", item_id or "unknown"), + name=meta.get("name", "tool_call"), + args=arguments, + response=None, + ) + ) + + # Completed response - capture final text if any + elif event_type == "response.completed": + if not emitted_turn_complete: + final_message = summary.llm_response or "".join(text_parts) + yield format_stream_data( + { + "event": "turn_complete", + "data": { + "id": chunk_id, + "token": final_message, + }, + } + ) + chunk_id += 1 + emitted_turn_complete = True + + # Ignore other event types for now; could add heartbeats if desired + + logger.debug( + "Streaming complete - Tool calls: %d, Response chars: %d", + len(summary.tool_calls), + len(summary.llm_response), + ) + + yield stream_end_event(metadata_map) + + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + store_transcript( + user_id=user_id, + conversation_id=conv_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part + # of quota work + attachments=query_request.attachments or [], + ) + + # Conversation persistence is handled inside the stream + # once the response.created event provides the ID + + # Update metrics for the LLM call + metrics.llm_calls_total.labels(provider_id, model_id).inc() + + return StreamingResponse(response_generator(response)) + + # connection to Llama Stack server + except APIConnectionError as e: + # Update metrics for the LLM call failure + metrics.llm_calls_failures_total.inc() + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + + +async def retrieve_response( + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: dict[str, dict[str, str]] | None = None, +) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]: + """ + Retrieve response from LLMs and agents. + + Asynchronously retrieves a streaming response and conversation + ID from the Llama Stack agent for a given user query. + + This function configures input/output shields, system prompt, + and tool usage based on the request and environment. It + prepares the agent with appropriate headers and toolgroups, + validates attachments if present, and initiates a streaming + turn with the user's query and any provided documents. + + Parameters: + model_id (str): Identifier of the model to use for the query. + query_request (QueryRequest): The user's query and associated metadata. + token (str): Authentication token for downstream services. + mcp_headers (dict[str, dict[str, str]], optional): + Multi-cluster proxy headers for tool integrations. + + Returns: + tuple: A tuple containing the streaming response object + and the conversation ID. + """ + logger.info("Shields are not yet supported in Responses API. Disabling safety") + + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + tools: list[dict[str, Any]] = [] + if not query_request.no_tools: + # Get vector stores for RAG tools + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data + ] + + # Add RAG tools if vector stores are available + rag_tools = get_rag_tools(vector_store_ids) + if rag_tools: + tools.extend(rag_tools) + + # Add MCP server tools + mcp_tools = get_mcp_tools(configuration.mcp_servers, token) + if mcp_tools: + tools.extend(mcp_tools) + logger.debug( + "Configured %d MCP tools: %s", + len(mcp_tools), + [tool.get("server_label", "unknown") for tool in mcp_tools], + ) + + response = await client.responses.create( + input=query_request.query, + model=model_id, + instructions=system_prompt, + previous_response_id=query_request.conversation_id, + tools=(cast(Any, tools) if tools else cast(Any, None)), + stream=True, + store=True, + ) + + response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) + + # For streaming responses, the ID arrives in the first 'response.created' chunk + # Return empty conversation_id here; it will be set once the first chunk is received + return response_stream, "" diff --git a/src/app/routers.py b/src/app/routers.py index 66c707668..00cf047f4 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -13,11 +13,14 @@ config, feedback, streaming_query, + streaming_query_v2, authorized, conversations, conversations_v2, metrics, tools, + # V2 endpoints for Response API support + query_v2, ) @@ -28,6 +31,8 @@ def include_routers(app: FastAPI) -> None: app: The `FastAPI` app instance. """ app.include_router(root.router) + + # V1 endpoints - Agent API (legacy) app.include_router(info.router, prefix="/v1") app.include_router(models.router, prefix="/v1") app.include_router(tools.router, prefix="/v1") @@ -40,6 +45,11 @@ def include_routers(app: FastAPI) -> None: app.include_router(conversations.router, prefix="/v1") app.include_router(conversations_v2.router, prefix="/v2") + # V2 endpoints - Response API support + app.include_router(query_v2.router, prefix="/v2") + app.include_router(streaming_query_v2.router, prefix="/v2") + app.include_router(conversations.router, prefix="/v3") + # road-core does not version these endpoints app.include_router(health.router) app.include_router(authorized.router) diff --git a/src/utils/suid.py b/src/utils/suid.py index 4dc9ca5e8..a3783ebf3 100644 --- a/src/utils/suid.py +++ b/src/utils/suid.py @@ -18,9 +18,10 @@ def get_suid() -> str: def check_suid(suid: str) -> bool: """ - Check if given string is a proper session ID. + Check if given string is a proper session ID or response ID. - Returns True if the string is a valid UUID, False otherwise. + Returns True if the string is a valid UUID or if it starts with resp-/resp_ + and it follows a valid UUID string, False otherwise. Parameters: suid (str | bytes): UUID value to validate — accepts a UUID string or @@ -30,8 +31,25 @@ def check_suid(suid: str) -> bool: Validation is performed by attempting to construct uuid.UUID(suid); invalid formats or types result in False. """ + if not isinstance(suid, str) or not suid: + return False + + # Handle Responses API IDs + if suid.startswith("resp-") or suid.startswith("resp_"): + token = suid[5:] + if not token: + return False + # If truncated (e.g., shell cut reduced length), pad to canonical UUID length + if len(token) < 36: + token = token + ("0" * (36 - len(token))) + try: + uuid.UUID(token) + return True + except (ValueError, TypeError): + return False + + # Otherwise, enforce UUID format try: - # accepts strings and bytes only uuid.UUID(suid) return True except (ValueError, TypeError): diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py new file mode 100644 index 000000000..7b43edb9a --- /dev/null +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -0,0 +1,253 @@ +# pylint: disable=redefined-outer-name, import-error +"""Unit tests for the /query (v2) REST API endpoint using Responses API.""" + +import pytest +from fastapi import HTTPException, status, Request + +from llama_stack_client import APIConnectionError + +from models.requests import QueryRequest, Attachment +from models.config import ModelContextProtocolServer + +from app.endpoints.query_v2 import ( + get_rag_tools, + get_mcp_tools, + retrieve_response, + query_endpoint_handler_v2, +) + + +@pytest.fixture +def dummy_request() -> Request: + req = Request(scope={"type": "http"}) + return req + + +def test_get_rag_tools(): + assert get_rag_tools([]) is None + + tools = get_rag_tools(["db1", "db2"]) + assert isinstance(tools, list) + assert tools[0]["type"] == "file_search" + assert tools[0]["vector_store_ids"] == ["db1", "db2"] + assert tools[0]["max_num_results"] == 10 + + +def test_get_mcp_tools_with_and_without_token(): + servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ModelContextProtocolServer(name="git", url="https://git.example.com/mcp"), + ] + + tools_no_token = get_mcp_tools(servers, token=None) + assert len(tools_no_token) == 2 + assert tools_no_token[0]["type"] == "mcp" + assert tools_no_token[0]["server_label"] == "fs" + assert tools_no_token[0]["server_url"] == "http://localhost:3000" + assert "headers" not in tools_no_token[0] + + tools_with_token = get_mcp_tools(servers, token="abc") + assert len(tools_with_token) == 2 + assert tools_with_token[1]["type"] == "mcp" + assert tools_with_token[1]["server_label"] == "git" + assert tools_with_token[1]["server_url"] == "https://git.example.com/mcp" + assert tools_with_token[1]["headers"] == {"Authorization": "Bearer abc"} + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_bypasses_tools(mocker): + mock_client = mocker.Mock() + # responses.create returns a synthetic OpenAI-like response + response_obj = mocker.Mock() + response_obj.id = "resp-1" + response_obj.output = [] + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + # vector_dbs.list should not matter when no_tools=True, but keep it valid + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + + # Ensure system prompt resolution does not require real config + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello", no_tools=True) + summary, conv_id = await retrieve_response(mock_client, "model-x", qr, token="tkn") + + assert conv_id == "resp-1" + assert summary.llm_response == "" + # tools must be passed as None + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["tools"] is None + assert kwargs["model"] == "model-x" + assert kwargs["instructions"] == "PROMPT" + + +@pytest.mark.asyncio +async def test_retrieve_response_builds_rag_and_mcp_tools(mocker): + mock_client = mocker.Mock() + response_obj = mocker.Mock() + response_obj.id = "resp-2" + response_obj.output = [] + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_client.vector_dbs.list = mocker.AsyncMock( + return_value=[mocker.Mock(identifier="dbA")] + ) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mock_cfg = mocker.Mock() + mock_cfg.mcp_servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ] + mocker.patch("app.endpoints.query_v2.configuration", mock_cfg) + + qr = QueryRequest(query="hello") + await retrieve_response(mock_client, "model-y", qr, token="mytoken") + + kwargs = mock_client.responses.create.call_args.kwargs + tools = kwargs["tools"] + assert isinstance(tools, list) + # Expect one file_search and one mcp tool + tool_types = {t.get("type") for t in tools} + assert tool_types == {"file_search", "mcp"} + file_search = next(t for t in tools if t["type"] == "file_search") + assert file_search["vector_store_ids"] == ["dbA"] + mcp_tool = next(t for t in tools if t["type"] == "mcp") + assert mcp_tool["server_label"] == "fs" + assert mcp_tool["headers"] == {"Authorization": "Bearer mytoken"} + + +@pytest.mark.asyncio +async def test_retrieve_response_parses_output_and_tool_calls(mocker): + mock_client = mocker.Mock() + + # Build output with content variants and tool calls + tool_call_fn = mocker.Mock(name="fn") + tool_call_fn.name = "do_something" + tool_call_fn.arguments = {"x": 1} + tool_call = mocker.Mock() + tool_call.id = "tc-1" + tool_call.function = tool_call_fn + + output_item_1 = mocker.Mock() + output_item_1.content = [mocker.Mock(text="Hello "), mocker.Mock(text="world")] + output_item_1.tool_calls = [] + + output_item_2 = mocker.Mock() + output_item_2.content = "!" + output_item_2.tool_calls = [tool_call] + + response_obj = mocker.Mock() + response_obj.id = "resp-3" + response_obj.output = [output_item_1, output_item_2] + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id = await retrieve_response(mock_client, "model-z", qr, token="tkn") + + assert conv_id == "resp-3" + assert summary.llm_response == "Hello world!" + assert len(summary.tool_calls) == 1 + assert summary.tool_calls[0].id == "tc-1" + assert summary.tool_calls[0].name == "do_something" + assert summary.tool_calls[0].args == {"x": 1} + + +@pytest.mark.asyncio +async def test_retrieve_response_validates_attachments(mocker): + mock_client = mocker.Mock() + response_obj = mocker.Mock() + response_obj.id = "resp-4" + response_obj.output = [] + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + validate_spy = mocker.patch( + "app.endpoints.query_v2.validate_attachments_metadata", return_value=None + ) + + attachments = [ + Attachment(attachment_type="log", content_type="text/plain", content="x"), + ] + + qr = QueryRequest(query="hello", attachments=attachments) + _summary, _cid = await retrieve_response(mock_client, "model-a", qr, token="tkn") + + validate_spy.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_endpoint_handler_v2_success(mocker, dummy_request): + # Mock configuration to avoid configuration not loaded errors + mock_config = mocker.Mock() + mock_config.llama_stack_configuration = mocker.Mock() + mocker.patch("app.endpoints.query_v2.configuration", mock_config) + + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch( + "app.endpoints.query_v2.evaluate_model_hints", return_value=(None, None) + ) + mocker.patch( + "app.endpoints.query_v2.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + summary = mocker.Mock(llm_response="ANSWER", tool_calls=[]) + mocker.patch( + "app.endpoints.query_v2.retrieve_response", + return_value=(summary, "conv-1"), + ) + mocker.patch( + "app.endpoints.query_v2.process_transcript_and_persist_conversation", + return_value=None, + ) + + metric = mocker.patch("metrics.llm_calls_total") + + res = await query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert res.conversation_id == "conv-1" + assert res.response == "ANSWER" + metric.labels("p", "m").inc.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_endpoint_handler_v2_api_connection_error(mocker, dummy_request): + # Mock configuration to avoid configuration not loaded errors + mock_config = mocker.Mock() + mock_config.llama_stack_configuration = mocker.Mock() + mocker.patch("app.endpoints.query_v2.configuration", mock_config) + + def _raise(*_args, **_kwargs): + raise APIConnectionError(request=None) + + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise) + + fail_metric = mocker.patch("metrics.llm_calls_failures_total") + + with pytest.raises(HTTPException) as exc: + await query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in str(exc.value.detail) + fail_metric.inc.assert_called_once() diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py new file mode 100644 index 000000000..3a84a0974 --- /dev/null +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -0,0 +1,211 @@ +# pylint: disable=redefined-outer-name, import-error +"""Unit tests for the /streaming_query (v2) endpoint using Responses API.""" + +from types import SimpleNamespace +import pytest +from fastapi import HTTPException, status, Request +from fastapi.responses import StreamingResponse + +from llama_stack_client import APIConnectionError + +from models.requests import QueryRequest +from models.config import ModelContextProtocolServer + +from app.endpoints.streaming_query_v2 import ( + retrieve_response, + streaming_query_endpoint_handler_v2, +) + + +@pytest.fixture +def dummy_request() -> Request: + req = Request(scope={"type": "http"}) + # Provide a permissive authorized_actions set to satisfy RBAC check + from models.config import Action # import here to avoid global import errors + + req.state.authorized_actions = set(Action) + return req + + +@pytest.mark.asyncio +async def test_retrieve_response_builds_rag_and_mcp_tools(mocker): + mock_client = mocker.Mock() + mock_client.vector_dbs.list = mocker.AsyncMock( + return_value=[mocker.Mock(identifier="db1")] + ) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + + mock_cfg = mocker.Mock() + mock_cfg.mcp_servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ] + mocker.patch("app.endpoints.streaming_query_v2.configuration", mock_cfg) + + qr = QueryRequest(query="hello") + await retrieve_response(mock_client, "model-z", qr, token="tok") + + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["stream"] is True + tools = kwargs["tools"] + assert isinstance(tools, list) + types = {t.get("type") for t in tools} + assert types == {"file_search", "mcp"} + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_passes_none(mocker): + mock_client = mocker.Mock() + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + mocker.patch( + "app.endpoints.streaming_query_v2.configuration", mocker.Mock(mcp_servers=[]) + ) + + qr = QueryRequest(query="hello", no_tools=True) + await retrieve_response(mock_client, "model-z", qr, token="tok") + + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["tools"] is None + assert kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_v2_success_yields_events( + mocker, dummy_request +): + # Skip real config checks + mocker.patch("app.endpoints.streaming_query_v2.check_configuration_loaded") + + # Model selection plumbing + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch( + "app.endpoints.streaming_query_v2.evaluate_model_hints", + return_value=(None, None), + ) + mocker.patch( + "app.endpoints.streaming_query_v2.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + # Replace SSE helpers for deterministic output + mocker.patch( + "app.endpoints.streaming_query_v2.stream_start_event", + lambda conv_id: f"START:{conv_id}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.format_stream_data", + lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.stream_end_event", lambda _m: "END\n" + ) + + # Conversation persistence and transcripts disabled + persist_spy = mocker.patch( + "app.endpoints.streaming_query_v2.persist_user_conversation_details", + return_value=None, + ) + mocker.patch( + "app.endpoints.streaming_query_v2.is_transcripts_enabled", return_value=False + ) + + # Build a fake async stream of chunks + async def fake_stream(): + yield SimpleNamespace( + type="response.created", response=SimpleNamespace(id="conv-xyz") + ) + yield SimpleNamespace(type="response.content_part.added") + yield SimpleNamespace(type="response.output_text.delta", delta="Hello ") + yield SimpleNamespace(type="response.output_text.delta", delta="world") + yield SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace( + type="function_call", id="item1", name="search", call_id="call1" + ), + ) + yield SimpleNamespace( + type="response.function_call_arguments.delta", delta='{"q":"x"}' + ) + yield SimpleNamespace( + type="response.function_call_arguments.done", + item_id="item1", + arguments='{"q":"x"}', + ) + yield SimpleNamespace(type="response.output_text.done", text="Hello world") + yield SimpleNamespace(type="response.completed") + + mocker.patch( + "app.endpoints.streaming_query_v2.retrieve_response", + return_value=(fake_stream(), ""), + ) + + metric = mocker.patch("metrics.llm_calls_total") + + resp = await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert isinstance(resp, StreamingResponse) + metric.labels("p", "m").inc.assert_called_once() + + # Collect emitted events + events: list[str] = [] + async for chunk in resp.body_iterator: + s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) + events.append(s) + + # Validate event sequence and content + assert events[0] == "START:conv-xyz\n" + # content_part.added triggers empty token + assert events[1] == "EV:token:\n" + assert events[2] == "EV:token:Hello \n" + assert events[3] == "EV:token:world\n" + # tool call delta + assert events[4].startswith("EV:tool_call:") + # turn complete and end + assert "EV:turn_complete:Hello world\n" in events + assert events[-1] == "END\n" + + # Verify conversation persistence was invoked with the created id + persist_spy.assert_called_once() + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_v2_api_connection_error( + mocker, dummy_request +): + mocker.patch("app.endpoints.streaming_query_v2.check_configuration_loaded") + + def _raise(*_a, **_k): + raise APIConnectionError(request=None) + + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise) + + fail_metric = mocker.patch("metrics.llm_calls_failures_total") + + with pytest.raises(HTTPException) as exc: + await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "tok"), + mcp_headers={}, + ) + + assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in str(exc.value.detail) + fail_metric.inc.assert_called_once() diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index e466fca44..c1131a2ba 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -15,10 +15,12 @@ shields, providers, query, + query_v2, health, config, feedback, streaming_query, + streaming_query_v2, authorized, metrics, tools, @@ -64,7 +66,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 15 + assert len(app.routers) == 17 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -72,7 +74,9 @@ def test_include_routers() -> None: assert shields.router in app.get_routers() assert providers.router in app.get_routers() assert query.router in app.get_routers() + assert query_v2.router in app.get_routers() assert streaming_query.router in app.get_routers() + assert streaming_query_v2.router in app.get_routers() assert config.router in app.get_routers() assert feedback.router in app.get_routers() assert health.router in app.get_routers() @@ -88,7 +92,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 15 + assert len(app.routers) == 17 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -97,6 +101,8 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(providers.router) == "/v1" assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1" + assert app.get_router_prefix(query_v2.router) == "/v2" + assert app.get_router_prefix(streaming_query_v2.router) == "/v2" assert app.get_router_prefix(config.router) == "/v1" assert app.get_router_prefix(feedback.router) == "/v1" assert app.get_router_prefix(health.router) == ""