diff --git a/src/uipath_langchain/retrievers/context_grounding_retriever.py b/src/uipath_langchain/retrievers/context_grounding_retriever.py index 10562b4ad..f8e14a20b 100644 --- a/src/uipath_langchain/retrievers/context_grounding_retriever.py +++ b/src/uipath_langchain/retrievers/context_grounding_retriever.py @@ -33,6 +33,7 @@ def _get_relevant_documents( page_content=x.content, metadata={ "source": x.source, + "reference": x.reference, "page_number": x.page_number, }, ) @@ -58,6 +59,7 @@ async def _aget_relevant_documents( page_content=x.content, metadata={ "source": x.source, + "reference": x.reference, "page_number": x.page_number, }, ) diff --git a/src/uipath_langchain/runtime/_citations.py b/src/uipath_langchain/runtime/_citations.py new file mode 100644 index 000000000..429133bbc --- /dev/null +++ b/src/uipath_langchain/runtime/_citations.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass +from uuid import uuid4 + +from uipath.core.chat import ( + UiPathConversationCitationEndEvent, + UiPathConversationCitationEvent, + UiPathConversationCitationSourceMedia, + UiPathConversationCitationSourceUrl, + UiPathConversationCitationStartEvent, + UiPathConversationContentPartChunkEvent, +) + +logger = logging.getLogger(__name__) + +_TAG_RE = re.compile(r'') +_ATTR_RE = re.compile(r'([a-z_]+)="([^"]*)"') + +@dataclass(frozen=True) # frozen to make hashable / de-dupe sources +class _ParsedCitation: + title: str + url: str | None = None + reference: str | None = None + page_number: str | None = None + +# tags -> [(text_segment, citation_or_none)] +def _parse_citations(text: str) -> list[tuple[str, _ParsedCitation | None]]: + segments: list[tuple[str, _ParsedCitation | None]] = [] + cursor = 0 + + for match in _TAG_RE.finditer(text): + preceding_text = text[cursor : match.start()] + raw_attributes = match.group(1) + attributes = dict(_ATTR_RE.findall(raw_attributes)) + + title = attributes.get("title", "") + url = attributes.get("url") + reference = attributes.get("reference") + page_number = attributes.get("page_number") + + has_url = url is not None + has_reference = reference is not None + + if has_url and not has_reference: + # web citation + citation = _ParsedCitation( + title=title, url=url, page_number=page_number + ) + elif has_reference and not has_url: + # context grounding citation + citation = _ParsedCitation( + title=title, reference=reference, page_number=page_number + ) + else: + # skip; citation has no url= or reference= + if preceding_text: + segments.append((preceding_text, None)) + cursor = match.end() + continue + + # Citation applies to the preceding text segment (e.g some text [citation]) + if preceding_text: + segments.append((preceding_text, citation)) + else: + # No preceding text (e.g. back-to-back citations) + segments.append(("", citation)) + + cursor = match.end() + + trailing_text = text[cursor:] + if trailing_text: + segments.append((trailing_text, None)) + + return segments + +def _find_partial_tag_start(text: str) -> int: + _TAG_PREFIX = "" + if "/>" in suffix: + return -1 + + # "<", " None: + self._buffer: str = "" + self._source_numbers: dict[_ParsedCitation, int] = {} + self._next_number: int = 1 + + def _build_content_part_citation( + self, + text: str, + citation: _ParsedCitation | None = None, + ) -> UiPathConversationContentPartChunkEvent: + if citation is None: + return UiPathConversationContentPartChunkEvent(data=text) + + if citation not in self._source_numbers: + self._source_numbers[citation] = self._next_number + self._next_number += 1 + number = self._source_numbers[citation] + + if citation.url is not None: + source = UiPathConversationCitationSourceUrl( + title=citation.title, + number=number, + url=citation.url, + ) + else: + source = UiPathConversationCitationSourceMedia( + title=citation.title, + number=number, + mime_type=None, + download_url=citation.reference, + page_number=citation.page_number, + ) + + return UiPathConversationContentPartChunkEvent( + data=text, + citation=UiPathConversationCitationEvent( + citation_id=str(uuid4()), + start=UiPathConversationCitationStartEvent(), + end=UiPathConversationCitationEndEvent(sources=[source]), + ), + ) + + def _process_segments(self, text: str) -> list[UiPathConversationContentPartChunkEvent]: + segments = _parse_citations(text) + if not segments: + return [] + + content_parts: list[UiPathConversationContentPartChunkEvent] = [] + for segment_text, citation in segments: + if citation is not None: + content_part_with_citation = self._build_content_part_citation(segment_text, citation) + content_parts.append(content_part_with_citation) + elif segment_text: + content_part_plain = self._build_content_part_citation(segment_text) + content_parts.append(content_part_plain) + + return content_parts + + def add_chunk(self, text: str) -> list[UiPathConversationContentPartChunkEvent]: + self._buffer += text + + partial_tag_start = _find_partial_tag_start(self._buffer) + if partial_tag_start >= 0: + completed_text = self._buffer[:partial_tag_start] + self._buffer = self._buffer[partial_tag_start:] + else: + completed_text = self._buffer + self._buffer = "" + + if not completed_text: + return [] + + return self._process_segments(completed_text) + + # Flush remaining content parts / citations + def finalize(self) -> list[UiPathConversationContentPartChunkEvent]: + if not self._buffer: + return [] + + remaining = self._buffer + self._buffer = "" + + return self._process_segments(remaining) diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 038d37aba..dad332170 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -32,6 +32,8 @@ ) from uipath.runtime import UiPathRuntimeStorageProtocol +from ._citations import CitationStreamProcessor + logger = logging.getLogger(__name__) STORAGE_NAMESPACE_EVENT_MAPPER = "chat-event-mapper" @@ -52,6 +54,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None self.current_message: AIMessageChunk self.seen_message_ids: set[str] = set() self._storage_lock = asyncio.Lock() + self._citation_stream_processor: CitationStreamProcessor | None = None def _extract_text(self, content: Any) -> str: """Normalize LangGraph message.content to plain text.""" @@ -256,6 +259,7 @@ async def map_ai_message_chunk_to_events( if message.id not in self.seen_message_ids: self.current_message = message self.seen_message_ids.add(message.id) + self._citation_stream_processor = CitationStreamProcessor() events.append(self.map_to_message_start_event(message.id)) if message.content_blocks: @@ -264,25 +268,27 @@ async def map_ai_message_chunk_to_events( block_type = block.get("type") match block_type: case "text": - events.append( - self.map_chunk_to_content_part_chunk_event( - message.id, cast(TextContentBlock, block) + text = cast(TextContentBlock, block)["text"] + for chunk in self._citation_stream_processor.add_chunk(text): + events.append( + self._chunk_to_message_event(message.id, chunk) ) - ) case "tool_call_chunk": # Accumulate the message chunk self.current_message = self.current_message + message elif isinstance(message.content, str) and message.content: # Fallback: raw string content on the chunk (rare when using content_blocks) - events.append( - self.map_content_to_content_part_chunk_event( - message.id, message.content - ) - ) + for chunk in self._citation_stream_processor.add_chunk(message.content): + events.append(self._chunk_to_message_event(message.id, chunk)) # Check if this is the last chunk by examining chunk_position, send end message event only if there are no pending tool calls if message.chunk_position == "last": + # Flush remaining text + for chunk in self._citation_stream_processor.finalize(): + events.append(self._chunk_to_message_event(message.id, chunk)) + self._citation_stream_processor = None + if ( self.current_message.tool_calls is not None and len(self.current_message.tool_calls) > 0 @@ -437,30 +443,14 @@ def map_tool_call_to_tool_call_start_event( ), ) - def map_chunk_to_content_part_chunk_event( - self, message_id: str, block: TextContentBlock - ) -> UiPathConversationMessageEvent: - text = block["text"] - return UiPathConversationMessageEvent( - message_id=message_id, - content_part=UiPathConversationContentPartEvent( - content_part_id=self.get_content_part_id(message_id), - chunk=UiPathConversationContentPartChunkEvent( - data=text, - ), - ), - ) - - def map_content_to_content_part_chunk_event( - self, message_id: str, content: str + def _chunk_to_message_event( + self, message_id: str, chunk: UiPathConversationContentPartChunkEvent ) -> UiPathConversationMessageEvent: return UiPathConversationMessageEvent( message_id=message_id, content_part=UiPathConversationContentPartEvent( content_part_id=self.get_content_part_id(message_id), - chunk=UiPathConversationContentPartChunkEvent( - data=content, - ), + chunk=chunk, ), )