Skip to content

Commit f68bbe8

Browse files
authored
refactor(llama-index): send generation updates directly from event handler (#981)
1 parent 9b24854 commit f68bbe8

File tree

3 files changed

+41
-47
lines changed

3 files changed

+41
-47
lines changed

langfuse/llama_index/_event_handler.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
from typing import Optional, Any, Union, Dict, Mapping
1+
from typing import Optional, Any, Union, Mapping
22

33
from langfuse.client import (
44
Langfuse,
5+
StatefulGenerationClient,
6+
StateType,
57
)
8+
from langfuse.utils import _get_timestamp
69
from langfuse.model import ModelUsage
7-
10+
from ._context import InstrumentorContext
11+
from uuid import uuid4 as create_uuid
812

913
try:
1014
from llama_index.core.base.llms.types import (
@@ -36,17 +40,12 @@
3640

3741

3842
class LlamaIndexEventHandler(BaseEventHandler, extra="allow"):
39-
def __init__(
40-
self,
41-
*,
42-
langfuse_client: Langfuse,
43-
observation_updates: Dict[str, Dict[str, Any]],
44-
):
43+
def __init__(self, *, langfuse_client: Langfuse):
4544
super().__init__()
4645

4746
self._langfuse = langfuse_client
48-
self._observation_updates = observation_updates
4947
self._token_counter = TokenCounter()
48+
self._context = InstrumentorContext()
5049

5150
@classmethod
5251
def class_name(cls) -> str:
@@ -92,8 +91,8 @@ def update_generation_from_start_event(
9291
]
9392
}
9493

95-
self._update_observation_updates(
96-
event.span_id, model=model, model_parameters=traced_model_data
94+
self._get_generation_client(event.span_id).update(
95+
model=model, model_parameters=traced_model_data
9796
)
9897

9998
def update_generation_from_end_event(
@@ -119,13 +118,9 @@ def update_generation_from_end_event(
119118
"total": token_count or None,
120119
}
121120

122-
self._update_observation_updates(event.span_id, usage=usage)
123-
124-
def _update_observation_updates(self, id_: str, **kwargs) -> None:
125-
if id_ not in self._observation_updates:
126-
return
127-
128-
self._observation_updates[id_].update(kwargs)
121+
self._get_generation_client(event.span_id).update(
122+
usage=usage, end_time=_get_timestamp()
123+
)
129124

130125
def _parse_token_usage(
131126
self, response: Union[ChatResponse, CompletionResponse]
@@ -140,6 +135,22 @@ def _parse_token_usage(
140135
if additional_kwargs := getattr(response, "additional_kwargs", None):
141136
return _parse_usage_from_mapping(additional_kwargs)
142137

138+
def _get_generation_client(self, id: str) -> StatefulGenerationClient:
139+
trace_id = self._context.trace_id
140+
if trace_id is None:
141+
logger.warning(
142+
"Trace ID is not set. Creating generation client with new trace id."
143+
)
144+
trace_id = str(create_uuid())
145+
146+
return StatefulGenerationClient(
147+
client=self._langfuse.client,
148+
id=id,
149+
trace_id=trace_id,
150+
task_manager=self._langfuse.task_manager,
151+
state_type=StateType.OBSERVATION,
152+
)
153+
143154

144155
def _parse_usage_from_mapping(
145156
usage: Union[object, Mapping[str, Any]],

langfuse/llama_index/_instrumentor.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,8 @@ def __init__(
9898
mask=mask,
9999
sdk_integration="llama-index_instrumentation",
100100
)
101-
self._observation_updates = {}
102-
self._span_handler = LlamaIndexSpanHandler(
103-
langfuse_client=self._langfuse,
104-
observation_updates=self._observation_updates,
105-
)
106-
self._event_handler = LlamaIndexEventHandler(
107-
langfuse_client=self._langfuse,
108-
observation_updates=self._observation_updates,
109-
)
101+
self._span_handler = LlamaIndexSpanHandler(langfuse_client=self._langfuse)
102+
self._event_handler = LlamaIndexEventHandler(langfuse_client=self._langfuse)
110103
self._context = InstrumentorContext()
111104

112105
def start(self):

langfuse/llama_index/_span_handler.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Optional, Any, Tuple, Dict, Generator, AsyncGenerator
2+
from typing import Optional, Any, Tuple, Generator, AsyncGenerator
33
import uuid
44

55
from langfuse.client import (
@@ -39,16 +39,10 @@ class LangfuseSpan(BaseSpan):
3939

4040

4141
class LlamaIndexSpanHandler(BaseSpanHandler[LangfuseSpan], extra="allow"):
42-
def __init__(
43-
self,
44-
*,
45-
langfuse_client: Langfuse,
46-
observation_updates: Dict[str, Dict[str, Any]],
47-
):
42+
def __init__(self, *, langfuse_client: Langfuse):
4843
super().__init__()
4944

5045
self._langfuse_client = langfuse_client
51-
self._observation_updates = observation_updates
5246
self._context = InstrumentorContext()
5347

5448
def new_span(
@@ -109,9 +103,6 @@ def new_span(
109103
metadata=kwargs,
110104
)
111105

112-
# Initialize observation update for the span to be populated by event handler
113-
self._observation_updates[id_] = {}
114-
115106
def prepare_to_exit_span(
116107
self,
117108
id_: str,
@@ -122,7 +113,6 @@ def prepare_to_exit_span(
122113
) -> Optional[LangfuseSpan]:
123114
logger.debug(f"Exiting span {instance.__class__.__name__} with ID {id_}")
124115

125-
observation_updates = self._observation_updates.pop(id_, {})
126116
output, metadata = self._parse_output_metadata(instance, result)
127117

128118
# Reset the context root if the span is the root span
@@ -138,15 +128,13 @@ def prepare_to_exit_span(
138128
if self._is_generation(id_, instance):
139129
generationClient = self._get_generation_client(id_)
140130
generationClient.end(
141-
**observation_updates,
142131
output=output,
143132
metadata=metadata,
144133
)
145134

146135
else:
147136
spanClient = self._get_span_client(id_)
148137
spanClient.end(
149-
**observation_updates,
150138
output=output,
151139
metadata=metadata,
152140
)
@@ -161,8 +149,6 @@ def prepare_to_drop_span(
161149
) -> Optional[LangfuseSpan]:
162150
logger.debug(f"Dropping span {instance.__class__.__name__} with ID {id_}")
163151

164-
observation_updates = self._observation_updates.pop(id_, {})
165-
166152
# Reset the context root if the span is the root span
167153
if id_ == self._context.root_llama_index_span_id:
168154
if self._context.update_parent:
@@ -177,15 +163,13 @@ def prepare_to_drop_span(
177163
if self._is_generation(id_, instance):
178164
generationClient = self._get_generation_client(id_)
179165
generationClient.end(
180-
**observation_updates,
181166
level="ERROR",
182167
status_message=str(err),
183168
)
184169

185170
else:
186171
spanClient = self._get_span_client(id_)
187172
spanClient.end(
188-
**observation_updates,
189173
level="ERROR",
190174
status_message=str(err),
191175
)
@@ -217,7 +201,10 @@ def _is_generation(self, id_: str, instance: Optional[Any] = None) -> bool:
217201
def _get_generation_client(self, id: str) -> StatefulGenerationClient:
218202
trace_id = self._context.trace_id
219203
if trace_id is None:
220-
raise ValueError("Trace ID is not set")
204+
logger.warning(
205+
"Trace ID is not set. Creating generation client with new trace id."
206+
)
207+
trace_id = str(uuid.uuid4())
221208

222209
return StatefulGenerationClient(
223210
client=self._langfuse_client.client,
@@ -230,7 +217,10 @@ def _get_generation_client(self, id: str) -> StatefulGenerationClient:
230217
def _get_span_client(self, id: str) -> StatefulSpanClient:
231218
trace_id = self._context.trace_id
232219
if trace_id is None:
233-
raise ValueError("Trace ID is not set")
220+
logger.warning(
221+
"Trace ID is not set. Creating generation client with new trace id."
222+
)
223+
trace_id = str(uuid.uuid4())
234224

235225
return StatefulSpanClient(
236226
client=self._langfuse_client.client,

0 commit comments

Comments
 (0)