Skip to content

Commit e6c1a99

Browse files
authored
Merge pull request #1040 from malmiron/sdk-chat-stream
feat(ia): Add stream implementation for chat
2 parents 29bae0e + 51921ca commit e6c1a99

File tree

3 files changed

+2193
-1
lines changed

3 files changed

+2193
-1
lines changed

gooddata-sdk/gooddata_sdk/compute/service.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# (C) 2022 GoodData Corporation
22
from __future__ import annotations
33

4+
import json
45
import logging
6+
from collections.abc import Iterator
57
from typing import Any, Optional
68

79
from gooddata_api_client import ApiException
@@ -89,12 +91,48 @@ def ai_chat(self, workspace_id: str, question: str) -> ChatResult:
8991
workspace_id: workspace identifier
9092
question: question to ask AI
9193
Returns:
92-
str: Chat response
94+
ChatResult: Chat response
9395
"""
9496
chat_request = ChatRequest(question=question)
9597
response = self._actions_api.ai_chat(workspace_id, chat_request, _check_return_type=False)
9698
return response
9799

100+
def _parse_sse_events(self, raw: str) -> Iterator[Any]:
101+
"""Helper to parse SSE events and yield JSON from data lines."""
102+
events = raw.split("\n\n")
103+
for event in events:
104+
for line in event.split("\n"):
105+
if line.startswith("data:"):
106+
try:
107+
yield json.loads(line[5:].strip())
108+
except json.JSONDecodeError:
109+
continue
110+
111+
def ai_chat_stream(self, workspace_id: str, question: str) -> Iterator[Any]:
112+
"""
113+
Chat Stream with AI in GoodData workspace.
114+
115+
Args:
116+
workspace_id: workspace identifier
117+
question: question to ask AI
118+
Returns:
119+
Iterator[Any]: Yields parsed JSON objects from each SSE event's data field
120+
"""
121+
chat_request = ChatRequest(question=question)
122+
response = self._actions_api.ai_chat_stream(
123+
workspace_id, chat_request, _check_return_type=False, _preload_content=False
124+
)
125+
buffer = ""
126+
try:
127+
for chunk in response.stream(decode_content=True):
128+
if chunk:
129+
buffer += chunk.decode("utf-8")
130+
*events, buffer = buffer.split("\n\n")
131+
for event in events:
132+
yield from self._parse_sse_events(event)
133+
finally:
134+
response.release_conn()
135+
98136
def get_ai_chat_history(
99137
self,
100138
workspace_id: str,

0 commit comments

Comments
 (0)