Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from typing import Annotated, Any, cast

from fastapi import APIRouter, Depends, HTTPException
from llama_stack_client import APIConnectionError # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
from llama_stack_client.types.alpha.agents.turn import Turn
from llama_stack.apis.agents.openai_responses import OpenAIResponseObject
from llama_stack_client import APIConnectionError

import constants
from authentication import get_auth_dependency
Expand All @@ -27,9 +26,8 @@
)
from models.rlsapi.requests import RlsapiV1InferRequest
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from utils.endpoints import get_temp_agent
from utils.responses import extract_text_from_response_output_item
from utils.suid import get_suid
from utils.types import content_to_str

logger = logging.getLogger(__name__)
router = APIRouter(tags=["rlsapi-v1"])
Expand Down Expand Up @@ -82,8 +80,8 @@ def _get_default_model_id() -> str:
async def retrieve_simple_response(question: str) -> str:
"""Retrieve a simple response from the LLM for a stateless query.

Creates a temporary agent, sends a single turn with the user's question,
and returns the LLM response text. No conversation persistence or tools.
Uses the Responses API for simple stateless inference, consistent with
other endpoints (query_v2, streaming_query_v2).

Args:
question: The combined user input (question + context).
Expand All @@ -100,24 +98,19 @@ async def retrieve_simple_response(question: str) -> str:

logger.debug("Using model %s for rlsapi v1 inference", model_id)

agent, session_id, _ = await get_temp_agent(
client, model_id, constants.DEFAULT_SYSTEM_PROMPT
)

response = await agent.create_turn(
messages=[UserMessage(role="user", content=question).model_dump()],
session_id=session_id,
response = await client.responses.create(
input=question,
model=model_id,
instructions=constants.DEFAULT_SYSTEM_PROMPT,
stream=False,
store=False,
)
response = cast(Turn, response)

if getattr(response, "output_message", None) is None:
return ""
response = cast(OpenAIResponseObject, response)

if getattr(response.output_message, "content", None) is None:
return ""

return content_to_str(response.output_message.content)
return "".join(
extract_text_from_response_output_item(output_item)
for output_item in response.output
)


@router.post("/infer", responses=infer_responses)
Expand Down
167 changes: 73 additions & 94 deletions tests/integration/endpoints/test_rlsapi_v1_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
# pylint: disable=protected-access
# pylint: disable=unused-argument

from typing import Any, NamedTuple
from typing import Any

import pytest
from fastapi import HTTPException, status
from llama_stack_client import APIConnectionError
from llama_stack_client.types.alpha.agents.turn import Turn
from pytest_mock import MockerFixture

import constants
Expand All @@ -34,26 +33,14 @@
from utils.suid import check_suid


class MockAgentFixture(NamedTuple):
"""Container for mocked Llama Stack agent components."""

client: Any
agent: Any
holder_class: Any


# ==========================================
# Shared Fixtures
# ==========================================


@pytest.fixture(name="rlsapi_config")
def rlsapi_config_fixture(test_config: AppConfig, mocker: MockerFixture) -> AppConfig:
"""Extend test_config with inference defaults required by rlsapi v1.

NOTE(major): The standard test configuration doesn't include inference
settings (default_model, default_provider) which rlsapi v1 requires.
"""
"""Extend test_config with inference defaults required by rlsapi v1."""
test_config.inference.default_model = "test-model"
test_config.inference.default_provider = "test-provider"
mocker.patch("app.endpoints.rlsapi_v1.configuration", test_config)
Expand All @@ -66,60 +53,42 @@ def mock_authorization_fixture(mocker: MockerFixture) -> None:
mock_authorization_resolvers(mocker)


def _create_mock_agent(
def _create_mock_response_output(mocker: MockerFixture, text: str) -> Any:
"""Create a mock Responses API output item with assistant message."""
mock_output_item = mocker.Mock()
mock_output_item.type = "message"
mock_output_item.role = "assistant"
mock_output_item.content = text
return mock_output_item


def _setup_responses_mock(
mocker: MockerFixture,
response_content: str = "Use the `ls` command to list files in a directory.",
output_message: Any = "default",
) -> MockAgentFixture:
"""Create a mocked Llama Stack agent with configurable response.

Args:
mocker: pytest-mock fixture
response_content: Text content for the LLM response
output_message: Custom output_message Mock, or "default" to create one,
or None for no output_message

Returns:
MockAgentFixture with client, agent, and holder_class components
"""
response_text: str = "Use the `ls` command to list files in a directory.",
) -> Any:
"""Set up responses.create mock with the given response text."""
mock_response = mocker.Mock()
mock_response.output = [_create_mock_response_output(mocker, response_text)]

mock_responses = mocker.Mock()
mock_responses.create = mocker.AsyncMock(return_value=mock_response)

mock_client = mocker.Mock()
mock_client.responses = mock_responses

mock_holder_class = mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
)
mock_client = mocker.AsyncMock()

# Configure output message
if output_message == "default":
mock_output_message = mocker.Mock()
mock_output_message.content = response_content
else:
mock_output_message = output_message

mock_turn = mocker.Mock(spec=Turn)
mock_turn.output_message = mock_output_message
mock_turn.steps = []

mock_agent = mocker.AsyncMock()
mock_agent.create_turn = mocker.AsyncMock(return_value=mock_turn)
mock_agent._agent_id = "test_agent_id"

mocker.patch(
"app.endpoints.rlsapi_v1.get_temp_agent",
return_value=(mock_agent, "test_session_id", None),
)

mock_holder_instance = mock_holder_class.return_value
mock_holder_instance.get_client.return_value = mock_client
mock_holder_class.return_value.get_client.return_value = mock_client

return MockAgentFixture(mock_client, mock_agent, mock_holder_class)
return mock_client


@pytest.fixture(name="mock_llama_stack")
def mock_llama_stack_fixture(
rlsapi_config: AppConfig, mocker: MockerFixture
) -> MockAgentFixture:
def mock_llama_stack_fixture(rlsapi_config: AppConfig, mocker: MockerFixture) -> Any:
"""Mock Llama Stack client with successful response."""
_ = rlsapi_config
return _create_mock_agent(mocker)
return _setup_responses_mock(mocker)


# ==========================================
Expand All @@ -129,7 +98,7 @@ def mock_llama_stack_fixture(

@pytest.mark.asyncio
async def test_rlsapi_v1_infer_minimal_request(
mock_llama_stack: MockAgentFixture,
mock_llama_stack: Any,
mock_authorization: None,
test_auth: AuthTuple,
) -> None:
Expand Down Expand Up @@ -179,7 +148,7 @@ async def test_rlsapi_v1_infer_minimal_request(
],
)
async def test_rlsapi_v1_infer_with_context(
mock_llama_stack: MockAgentFixture,
mock_llama_stack: Any,
mock_authorization: None,
test_auth: AuthTuple,
context: RlsapiV1Context,
Expand All @@ -198,7 +167,7 @@ async def test_rlsapi_v1_infer_with_context(

@pytest.mark.asyncio
async def test_rlsapi_v1_infer_generates_unique_request_ids(
mock_llama_stack: MockAgentFixture,
mock_llama_stack: Any,
mock_authorization: None,
test_auth: AuthTuple,
) -> None:
Expand Down Expand Up @@ -229,19 +198,18 @@ async def test_rlsapi_v1_infer_connection_error_returns_503(
"""Test /v1/infer returns 503 when Llama Stack is unavailable."""
_ = rlsapi_config

# Create agent that raises APIConnectionError
mock_holder_class = mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
)
mock_agent = mocker.AsyncMock()
mock_agent.create_turn = mocker.AsyncMock(
mock_responses = mocker.Mock()
mock_responses.create = mocker.AsyncMock(
side_effect=APIConnectionError(request=mocker.Mock())
)
mocker.patch(
"app.endpoints.rlsapi_v1.get_temp_agent",
return_value=(mock_agent, "test_session_id", None),

mock_client = mocker.Mock()
mock_client.responses = mock_responses

mock_holder_class = mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
)
mock_holder_class.return_value.get_client.return_value = mocker.AsyncMock()
mock_holder_class.return_value.get_client.return_value = mock_client

with pytest.raises(HTTPException) as exc_info:
await infer_endpoint(
Expand All @@ -255,29 +223,28 @@ async def test_rlsapi_v1_infer_connection_error_returns_503(


@pytest.mark.asyncio
@pytest.mark.parametrize(
"output_message",
[
pytest.param(None, id="none_output_message"),
pytest.param("empty", id="empty_content"),
],
)
async def test_rlsapi_v1_infer_fallback_responses(
async def test_rlsapi_v1_infer_fallback_response_empty_output(
rlsapi_config: AppConfig,
mock_authorization: None,
test_auth: AuthTuple,
mocker: MockerFixture,
output_message: Any,
) -> None:
"""Test /v1/infer returns fallback for empty/None responses."""
"""Test /v1/infer returns fallback for empty output list."""
_ = rlsapi_config

if output_message == "empty":
mock_output = mocker.Mock()
mock_output.content = ""
_create_mock_agent(mocker, output_message=mock_output)
else:
_create_mock_agent(mocker, output_message=None)
mock_response = mocker.Mock()
mock_response.output = []

mock_responses = mocker.Mock()
mock_responses.create = mocker.AsyncMock(return_value=mock_response)

mock_client = mocker.Mock()
mock_client.responses = mock_responses

mock_holder_class = mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
)
mock_holder_class.return_value.get_client.return_value = mock_client

response = await infer_endpoint(
infer_request=RlsapiV1InferRequest(question="Test"),
Expand All @@ -301,7 +268,20 @@ async def test_rlsapi_v1_infer_input_source_combination(
) -> None:
"""Test that input sources are properly combined before sending to LLM."""
_ = rlsapi_config
mocks = _create_mock_agent(mocker)

mock_response = mocker.Mock()
mock_response.output = [_create_mock_response_output(mocker, "response text")]

mock_responses = mocker.Mock()
mock_responses.create = mocker.AsyncMock(return_value=mock_response)

mock_client = mocker.Mock()
mock_client.responses = mock_responses

mock_holder_class = mocker.patch(
"app.endpoints.rlsapi_v1.AsyncLlamaStackClientHolder"
)
mock_holder_class.return_value.get_client.return_value = mock_client

await infer_endpoint(
infer_request=RlsapiV1InferRequest(
Expand All @@ -315,12 +295,11 @@ async def test_rlsapi_v1_infer_input_source_combination(
auth=test_auth,
)

# Verify all parts present in message sent to LLM
call_args = mocks.agent.create_turn.call_args
message_content = call_args.kwargs["messages"][0]["content"]
call_args = mock_responses.create.call_args
input_content = call_args.kwargs["input"]

for expected in ["My question", "stdin content", "attachment content", "terminal"]:
assert expected in message_content
assert expected in input_content


# ==========================================
Expand All @@ -334,7 +313,7 @@ async def test_rlsapi_v1_infer_input_source_combination(
[pytest.param(False, id="default_false"), pytest.param(True, id="explicit_true")],
)
async def test_rlsapi_v1_infer_skip_rag(
mock_llama_stack: MockAgentFixture,
mock_llama_stack: Any,
mock_authorization: None,
test_auth: AuthTuple,
skip_rag: bool,
Expand Down
Loading
Loading