|
11 | 11 | ) |
12 | 12 |
|
13 | 13 | from logging import getLogger |
| 14 | +from pydantic import BaseModel |
| 15 | + |
14 | 16 | from ._context import InstrumentorContext |
15 | 17 |
|
16 | 18 | logger = getLogger(__name__) |
|
25 | 27 | StreamingResponse, |
26 | 28 | AsyncStreamingResponse, |
27 | 29 | ) |
| 30 | + from llama_index.core.workflow import Context |
28 | 31 |
|
29 | 32 | except ImportError: |
30 | 33 | raise ModuleNotFoundError( |
@@ -243,6 +246,9 @@ def _parse_generation_input( |
243 | 246 | def _parse_output_metadata( |
244 | 247 | self, instance: Optional[Any], result: Optional[Any] |
245 | 248 | ) -> Tuple[Optional[Any], Optional[Any]]: |
| 249 | + if isinstance(result, BaseModel): |
| 250 | + return result.__dict__, None |
| 251 | + |
246 | 252 | if not result or isinstance( |
247 | 253 | result, |
248 | 254 | (Generator, AsyncGenerator, StreamingResponse, AsyncStreamingResponse), |
@@ -289,4 +295,12 @@ def _parse_input(self, *, bound_args: inspect.BoundArguments): |
289 | 295 | if "nodes" in arguments: |
290 | 296 | return {"num_nodes": len(arguments["nodes"])} |
291 | 297 |
|
| 298 | + # Remove Context since it is in not properly serialized |
| 299 | + ctx_key = None |
| 300 | + for arg, val in arguments.items(): |
| 301 | + if isinstance(val, Context): |
| 302 | + ctx_key = arg |
| 303 | + if ctx_key in arguments: |
| 304 | + return {arg: val for arg, val in arguments.items() if arg != ctx_key} |
| 305 | + |
292 | 306 | return arguments |
0 commit comments