diff --git a/langfuse/_client/attributes.py b/langfuse/_client/attributes.py index 5ae81000c..75c5645ea 100644 --- a/langfuse/_client/attributes.py +++ b/langfuse/_client/attributes.py @@ -18,7 +18,6 @@ ObservationTypeGenerationLike, ObservationTypeSpanLike, ) - from langfuse._utils.serializer import EventSerializer from langfuse.model import PromptClient from langfuse.types import MapValue, SpanLevel diff --git a/langfuse/_client/client.py b/langfuse/_client/client.py index ebc65e988..a98cbda2c 100644 --- a/langfuse/_client/client.py +++ b/langfuse/_client/client.py @@ -16,6 +16,7 @@ Any, Callable, Dict, + Generator, List, Literal, Optional, @@ -27,8 +28,15 @@ import backoff import httpx -from opentelemetry import trace -from opentelemetry import trace as otel_trace_api +from opentelemetry import ( + baggage as otel_baggage_api, +) +from opentelemetry import ( + context as otel_context_api, +) +from opentelemetry import ( + trace as otel_trace_api, +) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.id_generator import RandomIdGenerator from opentelemetry.util._decorator import ( @@ -39,6 +47,7 @@ from langfuse._client.attributes import LangfuseOtelSpanAttributes from langfuse._client.constants import ( + LANGFUSE_CORRELATION_CONTEXT_KEY, ObservationTypeGenerationLike, ObservationTypeLiteral, ObservationTypeLiteralNoEvent, @@ -69,7 +78,10 @@ LangfuseSpan, LangfuseTool, ) -from langfuse._client.utils import run_async_safely +from langfuse._client.utils import ( + get_attribute_key_from_correlation_context, + run_async_safely, +) from langfuse._utils import _get_timestamp from langfuse._utils.parse_error import handle_fern_exception from langfuse._utils.prompt_cache import PromptCache @@ -189,6 +201,7 @@ class Langfuse: _resources: Optional[LangfuseResourceManager] = None _mask: Optional[MaskFunction] = None _otel_tracer: otel_trace_api.Tracer + _host: str def __init__( self, @@ -350,6 +363,83 @@ def start_span( status_message=status_message, ) + @_agnosticcontextmanager + def correlation_context( + self, + correlation_context: Dict[str, str], + *, + as_baggage: bool = False, + ) -> Generator[None, None, None]: + """Create a context manager that propagates the given correlation_context to all spans within the context manager's scope. + + Args: + correlation_context (Dict[str, str]): Dictionary containing key-value pairs to be propagated + to all spans within the context manager's scope. Common keys include user_id, session_id, + and custom metadata. All values must be strings below 200 characters. + as_baggage (bool, optional): If True, stores the values in OpenTelemetry baggage + for cross-service propagation. If False, stores only in local context for + current-service propagation. Defaults to False. + + Returns: + Context manager that sets values on all spans created within its scope. + + Warning: + When as_baggage=True, the values will be included in HTTP headers of any + outbound requests made within this context. Only use this for non-sensitive + identifiers that are safe to transmit across service boundaries. + + Examples: + ```python + # Local context only (default) - pass context as dictionary + with langfuse.correlation_context({"session_id": "session_123"}): + with langfuse.start_as_current_span(name="process-request") as span: + # This span and all its children will have session_id="session_123" + child_span = langfuse.start_span(name="child-operation") + + # Multiple values in context dictionary + with langfuse.correlation_context({"user_id": "user_456", "experiment": "A"}): + # All spans will have both user_id and experiment attributes + span = langfuse.start_span(name="experiment-operation") + + # Cross-service propagation (use with caution) + with langfuse.correlation_context({"session_id": "session_123"}, as_baggage=True): + # session_id will be propagated to external service calls + response = requests.get("https://api.example.com/data") + ``` + """ + current_context = otel_context_api.get_current() + current_span = otel_trace_api.get_current_span() + + current_context = otel_context_api.set_value( + LANGFUSE_CORRELATION_CONTEXT_KEY, correlation_context, current_context + ) + + for key, value in correlation_context.items(): + if len(value) > 200: + langfuse_logger.warning( + f"Correlation context key '{key}' is over 200 characters ({len(value)} chars). Dropping value." + ) + continue + + attribute_key = get_attribute_key_from_correlation_context(key) + + if current_span is not None and current_span.is_recording(): + current_span.set_attribute(attribute_key, value) + + if as_baggage: + current_context = otel_baggage_api.set_baggage( + key, value, current_context + ) + + # Activate context, execute, and detach context + token = otel_context_api.attach(current_context) + + try: + yield + + finally: + otel_context_api.detach(token) + def start_as_current_span( self, *, @@ -1667,6 +1757,11 @@ def update_current_trace( span.update(output=response) ``` """ + warnings.warn( + "update_current_trace is deprecated and will be removed in a future version. Use `with langfuse.correlation_context(...)` instead. ", + DeprecationWarning, + stacklevel=2, + ) if not self._tracing_enabled: langfuse_logger.debug( "Operation skipped: update_current_trace - Tracing is disabled or client is in no-op mode." @@ -1811,7 +1906,7 @@ def _create_remote_parent_span( is_remote=False, ) - return trace.NonRecordingSpan(span_context) + return otel_trace_api.NonRecordingSpan(span_context) def _is_valid_trace_id(self, trace_id: str) -> bool: pattern = r"^[0-9a-f]{32}$" diff --git a/langfuse/_client/constants.py b/langfuse/_client/constants.py index b699480c0..b385f0a8f 100644 --- a/langfuse/_client/constants.py +++ b/langfuse/_client/constants.py @@ -3,11 +3,14 @@ This module defines constants used throughout the Langfuse OpenTelemetry integration. """ -from typing import Literal, List, get_args, Union, Any +from typing import Any, List, Literal, Union, get_args + from typing_extensions import TypeAlias LANGFUSE_TRACER_NAME = "langfuse-sdk" +LANGFUSE_CORRELATION_CONTEXT_KEY = "langfuse.ctx.correlation" + """Note: this type is used with .__args__ / get_args in some cases and therefore must remain flat""" ObservationTypeGenerationLike: TypeAlias = Literal[ diff --git a/langfuse/_client/span.py b/langfuse/_client/span.py index 9fa9c7489..fb9a2849d 100644 --- a/langfuse/_client/span.py +++ b/langfuse/_client/span.py @@ -13,9 +13,9 @@ and scoring integration specific to Langfuse's observability platform. """ +import warnings from datetime import datetime from time import time_ns -import warnings from typing import ( TYPE_CHECKING, Any, @@ -44,10 +44,10 @@ create_trace_attributes, ) from langfuse._client.constants import ( - ObservationTypeLiteral, ObservationTypeGenerationLike, - ObservationTypeSpanLike, + ObservationTypeLiteral, ObservationTypeLiteralNoEvent, + ObservationTypeSpanLike, get_observation_types_list, ) from langfuse.logger import langfuse_logger @@ -233,6 +233,11 @@ def update_trace( tags: List of tags to categorize the trace public: Whether the trace should be publicly accessible """ + warnings.warn( + "update_trace is deprecated and will be removed in a future version. Use `with langfuse.correlation_context(...)` instead. ", + DeprecationWarning, + stacklevel=2, + ) if not self._otel_span.is_recording(): return self diff --git a/langfuse/_client/span_processor.py b/langfuse/_client/span_processor.py index baa72360c..c62c5ad83 100644 --- a/langfuse/_client/span_processor.py +++ b/langfuse/_client/span_processor.py @@ -15,17 +15,29 @@ import os from typing import Dict, List, Optional +from opentelemetry import baggage +from opentelemetry import context as context_api +from opentelemetry.context import Context from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace import ReadableSpan, Span from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.trace import format_span_id -from langfuse._client.constants import LANGFUSE_TRACER_NAME +from langfuse._client.attributes import LangfuseOtelSpanAttributes +from langfuse._client.constants import ( + LANGFUSE_CORRELATION_CONTEXT_KEY, + LANGFUSE_TRACER_NAME, +) from langfuse._client.environment_variables import ( LANGFUSE_FLUSH_AT, LANGFUSE_FLUSH_INTERVAL, LANGFUSE_OTEL_TRACES_EXPORT_PATH, ) -from langfuse._client.utils import span_formatter +from langfuse._client.utils import ( + correlation_context_to_attribute_map, + get_attribute_key_from_correlation_context, + span_formatter, +) from langfuse.logger import langfuse_logger from langfuse.version import __version__ as langfuse_version @@ -114,6 +126,49 @@ def __init__( else None, ) + def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None: + # Propagate correlation context to span + current_context = parent_context or context_api.get_current() + propagated_attributes = {} + + # Propagate correlation context in baggage + baggage_entries = baggage.get_all(context=current_context) + + for key, value in baggage_entries.items(): + if ( + key.startswith(LangfuseOtelSpanAttributes.TRACE_METADATA) + or key in correlation_context_to_attribute_map.values() + ): + propagated_attributes[key] = value + + # Propagate correlation context in OTEL context + correlation_context = ( + context_api.get_value(LANGFUSE_CORRELATION_CONTEXT_KEY, current_context) + or {} + ) + + if not isinstance(correlation_context, dict): + langfuse_logger.error( + f"Correlation context is not of type dict. Got type '{type(correlation_context)}'." + ) + + return super().on_start(span, parent_context) + + for key, value in correlation_context.items(): + attribute_key = get_attribute_key_from_correlation_context(key) + propagated_attributes[attribute_key] = value + + # Write attributes on span + if propagated_attributes: + for key, value in propagated_attributes.items(): + span.set_attribute(key, str(value)) + + langfuse_logger.debug( + f"Propagated {len(propagated_attributes)} attributes to span '{format_span_id(span.context.span_id)}': {propagated_attributes}" + ) + + return super().on_start(span, parent_context) + def on_end(self, span: ReadableSpan) -> None: # Only export spans that belong to the scoped project # This is important to not send spans to wrong project in multi-project setups diff --git a/langfuse/_client/utils.py b/langfuse/_client/utils.py index d34857ebd..340daddb6 100644 --- a/langfuse/_client/utils.py +++ b/langfuse/_client/utils.py @@ -13,6 +13,8 @@ from opentelemetry.sdk import util from opentelemetry.sdk.trace import ReadableSpan +from langfuse._client.attributes import LangfuseOtelSpanAttributes + def span_formatter(span: ReadableSpan) -> str: parent_id = ( @@ -125,3 +127,16 @@ async def my_async_function(): else: # Loop exists but not running, safe to use asyncio.run() return asyncio.run(coro) + + +correlation_context_to_attribute_map = { + "session_id": LangfuseOtelSpanAttributes.TRACE_SESSION_ID, + "user_id": LangfuseOtelSpanAttributes.TRACE_USER_ID, +} + + +def get_attribute_key_from_correlation_context(correlation_context_key: str) -> str: + return ( + correlation_context_to_attribute_map.get(correlation_context_key) + or f"{LangfuseOtelSpanAttributes.TRACE_METADATA}.{correlation_context_key}" + ) diff --git a/tests/test_core_sdk.py b/tests/test_core_sdk.py index 26d11746c..829d9d971 100644 --- a/tests/test_core_sdk.py +++ b/tests/test_core_sdk.py @@ -338,7 +338,7 @@ def test_create_update_current_trace(): user_id="test", metadata={"key": "value"}, public=True, - input="test_input" + input="test_input", ) # Get trace ID for later reference trace_id = span.trace_id @@ -347,7 +347,9 @@ def test_create_update_current_trace(): sleep(1) # Update trace properties using update_current_trace - langfuse.update_current_trace(metadata={"key2": "value2"}, public=False, version="1.0") + langfuse.update_current_trace( + metadata={"key2": "value2"}, public=False, version="1.0" + ) # Ensure data is sent to the API langfuse.flush() @@ -1957,9 +1959,9 @@ def test_start_as_current_observation_types(): expected_types = {obs_type.upper() for obs_type in observation_types} | { "SPAN" } # includes parent span - assert expected_types.issubset(found_types), ( - f"Missing types: {expected_types - found_types}" - ) + assert expected_types.issubset( + found_types + ), f"Missing types: {expected_types - found_types}" # Verify each specific observation exists for obs_type in observation_types: @@ -2003,25 +2005,25 @@ def test_that_generation_like_properties_are_actually_created(): ) as obs: # Verify the properties are accessible on the observation object if hasattr(obs, "model"): - assert obs.model == test_model, ( - f"{obs_type} should have model property" - ) + assert ( + obs.model == test_model + ), f"{obs_type} should have model property" if hasattr(obs, "completion_start_time"): - assert obs.completion_start_time == test_completion_start_time, ( - f"{obs_type} should have completion_start_time property" - ) + assert ( + obs.completion_start_time == test_completion_start_time + ), f"{obs_type} should have completion_start_time property" if hasattr(obs, "model_parameters"): - assert obs.model_parameters == test_model_parameters, ( - f"{obs_type} should have model_parameters property" - ) + assert ( + obs.model_parameters == test_model_parameters + ), f"{obs_type} should have model_parameters property" if hasattr(obs, "usage_details"): - assert obs.usage_details == test_usage_details, ( - f"{obs_type} should have usage_details property" - ) + assert ( + obs.usage_details == test_usage_details + ), f"{obs_type} should have usage_details property" if hasattr(obs, "cost_details"): - assert obs.cost_details == test_cost_details, ( - f"{obs_type} should have cost_details property" - ) + assert ( + obs.cost_details == test_cost_details + ), f"{obs_type} should have cost_details property" langfuse.flush() @@ -2035,28 +2037,232 @@ def test_that_generation_like_properties_are_actually_created(): for obs in trace.observations if obs.name == f"test-{obs_type}" and obs.type == obs_type.upper() ] - assert len(observations) == 1, ( - f"Expected one {obs_type.upper()} observation, but found {len(observations)}" - ) + assert ( + len(observations) == 1 + ), f"Expected one {obs_type.upper()} observation, but found {len(observations)}" obs = observations[0] assert obs.model == test_model, f"{obs_type} should have model property" - assert obs.model_parameters == test_model_parameters, ( - f"{obs_type} should have model_parameters property" - ) + assert ( + obs.model_parameters == test_model_parameters + ), f"{obs_type} should have model_parameters property" # usage_details assert hasattr(obs, "usage_details"), f"{obs_type} should have usage_details" - assert obs.usage_details == dict(test_usage_details, total=30), ( - f"{obs_type} should persist usage_details" - ) # API adds total + assert obs.usage_details == dict( + test_usage_details, total=30 + ), f"{obs_type} should persist usage_details" # API adds total - assert obs.cost_details == test_cost_details, ( - f"{obs_type} should persist cost_details" - ) + assert ( + obs.cost_details == test_cost_details + ), f"{obs_type} should persist cost_details" # completion_start_time, because of time skew not asserting time - assert obs.completion_start_time is not None, ( - f"{obs_type} should persist completion_start_time property" - ) + assert ( + obs.completion_start_time is not None + ), f"{obs_type} should persist completion_start_time property" + + +def test_context_manager_user_propagation(): + """Test that user context manager propagates user_id to child spans.""" + langfuse = Langfuse() + + user_id = "test_user_123" + + with langfuse.start_as_current_span(name="parent-span") as parent_span: + with langfuse.correlation_context({"user_id": user_id}): + trace_id = parent_span.trace_id + + # Create child spans that should inherit user_id + child_span = langfuse.start_span(name="child-span") + child_span.end() + + # Create generation that should inherit user_id + generation = parent_span.start_generation(name="child-generation") + generation.end() + + langfuse.flush() + sleep(2) + + # Verify trace has user_id (child spans inherit via context propagation) + trace = get_api().trace.get(trace_id) + assert trace.user_id == user_id + + # Verify child observations were created and have user_id + child_observations = [ + obs + for obs in trace.observations + if obs.name in ["child-span", "child-generation"] + # Skip user.id validation as we currently drop it from the visible attributes server-side. + # and obs.metadata["attributes"]["user.id"] == user_id + ] + assert len(child_observations) == 2 + + +def test_context_manager_session_propagation(): + """Test that session context manager propagates session_id to child spans.""" + langfuse = Langfuse() + + session_id = "test_session_456" + + with langfuse.start_as_current_span(name="parent-span") as parent_span: + with langfuse.correlation_context({"session_id": session_id}): + trace_id = parent_span.trace_id + + # Create child spans that should inherit session_id + child_span = langfuse.start_span(name="child-span") + child_span.end() + + # Create nested context to test multiple levels + with langfuse.start_as_current_span(name="nested-span"): + grandchild_span = langfuse.start_span(name="grandchild-span") + grandchild_span.end() + + langfuse.flush() + sleep(2) + + # Verify trace has session_id + trace = get_api().trace.get(trace_id) + assert trace.session_id == session_id + + # Verify nested spans were created + nested_observations = [ + obs + for obs in trace.observations + if "span" in obs.name + # Skip session.id validation as we currently drop it from the visible attributes server-side. + # and obs.metadata["attributes"]["session.id"] == session_id + ] + assert len(nested_observations) >= 2 + + +def test_context_manager_metadata_propagation(): + """Test that metadata context manager propagates metadata to child spans.""" + langfuse = Langfuse() + + with langfuse.start_as_current_span(name="parent-span") as parent_span: + with langfuse.correlation_context( + { + "experiment": "A/B", + "version": "1.2.3", + "feature_flag": "enabled", + } + ): + trace_id = parent_span.trace_id + + # Create child spans that should inherit metadata + child_span = langfuse.start_span(name="child-span") + child_span.end() + + # Create generation that should inherit metadata + generation = parent_span.start_generation(name="child-generation") + generation.end() + + langfuse.flush() + sleep(2) + + # Verify trace has metadata + trace = get_api().trace.get(trace_id) + assert trace.metadata["experiment"] == "A/B" + assert trace.metadata["version"] == "1.2.3" + assert trace.metadata["feature_flag"] == "enabled" + + # Verify all observations have the metadata distributed as individual keys + for obs in trace.observations: + if obs.name in ["child-span", "child-generation", "parent-span"]: + # Check that metadata was set on the observation + assert hasattr(obs, "metadata"), f"Observation {obs.name} missing metadata" + assert ( + obs.metadata["experiment"] == "A/B" + ), f"Observation {obs.name} missing experiment metadata" + assert ( + obs.metadata["version"] == "1.2.3" + ), f"Observation {obs.name} missing version metadata" + assert ( + obs.metadata["feature_flag"] == "enabled" + ), f"Observation {obs.name} missing feature_flag metadata" + + +def test_context_manager_nested_contexts(): + """Test nested context managers with overrides and merging.""" + langfuse = Langfuse() + + with langfuse.start_as_current_span(name="outer-span") as outer_span: + with langfuse.correlation_context( + {"user_id": "user_1", "session_id": "session_1"} + ): + with langfuse.correlation_context({"env": "prod", "region": "us-east"}): + outer_trace_id = outer_span.trace_id + + # Create span in outer context + outer_child = langfuse.start_span(name="outer-child") + outer_child.end() + + nested_span = langfuse.start_span(name="nested-span") + nested_span.end() + + langfuse.flush() + sleep(2) + + # Verify trace was created with nested spans + trace = get_api().trace.get(outer_trace_id) + + # Verify trace-level properties from the context + assert trace.user_id == "user_1" + assert trace.session_id == "session_1" + assert trace.metadata["env"] == "prod" + assert trace.metadata["region"] == "us-east" + + # Verify child observations were created + child_observations = [ + obs for obs in trace.observations if "child" in obs.name or "nested" in obs.name + ] + assert len(child_observations) >= 2 + + # Verify specific child spans exist and have correct metadata + outer_child_obs = [obs for obs in trace.observations if obs.name == "outer-child"] + nested_span_obs = [obs for obs in trace.observations if obs.name == "nested-span"] + + assert len(outer_child_obs) == 1, "outer-child span should exist" + assert len(nested_span_obs) == 1, "nested-span should exist" + + +def test_context_manager_baggage_propagation(): + """Test context managers with as_baggage=True for cross-service propagation.""" + langfuse = Langfuse() + + # Test with baggage enabled (careful with sensitive data) + with langfuse.start_as_current_span(name="service-span") as span: + with langfuse.correlation_context( + {"session_id": "public_session_789"}, as_baggage=True + ): + with langfuse.correlation_context( + {"service": "api", "version": "v1.0"}, as_baggage=True + ): + trace_id = span.trace_id + + # Create child spans that inherit baggage context + child_span = langfuse.start_span(name="external-call-span") + child_span.end() + + langfuse.flush() + sleep(2) + + # Verify trace properties were set + trace = get_api().trace.get(trace_id) + assert trace.session_id == "public_session_789" + assert trace.metadata["service"] == "api" + assert trace.metadata["version"] == "v1.0" + + # Verify all observations have the metadata and session_id + for obs in trace.observations: + if obs.name in ["external-call-span", "service-span"]: + # Check that metadata was set on the observation + assert hasattr(obs, "metadata"), f"Observation {obs.name} missing metadata" + assert ( + obs.metadata["service"] == "api" + ), f"Observation {obs.name} missing service metadata" + assert ( + obs.metadata["version"] == "v1.0" + ), f"Observation {obs.name} missing version metadata"