Skip to content

Commit d8f747d

Browse files
committed
fix(server): add TestContextMiddleware for integration test replay
Ensures test context from X-LlamaStack-Provider-Data header is available to FastAPI router routes, enabling deterministic ID generation in replay mode. Signed-off-by: Matthew F Leader <mleader@redhat.com>
1 parent 95b2948 commit d8f747d

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

src/llama_stack/core/server/server.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,36 @@ async def send_version_error(send):
356356
return await self.app(scope, receive, send)
357357

358358

359+
class TestContextMiddleware:
360+
"""Middleware to propagate test context from request headers to all routes.
361+
362+
Extracts the test ID from the X-LlamaStack-Provider-Data header and makes it
363+
available via get_test_context(). This enables deterministic ID generation
364+
during integration test replay mode.
365+
"""
366+
367+
def __init__(self, app):
368+
self.app = app
369+
370+
async def __call__(self, scope, receive, send):
371+
if scope["type"] == "http":
372+
from llama_stack.core.testing_context import (
373+
reset_test_context,
374+
sync_test_context_from_provider_data,
375+
)
376+
377+
headers = {k.decode(): v.decode() for k, v in scope.get("headers", [])}
378+
with request_provider_data_context(headers, None):
379+
token = sync_test_context_from_provider_data()
380+
try:
381+
return await self.app(scope, receive, send)
382+
finally:
383+
if token:
384+
reset_test_context(token)
385+
386+
return await self.app(scope, receive, send)
387+
388+
359389
def create_app() -> StackApp:
360390
"""Create and configure the FastAPI application.
361391
@@ -395,6 +425,9 @@ def create_app() -> StackApp:
395425
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
396426
app.add_middleware(ClientVersionMiddleware)
397427

428+
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
429+
app.add_middleware(TestContextMiddleware)
430+
398431
impls = app.stack.impls
399432

400433
if config.server.auth:
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import json
8+
import os
9+
10+
import pytest
11+
from fastapi import APIRouter, FastAPI
12+
from starlette.testclient import TestClient
13+
14+
from llama_stack.core.server.server import TestContextMiddleware
15+
from llama_stack.core.testing_context import get_test_context
16+
17+
18+
@pytest.fixture
19+
def app_with_middleware():
20+
"""Create a minimal FastAPI app with TestContextMiddleware."""
21+
app = FastAPI()
22+
23+
router = APIRouter()
24+
25+
@router.get("/test-context")
26+
def get_current_test_context():
27+
return {"test_id": get_test_context()}
28+
29+
app.include_router(router)
30+
app.add_middleware(TestContextMiddleware)
31+
32+
return app
33+
34+
35+
@pytest.fixture
36+
def test_mode_env(monkeypatch):
37+
"""Set environment variables required for test context extraction."""
38+
monkeypatch.setenv("LLAMA_STACK_TEST_INFERENCE_MODE", "replay")
39+
monkeypatch.setenv("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "server")
40+
41+
42+
def test_middleware_returns_none_without_header(app_with_middleware, test_mode_env):
43+
"""Without the provider data header, test context should be None."""
44+
client = TestClient(app_with_middleware)
45+
response = client.get("/test-context")
46+
47+
assert response.status_code == 200
48+
assert response.json()["test_id"] is None
49+
50+
51+
def test_middleware_extracts_test_id_from_header(app_with_middleware, test_mode_env):
52+
"""With the provider data header containing __test_id, it should be extracted."""
53+
client = TestClient(app_with_middleware)
54+
55+
provider_data = json.dumps({"__test_id": "test-abc-123"})
56+
response = client.get(
57+
"/test-context",
58+
headers={"X-LlamaStack-Provider-Data": provider_data},
59+
)
60+
61+
assert response.status_code == 200
62+
assert response.json()["test_id"] == "test-abc-123"
63+
64+
65+
def test_middleware_handles_empty_provider_data(app_with_middleware, test_mode_env):
66+
"""Empty provider data should result in None test context."""
67+
client = TestClient(app_with_middleware)
68+
69+
response = client.get(
70+
"/test-context",
71+
headers={"X-LlamaStack-Provider-Data": "{}"},
72+
)
73+
74+
assert response.status_code == 200
75+
assert response.json()["test_id"] is None
76+
77+
78+
def test_middleware_handles_invalid_json(app_with_middleware, test_mode_env):
79+
"""Invalid JSON in header should not crash, test context should be None."""
80+
client = TestClient(app_with_middleware)
81+
82+
response = client.get(
83+
"/test-context",
84+
headers={"X-LlamaStack-Provider-Data": "not-valid-json"},
85+
)
86+
87+
assert response.status_code == 200
88+
assert response.json()["test_id"] is None
89+
90+
91+
def test_middleware_noop_without_test_mode(app_with_middleware):
92+
"""Without test mode env vars, middleware should not extract test context."""
93+
# Ensure env vars are not set
94+
os.environ.pop("LLAMA_STACK_TEST_INFERENCE_MODE", None)
95+
os.environ.pop("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", None)
96+
97+
client = TestClient(app_with_middleware)
98+
99+
provider_data = json.dumps({"__test_id": "test-abc-123"})
100+
response = client.get(
101+
"/test-context",
102+
headers={"X-LlamaStack-Provider-Data": provider_data},
103+
)
104+
105+
assert response.status_code == 200
106+
assert response.json()["test_id"] is None

0 commit comments

Comments
 (0)