Skip to content

Commit 448cb60

Browse files
authored
Merge pull request #1018 from asimurka/tool_call_extraction_improvement
LCORE-1198: RAG chunk parsing improvement for streaming query
2 parents 4f168c3 + d7d11c7 commit 448cb60

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

src/app/endpoints/query_v2.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,24 @@
8383

8484
def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches
8585
output_item: OpenAIResponseOutput,
86+
rag_chunks: list[RAGChunk],
8687
) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]:
8788
"""Translate Responses API tool outputs into ToolCallSummary and ToolResultSummary records.
8889
8990
Processes OpenAI response output items and extracts tool call and result information.
91+
Also parses RAG chunks from file_search_call items and appends them to the provided list.
9092
9193
Args:
9294
output_item: An OpenAIResponseOutput item from the response.output array
95+
rag_chunks: List to append extracted RAG chunks to (from file_search_call items)
9396
9497
Returns:
9598
A tuple of (ToolCallSummary, ToolResultSummary) one of them possibly None
9699
if current llama stack Responses API does not provide the information.
97100
98101
Supported tool types:
99102
- function_call: Function tool calls with parsed arguments (no result)
100-
- file_search_call: File search operations with results
103+
- file_search_call: File search operations with results (also extracts RAG chunks)
101104
- web_search_call: Web search operations (incomplete)
102105
- mcp_call: MCP calls with server labels
103106
- mcp_list_tools: MCP server tool listings
@@ -120,6 +123,7 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
120123

121124
if item_type == "file_search_call":
122125
item = cast(OpenAIResponseOutputMessageFileSearchToolCall, output_item)
126+
extract_rag_chunks_from_file_search_item(item, rag_chunks)
123127
response_payload: Optional[dict[str, Any]] = None
124128
if item.results is not None:
125129
response_payload = {
@@ -430,12 +434,13 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
430434
llm_response = ""
431435
tool_calls: list[ToolCallSummary] = []
432436
tool_results: list[ToolResultSummary] = []
437+
rag_chunks: list[RAGChunk] = []
433438
for output_item in response.output:
434439
message_text = extract_text_from_response_output_item(output_item)
435440
if message_text:
436441
llm_response += message_text
437442

438-
tool_call, tool_result = _build_tool_call_summary(output_item)
443+
tool_call, tool_result = _build_tool_call_summary(output_item, rag_chunks)
439444
if tool_call:
440445
tool_calls.append(tool_call)
441446
if tool_result:
@@ -447,9 +452,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
447452
len(llm_response),
448453
)
449454

450-
# Extract rag chunks
451-
rag_chunks = parse_rag_chunks_from_responses_api(response)
452-
453455
summary = TurnSummary(
454456
llm_response=llm_response,
455457
tool_calls=tool_calls,
@@ -478,7 +480,27 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
478480
)
479481

480482

481-
def parse_rag_chunks_from_responses_api(response_obj: Any) -> list[RAGChunk]:
483+
def extract_rag_chunks_from_file_search_item(
484+
item: OpenAIResponseOutputMessageFileSearchToolCall,
485+
rag_chunks: list[RAGChunk],
486+
) -> None:
487+
"""Extract RAG chunks from a file search tool call item and append to rag_chunks.
488+
489+
Args:
490+
item: The file search tool call item.
491+
rag_chunks: List to append extracted RAG chunks to.
492+
"""
493+
if item.results is not None:
494+
for result in item.results:
495+
rag_chunk = RAGChunk(
496+
content=result.text, source="file_search", score=result.score
497+
)
498+
rag_chunks.append(rag_chunk)
499+
500+
501+
def parse_rag_chunks_from_responses_api(
502+
response_obj: OpenAIResponseObject,
503+
) -> list[RAGChunk]:
482504
"""
483505
Extract rag_chunks from the llama-stack OpenAI response.
484506
@@ -488,20 +510,13 @@ def parse_rag_chunks_from_responses_api(response_obj: Any) -> list[RAGChunk]:
488510
Returns:
489511
List of RAGChunk with content, source, score
490512
"""
491-
rag_chunks = []
513+
rag_chunks: list[RAGChunk] = []
492514

493515
for output_item in response_obj.output:
494-
if (
495-
hasattr(output_item, "type")
496-
and output_item.type == "file_search_call"
497-
and hasattr(output_item, "results")
498-
):
499-
500-
for result in output_item.results:
501-
rag_chunk = RAGChunk(
502-
content=result.text, source="file_search", score=result.score
503-
)
504-
rag_chunks.append(rag_chunk)
516+
item_type = getattr(output_item, "type", None)
517+
if item_type == "file_search_call":
518+
item = cast(OpenAIResponseOutputMessageFileSearchToolCall, output_item)
519+
extract_rag_chunks_from_file_search_item(item, rag_chunks)
505520

506521
return rag_chunks
507522

src/app/endpoints/streaming_query_v2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
)
7171
from utils.token_counter import TokenCounter
7272
from utils.transcripts import store_transcript
73-
from utils.types import TurnSummary
73+
from utils.types import RAGChunk, TurnSummary
7474

7575
logger = logging.getLogger("app.endpoints.handlers")
7676
router = APIRouter(tags=["streaming_query_v1"])
@@ -143,6 +143,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
143143
# Track the latest response object from response.completed event
144144
latest_response_object: Optional[Any] = None
145145

146+
# RAG chunks
147+
rag_chunks: list[RAGChunk] = []
148+
146149
logger.debug("Starting streaming response (Responses API) processing")
147150

148151
async for chunk in turn_response:
@@ -198,7 +201,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
198201
)
199202
if done_chunk.item.type == "message":
200203
continue
201-
tool_call, tool_result = _build_tool_call_summary(done_chunk.item)
204+
tool_call, tool_result = _build_tool_call_summary(
205+
done_chunk.item, rag_chunks
206+
)
202207
if tool_call:
203208
summary.tool_calls.append(tool_call)
204209
yield stream_event(
@@ -321,7 +326,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
321326
is_transcripts_enabled_func=is_transcripts_enabled,
322327
store_transcript_func=store_transcript,
323328
persist_user_conversation_details_func=persist_user_conversation_details,
324-
rag_chunks=[], # Responses API uses empty list for rag_chunks
329+
rag_chunks=[rag_chunk.model_dump() for rag_chunk in rag_chunks],
325330
)
326331

327332
return response_generator

src/utils/endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ async def cleanup_after_streaming(
747747
is_transcripts_enabled_func: Function to check if transcripts are enabled
748748
store_transcript_func: Function to store transcript
749749
persist_user_conversation_details_func: Function to persist conversation details
750-
rag_chunks: Optional RAG chunks dict (for Agent API, None for Responses API)
750+
rag_chunks: Optional RAG chunks dict
751751
"""
752752
# Store transcript if enabled
753753
if not is_transcripts_enabled_func():

0 commit comments

Comments
 (0)