Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions src/uipath_langchain/runtime/_citations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from uuid import uuid4

from uipath.core.chat import (
UiPathConversationCitationEndEvent,
UiPathConversationCitationEvent,
UiPathConversationCitationSourceMedia,
UiPathConversationCitationSourceUrl,
UiPathConversationCitationStartEvent,
UiPathConversationContentPartChunkEvent,
)

logger = logging.getLogger(__name__)

_TAG_RE = re.compile(r'<uip:cite\s+((?:[a-z_]+="[^"]*"\s*)+)/\s*>')
_ATTR_RE = re.compile(r'([a-z_]+)="([^"]*)"')

@dataclass(frozen=True) # frozen to make hashable / de-dupe sources
class _ParsedCitation:
title: str
url: str | None = None
reference: str | None = None
page_number: str | None = None

# <uip:cite .../> tags -> [(text_segment, citation_or_none)]
def _parse_citations(text: str) -> list[tuple[str, _ParsedCitation | None]]:
segments: list[tuple[str, _ParsedCitation | None]] = []
cursor = 0

for match in _TAG_RE.finditer(text):
preceding_text = text[cursor : match.start()]
raw_attributes = match.group(1)
attributes = dict(_ATTR_RE.findall(raw_attributes))

title = attributes.get("title", "")
url = attributes.get("url")
reference = attributes.get("reference")
page_number = attributes.get("page_number")

has_url = url is not None
has_reference = reference is not None

if has_url and not has_reference:
# web citation
citation = _ParsedCitation(
title=title, url=url, page_number=page_number
)
elif has_reference and not has_url:
# context grounding citation
citation = _ParsedCitation(
title=title, reference=reference, page_number=page_number
)
else:
# skip; citation has no url= or reference=
if preceding_text:
segments.append((preceding_text, None))
cursor = match.end()
continue

# Citation applies to the preceding text segment (e.g some text [citation])
if preceding_text:
segments.append((preceding_text, citation))
else:
# No preceding text (e.g. back-to-back citations)
segments.append(("", citation))

cursor = match.end()

trailing_text = text[cursor:]
if trailing_text:
segments.append((trailing_text, None))

return segments

def _find_partial_tag_start(text: str) -> int:
_TAG_PREFIX = "<uip:cite "

bracket_pos = text.rfind("<")
if bracket_pos == -1:
return -1

suffix = text[bracket_pos:]

# "<uip:cite title="some partial" />"
if "/>" in suffix:
return -1

# "<", "<u", "<uip:", "<uip:cite"
if len(suffix) <= len(_TAG_PREFIX):
if _TAG_PREFIX.startswith(suffix):
return bracket_pos
return -1

# "<uip:cite title="some partial"
if suffix.startswith(_TAG_PREFIX):
return bracket_pos

return -1


class CitationStreamProcessor:
def __init__(self) -> None:
self._buffer: str = ""
self._source_numbers: dict[_ParsedCitation, int] = {}
self._next_number: int = 1

def _build_content_part_citation(
self,
text: str,
citation: _ParsedCitation | None = None,
) -> UiPathConversationContentPartChunkEvent:
if citation is None:
return UiPathConversationContentPartChunkEvent(data=text)

if citation not in self._source_numbers:
self._source_numbers[citation] = self._next_number
self._next_number += 1
number = self._source_numbers[citation]

if citation.url is not None:
source = UiPathConversationCitationSourceUrl(
title=citation.title,
number=number,
url=citation.url,
)
else:
source = UiPathConversationCitationSourceMedia(
title=citation.title,
number=number,
mime_type=None,
download_url=citation.reference,
page_number=citation.page_number,
)

return UiPathConversationContentPartChunkEvent(
data=text,
citation=UiPathConversationCitationEvent(
citation_id=str(uuid4()),
start=UiPathConversationCitationStartEvent(),
end=UiPathConversationCitationEndEvent(sources=[source]),
),
)

def _process_segments(self, text: str) -> list[UiPathConversationContentPartChunkEvent]:
segments = _parse_citations(text)
if not segments:
return []

content_parts: list[UiPathConversationContentPartChunkEvent] = []
for segment_text, citation in segments:
if citation is not None:
content_part_with_citation = self._build_content_part_citation(segment_text, citation)
content_parts.append(content_part_with_citation)
elif segment_text:
content_part_plain = self._build_content_part_citation(segment_text)
content_parts.append(content_part_plain)

return content_parts

def add_chunk(self, text: str) -> list[UiPathConversationContentPartChunkEvent]:
self._buffer += text

partial_tag_start = _find_partial_tag_start(self._buffer)
if partial_tag_start >= 0:
completed_text = self._buffer[:partial_tag_start]
self._buffer = self._buffer[partial_tag_start:]
else:
completed_text = self._buffer
self._buffer = ""

if not completed_text:
return []

return self._process_segments(completed_text)

# Flush remaining content parts / citations
def finalize(self) -> list[UiPathConversationContentPartChunkEvent]:
if not self._buffer:
return []

remaining = self._buffer
self._buffer = ""

return self._process_segments(remaining)
46 changes: 18 additions & 28 deletions src/uipath_langchain/runtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
)
from uipath.runtime import UiPathRuntimeStorageProtocol

from ._citations import CitationStreamProcessor

logger = logging.getLogger(__name__)

STORAGE_NAMESPACE_EVENT_MAPPER = "chat-event-mapper"
Expand All @@ -52,6 +54,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None
self.current_message: AIMessageChunk
self.seen_message_ids: set[str] = set()
self._storage_lock = asyncio.Lock()
self._citation_stream_processor: CitationStreamProcessor | None = None

def _extract_text(self, content: Any) -> str:
"""Normalize LangGraph message.content to plain text."""
Expand Down Expand Up @@ -256,6 +259,7 @@ async def map_ai_message_chunk_to_events(
if message.id not in self.seen_message_ids:
self.current_message = message
self.seen_message_ids.add(message.id)
self._citation_stream_processor = CitationStreamProcessor()
events.append(self.map_to_message_start_event(message.id))

if message.content_blocks:
Expand All @@ -264,25 +268,27 @@ async def map_ai_message_chunk_to_events(
block_type = block.get("type")
match block_type:
case "text":
events.append(
self.map_chunk_to_content_part_chunk_event(
message.id, cast(TextContentBlock, block)
text = cast(TextContentBlock, block)["text"]
for chunk in self._citation_stream_processor.add_chunk(text):
events.append(
self._chunk_to_message_event(message.id, chunk)
)
)
case "tool_call_chunk":
# Accumulate the message chunk
self.current_message = self.current_message + message

elif isinstance(message.content, str) and message.content:
# Fallback: raw string content on the chunk (rare when using content_blocks)
events.append(
self.map_content_to_content_part_chunk_event(
message.id, message.content
)
)
for chunk in self._citation_stream_processor.add_chunk(message.content):
events.append(self._chunk_to_message_event(message.id, chunk))

# Check if this is the last chunk by examining chunk_position, send end message event only if there are no pending tool calls
if message.chunk_position == "last":
# Flush remaining text
for chunk in self._citation_stream_processor.finalize():
events.append(self._chunk_to_message_event(message.id, chunk))
self._citation_stream_processor = None

if (
self.current_message.tool_calls is not None
and len(self.current_message.tool_calls) > 0
Expand Down Expand Up @@ -437,30 +443,14 @@ def map_tool_call_to_tool_call_start_event(
),
)

def map_chunk_to_content_part_chunk_event(
self, message_id: str, block: TextContentBlock
) -> UiPathConversationMessageEvent:
text = block["text"]
return UiPathConversationMessageEvent(
message_id=message_id,
content_part=UiPathConversationContentPartEvent(
content_part_id=self.get_content_part_id(message_id),
chunk=UiPathConversationContentPartChunkEvent(
data=text,
),
),
)

def map_content_to_content_part_chunk_event(
self, message_id: str, content: str
def _chunk_to_message_event(
self, message_id: str, chunk: UiPathConversationContentPartChunkEvent
) -> UiPathConversationMessageEvent:
return UiPathConversationMessageEvent(
message_id=message_id,
content_part=UiPathConversationContentPartEvent(
content_part_id=self.get_content_part_id(message_id),
chunk=UiPathConversationContentPartChunkEvent(
data=content,
),
chunk=chunk,
),
)

Expand Down
Loading