From da584aada353dc08f3fada6ef83d044b1b85dc9f Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Fri, 12 Sep 2025 14:11:25 +0200 Subject: [PATCH 1/7] Base implementation of query_v2 Note: metrics are not being reported for now --- src/app/endpoints/query_v2.py | 377 +++++++++++++++++++++++++++++++++ src/app/routers.py | 7 + tests/unit/app/test_routers.py | 7 +- 3 files changed, 389 insertions(+), 2 deletions(-) create mode 100644 src/app/endpoints/query_v2.py diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py new file mode 100644 index 000000000..1a7174603 --- /dev/null +++ b/src/app/endpoints/query_v2.py @@ -0,0 +1,377 @@ +"""Handler for REST API call to provide answer to query using Response API.""" + +import logging +from typing import Annotated, Any, cast + +from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack_client import APIConnectionError +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseObject, +) + +from fastapi import APIRouter, HTTPException, Request, status, Depends + +from app.endpoints.query import ( + evaluate_model_hints, + is_transcripts_enabled, + persist_user_conversation_details, + query_response, + select_model_and_provider_id, + validate_attachments_metadata, +) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +import metrics +from models.config import Action +from models.database.conversations import UserConversation +from models.requests import QueryRequest +from models.responses import QueryResponse +from utils.endpoints import ( + check_configuration_loaded, + get_system_prompt, + validate_model_provider_override, +) +from utils.mcp_headers import mcp_headers_dependency +from utils.transcripts import store_transcript +from utils.types import TurnSummary, ToolCallSummary + + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["query_v2"]) +auth_dependency = get_auth_dependency() + + +@router.post("/query", responses=query_response) +@authorize(Action.QUERY) +async def query_endpoint_handler_v2( + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(auth_dependency)], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> QueryResponse: + """ + Handle request to the /query endpoint using Response API. + + Processes a POST request to the /query endpoint, forwarding the + user's query to a selected Llama Stack LLM using Response API + and returning the generated response. + + Validates configuration and authentication, selects the appropriate model + and provider, retrieves the LLM response, updates metrics, and optionally + stores a transcript of the interaction. Handles connection errors to the + Llama Stack service by returning an HTTP 500 error. + + Returns: + QueryResponse: Contains the conversation ID and the LLM-generated response. + """ + check_configuration_loaded(configuration) + + # Enforce RBAC: optionally disallow overriding model/provider in requests + validate_model_provider_override(query_request, request.state.authorized_actions) + + # log Llama Stack configuration + logger.info("Llama stack config: %s", configuration.llama_stack_configuration) + + user_id, _, _, token = auth + + user_conversation: UserConversation | None = None + if query_request.conversation_id: + # TODO: Implement conversation once Llama Stack supports its API + pass + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + llama_stack_model_id, model_id, provider_id = select_model_and_provider_id( + await client.models.list(), + *evaluate_model_hints( + user_conversation=user_conversation, query_request=query_request + ), + ) + summary, conversation_id = await retrieve_response( + client, + llama_stack_model_id, + query_request, + token, + mcp_headers=mcp_headers, + provider_id=provider_id, + ) + # Update metrics for the LLM call + metrics.llm_calls_total.labels(provider_id, model_id).inc() + + process_transcript_and_persist_conversation( + user_id=user_id, + conversation_id=conversation_id, + model_id=model_id, + provider_id=provider_id, + query_request=query_request, + summary=summary, + ) + + return QueryResponse( + conversation_id=conversation_id, + response=summary.llm_response, + ) + + # connection to Llama Stack server + except APIConnectionError as e: + # Update metrics for the LLM call failure + metrics.llm_calls_failures_total.inc() + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + + +async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: dict[str, dict[str, str]] | None = None, + provider_id: str = "", +) -> tuple[TurnSummary, str]: + """ + Retrieve response from LLMs and agents. + + Retrieves a response from the Llama Stack LLM or agent for a + given query, handling shield configuration, tool usage, and + attachment validation. + + This function configures input/output shields, system prompts, + and toolgroups (including RAG and MCP integration) as needed + based on the query request and system configuration. It + validates attachments, manages conversation and session + context, and processes MCP headers for multi-component + processing. Shield violations in the response are detected and + corresponding metrics are updated. + + Parameters: + model_id (str): The identifier of the LLM model to use. + query_request (QueryRequest): The user's query and associated metadata. + token (str): The authentication token for authorization. + mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. + + Returns: + tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content + and the conversation ID. + """ + logger.info("Shields are not yet supported in Responses API. Disabling safety") + + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + tools: list[dict[str, Any]] = [] + if not query_request.no_tools: + # Get vector databases for RAG tools + vector_db_ids = [ + vector_db.identifier for vector_db in await client.vector_dbs.list() + ] + + # Add RAG tools if vector databases are available + rag_tools = get_rag_tools(vector_db_ids) + if rag_tools: + tools.extend(rag_tools) + + # Add MCP server tools + mcp_tools = get_mcp_tools(configuration.mcp_servers, token) + if mcp_tools: + tools.extend(mcp_tools) + logger.debug( + "Configured %d MCP tools: %s", + len(mcp_tools), + [tool.get("server_label", "unknown") for tool in mcp_tools], + ) + + # Create OpenAI response using responses API + response = await client.responses.create( + input=query_request.query, + model=model_id, + instructions=system_prompt, + previous_response_id=query_request.conversation_id, + tools=(cast(Any, tools) if tools else cast(Any, None)), + stream=False, + store=True, + ) + response = cast(OpenAIResponseObject, response) + + logger.debug( + "Received response with ID: %s, output items: %d", + response.id, + len(response.output), + ) + # Return the response ID - client can use it for chaining if desired + conversation_id = response.id + + # Process OpenAI response format + llm_response = "" + tool_calls: list[ToolCallSummary] = [] + + for idx, output_item in enumerate(response.output): + logger.debug( + "Processing output item %d, type: %s", idx, type(output_item).__name__ + ) + + if hasattr(output_item, "content") and output_item.content: + # Extract text content from message output + if isinstance(output_item.content, list): + for content_item in output_item.content: + if hasattr(content_item, "text"): + llm_response += content_item.text + elif hasattr(output_item.content, "text"): + llm_response += output_item.content.text + elif isinstance(output_item.content, str): + llm_response += output_item.content + + if llm_response: + logger.info( + "Model response content: '%s'", + ( + llm_response[:200] + "..." + if len(llm_response) > 200 + else llm_response + ), + ) + + # Process tool calls if present + if hasattr(output_item, "tool_calls") and output_item.tool_calls: + logger.debug( + "Found %d tool calls in output item %d", + len(output_item.tool_calls), + idx, + ) + for tool_idx, tool_call in enumerate(output_item.tool_calls): + tool_name = ( + tool_call.function.name + if hasattr(tool_call, "function") + else "unknown" + ) + tool_args = ( + tool_call.function.arguments + if hasattr(tool_call, "function") + else {} + ) + + logger.debug( + "Tool call %d - Name: %s, Args: %s", + tool_idx, + tool_name, + str(tool_args)[:100], + ) + + tool_calls.append( + ToolCallSummary( + id=( + tool_call.id + if hasattr(tool_call, "id") + else str(len(tool_calls)) + ), + name=tool_name, + args=tool_args, + response=None, # Tool responses would be in subsequent output items + ) + ) + + logger.info( + "Response processing complete - Tool calls: %d, Response length: %d chars", + len(tool_calls), + len(llm_response), + ) + + summary = TurnSummary( + llm_response=llm_response, + tool_calls=tool_calls, + ) + + # TODO(ltomasbo): update token count metrics for the LLM call + # Update token count metrics for the LLM call + # model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id + # update_llm_token_count_from_response(response, model_label, provider_id, system_prompt) + + if not summary.llm_response: + logger.warning( + "Response lacks content (conversation_id=%s)", + conversation_id, + ) + return summary, conversation_id + + +def get_rag_tools(vector_db_ids: list[str]) -> list[dict[str, Any]] | None: + """Convert vector DB IDs to tools format for responses API.""" + if not vector_db_ids: + return None + + return [ + { + "type": "file_search", + "vector_store_ids": vector_db_ids, + "max_num_results": 10, + } + ] + + +def get_mcp_tools(mcp_servers: list, token: str | None = None) -> list[dict[str, Any]]: + """Convert MCP servers to tools format for responses API.""" + tools = [] + for mcp_server in mcp_servers: + tool_def = { + "type": "mcp", + "server_label": mcp_server.name, + "server_url": mcp_server.url, + "require_approval": "never", + } + + # Add authentication if token provided (Response API format) + if token: + tool_def["headers"] = {"Authorization": f"Bearer {token}"} + + tools.append(tool_def) + return tools + + +def process_transcript_and_persist_conversation( + user_id: str, + conversation_id: str, + model_id: str, + provider_id: str, + query_request: QueryRequest, + summary: TurnSummary, +) -> None: + """Process transcript storage and persist conversation details.""" + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + store_transcript( + user_id=user_id, + conversation_id=conversation_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part of quota work + attachments=query_request.attachments or [], + ) + + persist_user_conversation_details( + user_id=user_id, + conversation_id=conversation_id, + model=model_id, + provider_id=provider_id, + ) diff --git a/src/app/routers.py b/src/app/routers.py index 66c707668..3350f5eed 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -18,6 +18,8 @@ conversations_v2, metrics, tools, + # V2 endpoints for Response API support + query_v2, ) @@ -28,6 +30,8 @@ def include_routers(app: FastAPI) -> None: app: The `FastAPI` app instance. """ app.include_router(root.router) + + # V1 endpoints - Agent API (legacy) app.include_router(info.router, prefix="/v1") app.include_router(models.router, prefix="/v1") app.include_router(tools.router, prefix="/v1") @@ -40,6 +44,9 @@ def include_routers(app: FastAPI) -> None: app.include_router(conversations.router, prefix="/v1") app.include_router(conversations_v2.router, prefix="/v2") + # V2 endpoints - Response API support + app.include_router(query_v2.router, prefix="/v2") + # road-core does not version these endpoints app.include_router(health.router) app.include_router(authorized.router) diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index e466fca44..a6c23a2b6 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -15,6 +15,7 @@ shields, providers, query, + query_v2, health, config, feedback, @@ -64,7 +65,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 15 + assert len(app.routers) == 16 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -72,6 +73,7 @@ def test_include_routers() -> None: assert shields.router in app.get_routers() assert providers.router in app.get_routers() assert query.router in app.get_routers() + assert query_v2.router in app.get_routers() assert streaming_query.router in app.get_routers() assert config.router in app.get_routers() assert feedback.router in app.get_routers() @@ -88,7 +90,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 15 + assert len(app.routers) == 16 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -97,6 +99,7 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(providers.router) == "/v1" assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1" + assert app.get_router_prefix(query_v2.router) == "/v2" assert app.get_router_prefix(config.router) == "/v1" assert app.get_router_prefix(feedback.router) == "/v1" assert app.get_router_prefix(health.router) == "" From 063e90da660de97890f69119238ed616aa1cf7fb Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 17 Sep 2025 15:30:34 +0200 Subject: [PATCH 2/7] Adding unit test coverage for query_v2 (cursor generated) --- tests/unit/app/endpoints/test_query_v2.py | 261 ++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 tests/unit/app/endpoints/test_query_v2.py diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py new file mode 100644 index 000000000..91e42c97b --- /dev/null +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -0,0 +1,261 @@ +# pylint: disable=redefined-outer-name, import-error +"""Unit tests for the /query (v2) REST API endpoint using Responses API.""" + +import pytest +from fastapi import HTTPException, status, Request + +from llama_stack_client import APIConnectionError + +from models.requests import QueryRequest, Attachment +from models.config import ModelContextProtocolServer + +from app.endpoints.query_v2 import ( + get_rag_tools, + get_mcp_tools, + retrieve_response, + query_endpoint_handler_v2, +) + + +@pytest.fixture +def dummy_request() -> Request: + req = Request(scope={"type": "http"}) + return req + + +def test_get_rag_tools(): + assert get_rag_tools([]) is None + + tools = get_rag_tools(["db1", "db2"]) + assert isinstance(tools, list) + assert tools[0]["type"] == "file_search" + assert tools[0]["vector_store_ids"] == ["db1", "db2"] + assert tools[0]["max_num_results"] == 10 + + +def test_get_mcp_tools_with_and_without_token(): + servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ModelContextProtocolServer(name="git", url="https://git.example.com/mcp"), + ] + + tools_no_token = get_mcp_tools(servers, token=None) + assert len(tools_no_token) == 2 + assert tools_no_token[0]["type"] == "mcp" + assert tools_no_token[0]["server_label"] == "fs" + assert tools_no_token[0]["server_url"] == "http://localhost:3000" + assert "headers" not in tools_no_token[0] + + tools_with_token = get_mcp_tools(servers, token="abc") + assert len(tools_with_token) == 2 + assert tools_with_token[1]["type"] == "mcp" + assert tools_with_token[1]["server_label"] == "git" + assert tools_with_token[1]["server_url"] == "https://git.example.com/mcp" + assert tools_with_token[1]["headers"] == {"Authorization": "Bearer abc"} + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_bypasses_tools(mocker): + mock_client = mocker.Mock() + # responses.create returns a synthetic OpenAI-like response + response_obj = mocker.Mock() + response_obj.id = "resp-1" + response_obj.output = [] + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + # vector_dbs.list should not matter when no_tools=True, but keep it valid + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + + # Ensure system prompt resolution does not require real config + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello", no_tools=True) + summary, conv_id = await retrieve_response( + mock_client, "model-x", qr, token="tkn" + ) + + assert conv_id == "resp-1" + assert summary.llm_response == "" + # tools must be passed as None + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["tools"] is None + assert kwargs["model"] == "model-x" + assert kwargs["instructions"] == "PROMPT" + + +@pytest.mark.asyncio +async def test_retrieve_response_builds_rag_and_mcp_tools(mocker): + mock_client = mocker.Mock() + response_obj = mocker.Mock() + response_obj.id = "resp-2" + response_obj.output = [] + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_client.vector_dbs.list = mocker.AsyncMock( + return_value=[mocker.Mock(identifier="dbA")] + ) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mock_cfg = mocker.Mock() + mock_cfg.mcp_servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ] + mocker.patch("app.endpoints.query_v2.configuration", mock_cfg) + + qr = QueryRequest(query="hello") + await retrieve_response(mock_client, "model-y", qr, token="mytoken") + + kwargs = mock_client.responses.create.call_args.kwargs + tools = kwargs["tools"] + assert isinstance(tools, list) + # Expect one file_search and one mcp tool + tool_types = {t.get("type") for t in tools} + assert tool_types == {"file_search", "mcp"} + file_search = next(t for t in tools if t["type"] == "file_search") + assert file_search["vector_store_ids"] == ["dbA"] + mcp_tool = next(t for t in tools if t["type"] == "mcp") + assert mcp_tool["server_label"] == "fs" + assert mcp_tool["headers"] == {"Authorization": "Bearer mytoken"} + + +@pytest.mark.asyncio +async def test_retrieve_response_parses_output_and_tool_calls(mocker): + mock_client = mocker.Mock() + + # Build output with content variants and tool calls + tool_call_fn = mocker.Mock(name="fn") + tool_call_fn.name = "do_something" + tool_call_fn.arguments = {"x": 1} + tool_call = mocker.Mock() + tool_call.id = "tc-1" + tool_call.function = tool_call_fn + + output_item_1 = mocker.Mock() + output_item_1.content = [mocker.Mock(text="Hello "), mocker.Mock(text="world")] + output_item_1.tool_calls = [] + + output_item_2 = mocker.Mock() + output_item_2.content = "!" + output_item_2.tool_calls = [tool_call] + + response_obj = mocker.Mock() + response_obj.id = "resp-3" + response_obj.output = [output_item_1, output_item_2] + + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + qr = QueryRequest(query="hello") + summary, conv_id = await retrieve_response( + mock_client, "model-z", qr, token="tkn" + ) + + assert conv_id == "resp-3" + assert summary.llm_response == "Hello world!" + assert len(summary.tool_calls) == 1 + assert summary.tool_calls[0].id == "tc-1" + assert summary.tool_calls[0].name == "do_something" + assert summary.tool_calls[0].args == {"x": 1} + + +@pytest.mark.asyncio +async def test_retrieve_response_validates_attachments(mocker): + mock_client = mocker.Mock() + response_obj = mocker.Mock() + response_obj.id = "resp-4" + response_obj.output = [] + mock_client.responses.create = mocker.AsyncMock(return_value=response_obj) + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + + mocker.patch("app.endpoints.query_v2.get_system_prompt", return_value="PROMPT") + mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) + + validate_spy = mocker.patch( + "app.endpoints.query_v2.validate_attachments_metadata", return_value=None + ) + + attachments = [ + Attachment(attachment_type="log", content_type="text/plain", content="x"), + ] + + qr = QueryRequest(query="hello", attachments=attachments) + _summary, _cid = await retrieve_response( + mock_client, "model-a", qr, token="tkn" + ) + + validate_spy.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_endpoint_handler_v2_success(mocker, dummy_request): + # Mock configuration to avoid configuration not loaded errors + mock_config = mocker.Mock() + mock_config.llama_stack_configuration = mocker.Mock() + mocker.patch("app.endpoints.query_v2.configuration", mock_config) + + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch( + "app.endpoints.query_v2.evaluate_model_hints", return_value=(None, None) + ) + mocker.patch( + "app.endpoints.query_v2.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + summary = mocker.Mock(llm_response="ANSWER", tool_calls=[]) + mocker.patch( + "app.endpoints.query_v2.retrieve_response", + return_value=(summary, "conv-1"), + ) + mocker.patch( + "app.endpoints.query_v2.process_transcript_and_persist_conversation", + return_value=None, + ) + + metric = mocker.patch("metrics.llm_calls_total") + + res = await query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert res.conversation_id == "conv-1" + assert res.response == "ANSWER" + metric.labels("p", "m").inc.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_endpoint_handler_v2_api_connection_error(mocker, dummy_request): + # Mock configuration to avoid configuration not loaded errors + mock_config = mocker.Mock() + mock_config.llama_stack_configuration = mocker.Mock() + mocker.patch("app.endpoints.query_v2.configuration", mock_config) + + def _raise(*_args, **_kwargs): + raise APIConnectionError(request=None) + + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise) + + fail_metric = mocker.patch("metrics.llm_calls_failures_total") + + with pytest.raises(HTTPException) as exc: + await query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in str(exc.value.detail) + fail_metric.inc.assert_called_once() + + From 3a7818c1000dd665c380acdeee5c8aa6c8507b38 Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Mon, 15 Sep 2025 11:53:56 +0200 Subject: [PATCH 3/7] First version of streaming_query for Responses API (v2) --- src/app/endpoints/streaming_query_v2.py | 402 ++++++++++++++++++++++++ src/app/routers.py | 2 + tests/unit/app/test_routers.py | 7 +- 3 files changed, 409 insertions(+), 2 deletions(-) create mode 100644 src/app/endpoints/streaming_query_v2.py diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py new file mode 100644 index 000000000..b807340c5 --- /dev/null +++ b/src/app/endpoints/streaming_query_v2.py @@ -0,0 +1,402 @@ +"""Streaming query handler using Responses API (v2).""" + +import logging +from typing import Annotated, Any, AsyncIterator, cast + +from llama_stack_client import APIConnectionError +from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseObjectStream, +) + +from fastapi import APIRouter, Depends, Request, HTTPException +from fastapi.responses import StreamingResponse +from starlette import status + + +from app.endpoints.query import ( + evaluate_model_hints, + is_transcripts_enabled, + persist_user_conversation_details, + select_model_and_provider_id, + validate_attachments_metadata, + validate_conversation_ownership, +) +from app.endpoints.query_v2 import ( + get_rag_tools, + get_mcp_tools, +) +from app.endpoints.streaming_query import ( + format_stream_data, + stream_start_event, + stream_end_event, +) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +import metrics +from models.config import Action +from models.database.conversations import UserConversation +from models.requests import QueryRequest +from utils.endpoints import ( + check_configuration_loaded, + get_system_prompt, + validate_model_provider_override, +) +from utils.mcp_headers import mcp_headers_dependency +from utils.transcripts import store_transcript +from utils.types import TurnSummary, ToolCallSummary + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["streaming_query_v2"]) +auth_dependency = get_auth_dependency() + + +@router.post("/streaming_query") +@authorize(Action.STREAMING_QUERY) +async def streaming_query_endpoint_handler_v2( # pylint: disable=too-many-locals + request: Request, + query_request: QueryRequest, + auth: Annotated[AuthTuple, Depends(auth_dependency)], + mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), +) -> StreamingResponse: + """ + Handle request to the /streaming_query endpoint. + + This endpoint receives a query request, authenticates the user, + selects the appropriate model and provider, and streams + incremental response events from the Llama Stack backend to the + client. Events include start, token updates, tool calls, turn + completions, errors, and end-of-stream metadata. Optionally + stores the conversation transcript if enabled in configuration. + + Returns: + StreamingResponse: An HTTP streaming response yielding + SSE-formatted events for the query lifecycle. + + Raises: + HTTPException: Returns HTTP 500 if unable to connect to the + Llama Stack server. + """ + check_configuration_loaded(configuration) + + # Enforce RBAC: optionally disallow overriding model/provider in requests + validate_model_provider_override(query_request, request.state.authorized_actions) + + # log Llama Stack configuration + logger.info("Llama stack config: %s", configuration.llama_stack_configuration) + + user_id, _user_name, _skip_userid_check, token = auth + + user_conversation: UserConversation | None = None + if query_request.conversation_id: + user_conversation = validate_conversation_ownership( + user_id=user_id, conversation_id=query_request.conversation_id + ) + + if user_conversation is None: + logger.warning( + "User %s attempted to query conversation %s they don't own", + user_id, + query_request.conversation_id, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "response": "Access denied", + "cause": "You do not have permission to access this conversation", + }, + ) + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + llama_stack_model_id, model_id, provider_id = select_model_and_provider_id( + await client.models.list(), + *evaluate_model_hints(user_conversation=None, query_request=query_request), + ) + + response, _ = await retrieve_response( + client, + llama_stack_model_id, + query_request, + token, + mcp_headers=mcp_headers, + ) + metadata_map: dict[str, dict[str, Any]] = {} + + async def response_generator( + turn_response: AsyncIterator[OpenAIResponseObjectStream], + ) -> AsyncIterator[str]: + """ + Generate SSE formatted streaming response. + + Asynchronously generates a stream of Server-Sent Events + (SSE) representing incremental responses from a + language model turn. + + Yields start, token, tool call, turn completion, and + end events as SSE-formatted strings. Collects the + complete response for transcript storage if enabled. + """ + chunk_id = 0 + summary = TurnSummary(llm_response="", tool_calls=[]) + + # Accumulators for Responses API + text_parts: list[str] = [] + tool_item_registry: dict[str, dict[str, str]] = {} + emitted_turn_complete = False + + # Handle conversation id and start event in-band on response.created + conv_id = "" + + logger.debug("Starting streaming response (Responses API) processing") + + async for chunk in turn_response: + event_type = getattr(chunk, "type", None) + logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) + + # Emit start and persist on response.created + if event_type == "response.created": + try: + conv_id = getattr(chunk, "response").id + except Exception: + conv_id = "" + yield stream_start_event(conv_id) + if conv_id: + persist_user_conversation_details( + user_id=user_id, + conversation_id=conv_id, + model=model_id, + provider_id=provider_id, + ) + continue + + # Text streaming + if event_type == "response.output_text.delta": + delta = getattr(chunk, "delta", "") + if delta: + text_parts.append(delta) + yield format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "token": delta, + }, + } + ) + chunk_id += 1 + + # Final text of the output (capture, but emit at response.completed) + elif event_type == "response.output_text.done": + final_text = getattr(chunk, "text", "") + if final_text: + summary.llm_response = final_text + + # Content part started - emit an empty token to kick off UI streaming if desired + elif event_type == "response.content_part.added": + yield format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "token": "", + }, + } + ) + chunk_id += 1 + + # Track tool call items as they are added so we can build a summary later + elif event_type == "response.output_item.added": + item = getattr(chunk, "item", None) + item_type = getattr(item, "type", None) + if item and item_type == "function_call": + item_id = getattr(item, "id", "") + name = getattr(item, "name", "function_call") + call_id = getattr(item, "call_id", item_id) + if item_id: + tool_item_registry[item_id] = { + "name": name, + "call_id": call_id, + } + + # Stream tool call arguments as tool_call events + elif event_type == "response.function_call_arguments.delta": + delta = getattr(chunk, "delta", "") + yield format_stream_data( + { + "event": "tool_call", + "data": { + "id": chunk_id, + "role": "tool_execution", + "token": delta, + }, + } + ) + chunk_id += 1 + + # Finalize tool call arguments and append to summary + elif event_type in ( + "response.function_call_arguments.done", + "response.mcp_call.arguments.done", + ): + item_id = getattr(chunk, "item_id", "") + arguments = getattr(chunk, "arguments", "") + meta = tool_item_registry.get(item_id, {}) + summary.tool_calls.append( + ToolCallSummary( + id=meta.get("call_id", item_id or "unknown"), + name=meta.get("name", "tool_call"), + args=arguments, + response=None, + ) + ) + + # Completed response - capture final text if any + elif event_type == "response.completed": + if not emitted_turn_complete: + final_message = summary.llm_response or "".join(text_parts) + yield format_stream_data( + { + "event": "turn_complete", + "data": { + "id": chunk_id, + "token": final_message, + }, + } + ) + chunk_id += 1 + emitted_turn_complete = True + + # Ignore other event types for now; could add heartbeats if desired + + logger.debug( + "Streaming complete - Tool calls: %d, Response chars: %d", + len(summary.tool_calls), + len(summary.llm_response), + ) + + yield stream_end_event(metadata_map) + + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + store_transcript( + user_id=user_id, + conversation_id=conv_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part + # of quota work + attachments=query_request.attachments or [], + ) + + # Conversation persistence is handled inside the stream + # once the response.created event provides the ID + + # Update metrics for the LLM call + metrics.llm_calls_total.labels(provider_id, model_id).inc() + + return StreamingResponse(response_generator(response)) + + # connection to Llama Stack server + except APIConnectionError as e: + # Update metrics for the LLM call failure + metrics.llm_calls_failures_total.inc() + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + + +async def retrieve_response( + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, + mcp_headers: dict[str, dict[str, str]] | None = None, +) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]: + """ + Retrieve response from LLMs and agents. + + Asynchronously retrieves a streaming response and conversation + ID from the Llama Stack agent for a given user query. + + This function configures input/output shields, system prompt, + and tool usage based on the request and environment. It + prepares the agent with appropriate headers and toolgroups, + validates attachments if present, and initiates a streaming + turn with the user's query and any provided documents. + + Parameters: + model_id (str): Identifier of the model to use for the query. + query_request (QueryRequest): The user's query and associated metadata. + token (str): Authentication token for downstream services. + mcp_headers (dict[str, dict[str, str]], optional): + Multi-cluster proxy headers for tool integrations. + + Returns: + tuple: A tuple containing the streaming response object + and the conversation ID. + """ + logger.info("Shields are not yet supported in Responses API. Disabling safety") + + # use system prompt from request or default one + system_prompt = get_system_prompt(query_request, configuration) + logger.debug("Using system prompt: %s", system_prompt) + + # TODO(lucasagomes): redact attachments content before sending to LLM + # if attachments are provided, validate them + if query_request.attachments: + validate_attachments_metadata(query_request.attachments) + + # Prepare tools for responses API + tools: list[dict[str, Any]] = [] + if not query_request.no_tools: + # Get vector databases for RAG tools + vector_db_ids = [ + vector_db.identifier for vector_db in await client.vector_dbs.list() + ] + + # Add RAG tools if vector databases are available + rag_tools = get_rag_tools(vector_db_ids) + if rag_tools: + tools.extend(rag_tools) + + # Add MCP server tools + mcp_tools = get_mcp_tools(configuration.mcp_servers, token) + if mcp_tools: + tools.extend(mcp_tools) + logger.debug( + "Configured %d MCP tools: %s", + len(mcp_tools), + [tool.get("server_label", "unknown") for tool in mcp_tools], + ) + + response = await client.responses.create( + input=query_request.query, + model=model_id, + instructions=system_prompt, + previous_response_id=query_request.conversation_id, + tools=(cast(Any, tools) if tools else cast(Any, None)), + stream=True, + store=True, + ) + + response_stream = cast(AsyncIterator[OpenAIResponseObjectStream], response) + + # For streaming responses, the ID arrives in the first 'response.created' chunk + # Return empty conversation_id here; it will be set once the first chunk is received + return response_stream, "" diff --git a/src/app/routers.py b/src/app/routers.py index 3350f5eed..3dac0f650 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -13,6 +13,7 @@ config, feedback, streaming_query, + streaming_query_v2, authorized, conversations, conversations_v2, @@ -46,6 +47,7 @@ def include_routers(app: FastAPI) -> None: # V2 endpoints - Response API support app.include_router(query_v2.router, prefix="/v2") + app.include_router(streaming_query_v2.router, prefix="/v2") # road-core does not version these endpoints app.include_router(health.router) diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index a6c23a2b6..c1131a2ba 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -20,6 +20,7 @@ config, feedback, streaming_query, + streaming_query_v2, authorized, metrics, tools, @@ -65,7 +66,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 16 + assert len(app.routers) == 17 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -75,6 +76,7 @@ def test_include_routers() -> None: assert query.router in app.get_routers() assert query_v2.router in app.get_routers() assert streaming_query.router in app.get_routers() + assert streaming_query_v2.router in app.get_routers() assert config.router in app.get_routers() assert feedback.router in app.get_routers() assert health.router in app.get_routers() @@ -90,7 +92,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 16 + assert len(app.routers) == 17 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -100,6 +102,7 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1" assert app.get_router_prefix(query_v2.router) == "/v2" + assert app.get_router_prefix(streaming_query_v2.router) == "/v2" assert app.get_router_prefix(config.router) == "/v1" assert app.get_router_prefix(feedback.router) == "/v1" assert app.get_router_prefix(health.router) == "" From b3c4c3ec73b53b6f50d85442dff3b955a09f598e Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 17 Sep 2025 15:39:24 +0200 Subject: [PATCH 4/7] Adding unit test coverage for streaming_query_v2 (cursor generated) --- tests/unit/app/endpoints/test_query_v2.py | 14 +- .../app/endpoints/test_streaming_query_v2.py | 211 ++++++++++++++++++ 2 files changed, 214 insertions(+), 11 deletions(-) create mode 100644 tests/unit/app/endpoints/test_streaming_query_v2.py diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py index 91e42c97b..7b43edb9a 100644 --- a/tests/unit/app/endpoints/test_query_v2.py +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -70,9 +70,7 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker): mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) qr = QueryRequest(query="hello", no_tools=True) - summary, conv_id = await retrieve_response( - mock_client, "model-x", qr, token="tkn" - ) + summary, conv_id = await retrieve_response(mock_client, "model-x", qr, token="tkn") assert conv_id == "resp-1" assert summary.llm_response == "" @@ -148,9 +146,7 @@ async def test_retrieve_response_parses_output_and_tool_calls(mocker): mocker.patch("app.endpoints.query_v2.configuration", mocker.Mock(mcp_servers=[])) qr = QueryRequest(query="hello") - summary, conv_id = await retrieve_response( - mock_client, "model-z", qr, token="tkn" - ) + summary, conv_id = await retrieve_response(mock_client, "model-z", qr, token="tkn") assert conv_id == "resp-3" assert summary.llm_response == "Hello world!" @@ -181,9 +177,7 @@ async def test_retrieve_response_validates_attachments(mocker): ] qr = QueryRequest(query="hello", attachments=attachments) - _summary, _cid = await retrieve_response( - mock_client, "model-a", qr, token="tkn" - ) + _summary, _cid = await retrieve_response(mock_client, "model-a", qr, token="tkn") validate_spy.assert_called_once() @@ -257,5 +251,3 @@ def _raise(*_args, **_kwargs): assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert "Unable to connect to Llama Stack" in str(exc.value.detail) fail_metric.inc.assert_called_once() - - diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py new file mode 100644 index 000000000..3a84a0974 --- /dev/null +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -0,0 +1,211 @@ +# pylint: disable=redefined-outer-name, import-error +"""Unit tests for the /streaming_query (v2) endpoint using Responses API.""" + +from types import SimpleNamespace +import pytest +from fastapi import HTTPException, status, Request +from fastapi.responses import StreamingResponse + +from llama_stack_client import APIConnectionError + +from models.requests import QueryRequest +from models.config import ModelContextProtocolServer + +from app.endpoints.streaming_query_v2 import ( + retrieve_response, + streaming_query_endpoint_handler_v2, +) + + +@pytest.fixture +def dummy_request() -> Request: + req = Request(scope={"type": "http"}) + # Provide a permissive authorized_actions set to satisfy RBAC check + from models.config import Action # import here to avoid global import errors + + req.state.authorized_actions = set(Action) + return req + + +@pytest.mark.asyncio +async def test_retrieve_response_builds_rag_and_mcp_tools(mocker): + mock_client = mocker.Mock() + mock_client.vector_dbs.list = mocker.AsyncMock( + return_value=[mocker.Mock(identifier="db1")] + ) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + + mock_cfg = mocker.Mock() + mock_cfg.mcp_servers = [ + ModelContextProtocolServer(name="fs", url="http://localhost:3000"), + ] + mocker.patch("app.endpoints.streaming_query_v2.configuration", mock_cfg) + + qr = QueryRequest(query="hello") + await retrieve_response(mock_client, "model-z", qr, token="tok") + + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["stream"] is True + tools = kwargs["tools"] + assert isinstance(tools, list) + types = {t.get("type") for t in tools} + assert types == {"file_search", "mcp"} + + +@pytest.mark.asyncio +async def test_retrieve_response_no_tools_passes_none(mocker): + mock_client = mocker.Mock() + mock_client.vector_dbs.list = mocker.AsyncMock(return_value=[]) + mock_client.responses.create = mocker.AsyncMock(return_value=mocker.Mock()) + + mocker.patch( + "app.endpoints.streaming_query_v2.get_system_prompt", return_value="PROMPT" + ) + mocker.patch( + "app.endpoints.streaming_query_v2.configuration", mocker.Mock(mcp_servers=[]) + ) + + qr = QueryRequest(query="hello", no_tools=True) + await retrieve_response(mock_client, "model-z", qr, token="tok") + + kwargs = mock_client.responses.create.call_args.kwargs + assert kwargs["tools"] is None + assert kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_v2_success_yields_events( + mocker, dummy_request +): + # Skip real config checks + mocker.patch("app.endpoints.streaming_query_v2.check_configuration_loaded") + + # Model selection plumbing + mock_client = mocker.Mock() + mock_client.models.list = mocker.AsyncMock(return_value=[mocker.Mock()]) + mocker.patch( + "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client + ) + mocker.patch( + "app.endpoints.streaming_query_v2.evaluate_model_hints", + return_value=(None, None), + ) + mocker.patch( + "app.endpoints.streaming_query_v2.select_model_and_provider_id", + return_value=("llama/m", "m", "p"), + ) + + # Replace SSE helpers for deterministic output + mocker.patch( + "app.endpoints.streaming_query_v2.stream_start_event", + lambda conv_id: f"START:{conv_id}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.format_stream_data", + lambda obj: f"EV:{obj['event']}:{obj['data'].get('token','')}\n", + ) + mocker.patch( + "app.endpoints.streaming_query_v2.stream_end_event", lambda _m: "END\n" + ) + + # Conversation persistence and transcripts disabled + persist_spy = mocker.patch( + "app.endpoints.streaming_query_v2.persist_user_conversation_details", + return_value=None, + ) + mocker.patch( + "app.endpoints.streaming_query_v2.is_transcripts_enabled", return_value=False + ) + + # Build a fake async stream of chunks + async def fake_stream(): + yield SimpleNamespace( + type="response.created", response=SimpleNamespace(id="conv-xyz") + ) + yield SimpleNamespace(type="response.content_part.added") + yield SimpleNamespace(type="response.output_text.delta", delta="Hello ") + yield SimpleNamespace(type="response.output_text.delta", delta="world") + yield SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace( + type="function_call", id="item1", name="search", call_id="call1" + ), + ) + yield SimpleNamespace( + type="response.function_call_arguments.delta", delta='{"q":"x"}' + ) + yield SimpleNamespace( + type="response.function_call_arguments.done", + item_id="item1", + arguments='{"q":"x"}', + ) + yield SimpleNamespace(type="response.output_text.done", text="Hello world") + yield SimpleNamespace(type="response.completed") + + mocker.patch( + "app.endpoints.streaming_query_v2.retrieve_response", + return_value=(fake_stream(), ""), + ) + + metric = mocker.patch("metrics.llm_calls_total") + + resp = await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "token-abc"), + mcp_headers={}, + ) + + assert isinstance(resp, StreamingResponse) + metric.labels("p", "m").inc.assert_called_once() + + # Collect emitted events + events: list[str] = [] + async for chunk in resp.body_iterator: + s = chunk.decode() if isinstance(chunk, (bytes, bytearray)) else str(chunk) + events.append(s) + + # Validate event sequence and content + assert events[0] == "START:conv-xyz\n" + # content_part.added triggers empty token + assert events[1] == "EV:token:\n" + assert events[2] == "EV:token:Hello \n" + assert events[3] == "EV:token:world\n" + # tool call delta + assert events[4].startswith("EV:tool_call:") + # turn complete and end + assert "EV:turn_complete:Hello world\n" in events + assert events[-1] == "END\n" + + # Verify conversation persistence was invoked with the created id + persist_spy.assert_called_once() + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_v2_api_connection_error( + mocker, dummy_request +): + mocker.patch("app.endpoints.streaming_query_v2.check_configuration_loaded") + + def _raise(*_a, **_k): + raise APIConnectionError(request=None) + + mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise) + + fail_metric = mocker.patch("metrics.llm_calls_failures_total") + + with pytest.raises(HTTPException) as exc: + await streaming_query_endpoint_handler_v2( + request=dummy_request, + query_request=QueryRequest(query="hi"), + auth=("user123", "", False, "tok"), + mcp_headers={}, + ) + + assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in str(exc.value.detail) + fail_metric.inc.assert_called_once() From b740d02f052fe89cbd5cc1c88cd231b2b97454f1 Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Mon, 29 Sep 2025 10:08:18 +0200 Subject: [PATCH 5/7] Add conversations V3 to allow testing of query/streaming_query V2 It is using conversations as conversation v3 endpoint --- src/app/routers.py | 1 + src/utils/suid.py | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/app/routers.py b/src/app/routers.py index 3dac0f650..00cf047f4 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -48,6 +48,7 @@ def include_routers(app: FastAPI) -> None: # V2 endpoints - Response API support app.include_router(query_v2.router, prefix="/v2") app.include_router(streaming_query_v2.router, prefix="/v2") + app.include_router(conversations.router, prefix="/v3") # road-core does not version these endpoints app.include_router(health.router) diff --git a/src/utils/suid.py b/src/utils/suid.py index 4dc9ca5e8..a3783ebf3 100644 --- a/src/utils/suid.py +++ b/src/utils/suid.py @@ -18,9 +18,10 @@ def get_suid() -> str: def check_suid(suid: str) -> bool: """ - Check if given string is a proper session ID. + Check if given string is a proper session ID or response ID. - Returns True if the string is a valid UUID, False otherwise. + Returns True if the string is a valid UUID or if it starts with resp-/resp_ + and it follows a valid UUID string, False otherwise. Parameters: suid (str | bytes): UUID value to validate — accepts a UUID string or @@ -30,8 +31,25 @@ def check_suid(suid: str) -> bool: Validation is performed by attempting to construct uuid.UUID(suid); invalid formats or types result in False. """ + if not isinstance(suid, str) or not suid: + return False + + # Handle Responses API IDs + if suid.startswith("resp-") or suid.startswith("resp_"): + token = suid[5:] + if not token: + return False + # If truncated (e.g., shell cut reduced length), pad to canonical UUID length + if len(token) < 36: + token = token + ("0" * (36 - len(token))) + try: + uuid.UUID(token) + return True + except (ValueError, TypeError): + return False + + # Otherwise, enforce UUID format try: - # accepts strings and bytes only uuid.UUID(suid) return True except (ValueError, TypeError): From 1f953ec8b93f1b31f070c6f4e4ad31b89d9234e7 Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Tue, 14 Oct 2025 14:22:35 +0200 Subject: [PATCH 6/7] Add topic summary --- src/app/endpoints/query_v2.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index 1a7174603..964a55b17 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -18,6 +18,7 @@ query_response, select_model_and_provider_id, validate_attachments_metadata, + get_topic_summary, ) from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -27,6 +28,7 @@ import metrics from models.config import Action from models.database.conversations import UserConversation +from app.database import get_session from models.requests import QueryRequest from models.responses import QueryResponse from utils.endpoints import ( @@ -102,6 +104,19 @@ async def query_endpoint_handler_v2( # Update metrics for the LLM call metrics.llm_calls_total.labels(provider_id, model_id).inc() + # Compute topic summary if this is a brand new conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation) + .filter_by(id=conversation_id) + .first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) + process_transcript_and_persist_conversation( user_id=user_id, conversation_id=conversation_id, @@ -109,6 +124,7 @@ async def query_endpoint_handler_v2( provider_id=provider_id, query_request=query_request, summary=summary, + topic_summary=topic_summary, ) return QueryResponse( @@ -350,6 +366,7 @@ def process_transcript_and_persist_conversation( provider_id: str, query_request: QueryRequest, summary: TurnSummary, + topic_summary: str | None = None, ) -> None: """Process transcript storage and persist conversation details.""" if not is_transcripts_enabled(): @@ -374,4 +391,5 @@ def process_transcript_and_persist_conversation( conversation_id=conversation_id, model=model_id, provider_id=provider_id, + topic_summary=topic_summary, ) From afe99e3c988e2d413c5c78b8070eb961a3be463d Mon Sep 17 00:00:00 2001 From: Luis Tomas Bolivar Date: Wed, 15 Oct 2025 14:52:07 +0200 Subject: [PATCH 7/7] Adapt to vector store, instead of vector DB API --- src/app/endpoints/query.py | 20 ++++++++++---------- src/app/endpoints/query_v2.py | 18 +++++++++--------- src/app/endpoints/streaming_query.py | 6 +++--- src/app/endpoints/streaming_query_v2.py | 10 +++++----- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3829305b4..86bf72418 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -698,10 +698,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche ), } - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ + toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [ mcp_server.name for mcp_server in configuration.mcp_servers ] # Convert empty list to None for consistency with existing behavior @@ -797,30 +797,30 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None: def get_rag_toolgroups( - vector_db_ids: list[str], + vector_store_ids: list[str], ) -> list[Toolgroup] | None: """ - Return a list of RAG Tool groups if the given vector DB list is not empty. + Return a list of RAG Tool groups if the given vector store list is not empty. Generate a list containing a RAG knowledge search toolgroup if - vector database IDs are provided. + vector store IDs are provided. Parameters: - vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup. + vector_store_ids (list[str]): List of vector store identifiers to include in the toolgroup. Returns: list[Toolgroup] | None: A list with a single RAG toolgroup if - vector_db_ids is non-empty; otherwise, None. + vector_store_ids is non-empty; otherwise, None. """ return ( [ ToolgroupAgentToolGroupWithArgs( name="builtin::rag/knowledge_search", args={ - "vector_db_ids": vector_db_ids, + "vector_store_ids": vector_store_ids, }, ) ] - if vector_db_ids + if vector_store_ids else None ) diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index 964a55b17..4c6047b28 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -193,13 +193,13 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche # Prepare tools for responses API tools: list[dict[str, Any]] = [] if not query_request.no_tools: - # Get vector databases for RAG tools - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() + # Get vector stores for RAG tools + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data ] - # Add RAG tools if vector databases are available - rag_tools = get_rag_tools(vector_db_ids) + # Add RAG tools if vector stores are available + rag_tools = get_rag_tools(vector_store_ids) if rag_tools: tools.extend(rag_tools) @@ -326,15 +326,15 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche return summary, conversation_id -def get_rag_tools(vector_db_ids: list[str]) -> list[dict[str, Any]] | None: - """Convert vector DB IDs to tools format for responses API.""" - if not vector_db_ids: +def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None: + """Convert vector store IDs to tools format for responses API.""" + if not vector_store_ids: return None return [ { "type": "file_search", - "vector_store_ids": vector_db_ids, + "vector_store_ids": vector_store_ids, "max_num_results": 10, } ] diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index d4ad3088a..fde9b2d25 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1031,10 +1031,10 @@ async def retrieve_response( ), } - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ + toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [ mcp_server.name for mcp_server in configuration.mcp_servers ] # Convert empty list to None for consistency with existing behavior diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index b807340c5..f2179639e 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -365,13 +365,13 @@ async def retrieve_response( # Prepare tools for responses API tools: list[dict[str, Any]] = [] if not query_request.no_tools: - # Get vector databases for RAG tools - vector_db_ids = [ - vector_db.identifier for vector_db in await client.vector_dbs.list() + # Get vector stores for RAG tools + vector_store_ids = [ + vector_store.id for vector_store in (await client.vector_stores.list()).data ] - # Add RAG tools if vector databases are available - rag_tools = get_rag_tools(vector_db_ids) + # Add RAG tools if vector stores are available + rag_tools = get_rag_tools(vector_store_ids) if rag_tools: tools.extend(rag_tools)