|
1 | 1 | # (C) 2022 GoodData Corporation |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
| 4 | +import json |
4 | 5 | import logging |
| 6 | +from collections.abc import Iterator |
5 | 7 | from typing import Any, Optional |
6 | 8 |
|
7 | 9 | from gooddata_api_client import ApiException |
@@ -89,12 +91,48 @@ def ai_chat(self, workspace_id: str, question: str) -> ChatResult: |
89 | 91 | workspace_id: workspace identifier |
90 | 92 | question: question to ask AI |
91 | 93 | Returns: |
92 | | - str: Chat response |
| 94 | + ChatResult: Chat response |
93 | 95 | """ |
94 | 96 | chat_request = ChatRequest(question=question) |
95 | 97 | response = self._actions_api.ai_chat(workspace_id, chat_request, _check_return_type=False) |
96 | 98 | return response |
97 | 99 |
|
| 100 | + def _parse_sse_events(self, raw: str) -> Iterator[Any]: |
| 101 | + """Helper to parse SSE events and yield JSON from data lines.""" |
| 102 | + events = raw.split("\n\n") |
| 103 | + for event in events: |
| 104 | + for line in event.split("\n"): |
| 105 | + if line.startswith("data:"): |
| 106 | + try: |
| 107 | + yield json.loads(line[5:].strip()) |
| 108 | + except json.JSONDecodeError: |
| 109 | + continue |
| 110 | + |
| 111 | + def ai_chat_stream(self, workspace_id: str, question: str) -> Iterator[Any]: |
| 112 | + """ |
| 113 | + Chat Stream with AI in GoodData workspace. |
| 114 | +
|
| 115 | + Args: |
| 116 | + workspace_id: workspace identifier |
| 117 | + question: question to ask AI |
| 118 | + Returns: |
| 119 | + Iterator[Any]: Yields parsed JSON objects from each SSE event's data field |
| 120 | + """ |
| 121 | + chat_request = ChatRequest(question=question) |
| 122 | + response = self._actions_api.ai_chat_stream( |
| 123 | + workspace_id, chat_request, _check_return_type=False, _preload_content=False |
| 124 | + ) |
| 125 | + buffer = "" |
| 126 | + try: |
| 127 | + for chunk in response.stream(decode_content=True): |
| 128 | + if chunk: |
| 129 | + buffer += chunk.decode("utf-8") |
| 130 | + *events, buffer = buffer.split("\n\n") |
| 131 | + for event in events: |
| 132 | + yield from self._parse_sse_events(event) |
| 133 | + finally: |
| 134 | + response.release_conn() |
| 135 | + |
98 | 136 | def get_ai_chat_history( |
99 | 137 | self, |
100 | 138 | workspace_id: str, |
|
0 commit comments