diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 108255163..f2bf3ac35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -90,7 +90,13 @@ jobs: rm -rf .env echo "::group::Run server" - TELEMETRY_ENABLED=false LANGFUSE_S3_MEDIA_UPLOAD_ENDPOINT=http://localhost:9090 LANGFUSE_SDK_CI_SYNC_PROCESSING_ENABLED=true LANGFUSE_READ_FROM_POSTGRES_ONLY=true LANGFUSE_READ_FROM_CLICKHOUSE_ONLY=false LANGFUSE_RETURN_FROM_CLICKHOUSE=false docker compose up -d + + TELEMETRY_ENABLED=false \ + LANGFUSE_S3_MEDIA_UPLOAD_ENDPOINT=http://localhost:9090 \ + LANGFUSE_INGESTION_QUEUE_DELAY_MS=10 \ + LANGFUSE_INGESTION_CLICKHOUSE_WRITE_INTERVAL_MS=10 \ + docker compose up -d + echo "::endgroup::" # Add this step to check the health of the container @@ -149,7 +155,7 @@ jobs: - name: Run the automated tests run: | python --version - poetry run pytest -s -v --log-cli-level=INFO + poetry run pytest -n auto -s -v --log-cli-level=INFO all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix diff --git a/langfuse/callback/langchain.py b/langfuse/callback/langchain.py index e3bb29b88..fcd2058f9 100644 --- a/langfuse/callback/langchain.py +++ b/langfuse/callback/langchain.py @@ -1,9 +1,9 @@ -from collections import defaultdict -import httpx import logging import typing import warnings +from collections import defaultdict +import httpx import pydantic try: # Test that langchain is installed before proceeding @@ -15,16 +15,17 @@ ) from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID, uuid4 + from langfuse.api.resources.ingestion.types.sdk_log_body import SdkLogBody from langfuse.client import ( + StatefulGenerationClient, StatefulSpanClient, StatefulTraceClient, - StatefulGenerationClient, ) from langfuse.extract_model import _extract_model_name +from langfuse.types import MaskFunction from langfuse.utils import _get_timestamp from langfuse.utils.base_callback_handler import LangfuseBaseCallbackHandler -from langfuse.types import MaskFunction try: from langchain.callbacks.base import ( @@ -32,18 +33,18 @@ ) from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.document import Document - from langchain_core.outputs import ( - ChatGeneration, - LLMResult, - ) from langchain_core.messages import ( AIMessage, BaseMessage, ChatMessage, + FunctionMessage, HumanMessage, SystemMessage, ToolMessage, - FunctionMessage, + ) + from langchain_core.outputs import ( + ChatGeneration, + LLMResult, ) except ImportError: raise ModuleNotFoundError( @@ -149,7 +150,9 @@ def on_llm_new_token( self.updated_completion_start_time_memo.add(run_id) - def get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str: + def get_langchain_run_name( + self, serialized: Optional[Dict[str, Any]], **kwargs: Any + ) -> str: """Retrieve the name of a serialized LangChain runnable. The prioritization for the determination of the run name is as follows: @@ -1055,16 +1058,24 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]): ] usage_model = usage.copy() # Copy all existing key-value pairs - for model_key, langfuse_key in conversion_list: - if model_key in usage_model: - captured_count = usage_model.pop(model_key) - final_count = ( - sum(captured_count) - if isinstance(captured_count, list) - else captured_count - ) # For Bedrock, the token count is a list when streamed - - usage_model[langfuse_key] = final_count # Translate key and keep the value + + # Skip OpenAI usage types as they are handled server side + if not all( + openai_key in usage_model + for openai_key in ["prompt_tokens", "completion_tokens", "total_tokens"] + ): + for model_key, langfuse_key in conversion_list: + if model_key in usage_model: + captured_count = usage_model.pop(model_key) + final_count = ( + sum(captured_count) + if isinstance(captured_count, list) + else captured_count + ) # For Bedrock, the token count is a list when streamed + + usage_model[langfuse_key] = ( + final_count # Translate key and keep the value + ) if isinstance(usage_model, dict): if "input_token_details" in usage_model: diff --git a/tests/api_wrapper.py b/tests/api_wrapper.py index b69ef4b7e..42f941550 100644 --- a/tests/api_wrapper.py +++ b/tests/api_wrapper.py @@ -1,4 +1,5 @@ import os +from time import sleep import httpx @@ -11,23 +12,27 @@ def __init__(self, username=None, password=None, base_url=None): self.BASE_URL = base_url if base_url else os.environ["LANGFUSE_HOST"] def get_observation(self, observation_id): + sleep(1) url = f"{self.BASE_URL}/api/public/observations/{observation_id}" response = httpx.get(url, auth=self.auth) return response.json() def get_scores(self, page=None, limit=None, user_id=None, name=None): + sleep(1) params = {"page": page, "limit": limit, "userId": user_id, "name": name} url = f"{self.BASE_URL}/api/public/scores" response = httpx.get(url, params=params, auth=self.auth) return response.json() def get_traces(self, page=None, limit=None, user_id=None, name=None): + sleep(1) params = {"page": page, "limit": limit, "userId": user_id, "name": name} url = f"{self.BASE_URL}/api/public/traces" response = httpx.get(url, params=params, auth=self.auth) return response.json() def get_trace(self, trace_id): + sleep(1) url = f"{self.BASE_URL}/api/public/traces/{trace_id}" response = httpx.get(url, auth=self.auth) return response.json() diff --git a/tests/test_core_sdk.py b/tests/test_core_sdk.py index 99953ba2a..a0ebf27d6 100644 --- a/tests/test_core_sdk.py +++ b/tests/test_core_sdk.py @@ -2,6 +2,7 @@ import time from asyncio import gather from datetime import datetime, timedelta, timezone +from time import sleep import pytest @@ -45,7 +46,7 @@ async def update_generation(i, langfuse: Langfuse): for i in range(100): observation = api.observations.get_many(name=str(i)).data[0] assert observation.name == str(i) - assert observation.metadata == {"count": str(i)} + assert observation.metadata == {"count": i} def test_flush(): @@ -208,13 +209,12 @@ def test_create_categorical_score(): assert trace["scores"][0]["id"] == score_id assert trace["scores"][0]["dataType"] == "CATEGORICAL" - assert trace["scores"][0]["value"] is None + assert trace["scores"][0]["value"] == 0 assert trace["scores"][0]["stringValue"] == "high score" def test_create_trace(): langfuse = Langfuse(debug=False) - api_wrapper = LangfuseAPI() trace_name = create_uuid() trace = langfuse.trace( @@ -226,8 +226,9 @@ def test_create_trace(): ) langfuse.flush() + sleep(2) - trace = api_wrapper.get_trace(trace.id) + trace = LangfuseAPI().get_trace(trace.id) assert trace["name"] == trace_name assert trace["userId"] == "test" @@ -238,8 +239,8 @@ def test_create_trace(): def test_create_update_trace(): - langfuse = Langfuse(debug=True, flush_at=1) - api = get_api() + langfuse = Langfuse() + trace_name = create_uuid() trace = langfuse.trace( @@ -248,21 +249,21 @@ def test_create_update_trace(): metadata={"key": "value"}, public=True, ) - trace.update(metadata={"key": "value2"}, public=False) + sleep(1) + trace.update(metadata={"key2": "value2"}, public=False) langfuse.flush() - trace = api.trace.get(trace.id) + trace = get_api().trace.get(trace.id) assert trace.name == trace_name assert trace.user_id == "test" - assert trace.metadata == {"key": "value2"} + assert trace.metadata == {"key": "value", "key2": "value2"} assert trace.public is False def test_create_generation(): langfuse = Langfuse(debug=True) - api = get_api() timestamp = _get_timestamp() generation_id = create_uuid() @@ -294,11 +295,11 @@ def test_create_generation(): trace_id = langfuse.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.name == "query-generation" assert trace.user_id is None - assert trace.metadata is None + assert trace.metadata == {} assert len(trace.observations) == 1 @@ -342,7 +343,6 @@ def test_create_generation(): None, None, ), - (LlmUsage(promptTokens=51, totalTokens=100), "TOKENS", None, None, None), ( { "input": 51, @@ -373,13 +373,6 @@ def test_create_generation(): 200, 300, ), - ( - {"input": 51, "total": 100}, - None, - None, - None, - None, - ), ( LlmUsageWithCost( promptTokens=51, @@ -394,20 +387,6 @@ def test_create_generation(): 200, 300, ), - ( - LlmUsageWithCost( - promptTokens=51, - completionTokens=0, - totalTokens=100, - inputCost=0.0021, - outputCost=0.00000000000021, - totalCost=None, - ), - "TOKENS", - 0.0021, - 0.00000000000021, - 0.00210000000021, - ), ], ) def test_create_generation_complex( @@ -418,7 +397,6 @@ def test_create_generation_complex( expected_total_cost, ): langfuse = Langfuse(debug=False) - api = get_api() generation_id = create_uuid() langfuse.generation( @@ -440,11 +418,11 @@ def test_create_generation_complex( trace_id = langfuse.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.name == "query-generation" assert trace.user_id is None - assert trace.metadata is None + assert trace.metadata == {} assert len(trace.observations) == 1 @@ -460,20 +438,22 @@ def test_create_generation_complex( }, ] assert generation.output == [{"foo": "bar"}] - assert generation.metadata == [{"tags": ["yo"]}] + assert generation.metadata["metadata"] == [{"tags": ["yo"]}] assert generation.start_time is not None - assert generation.usage.input == 51 - assert generation.usage.output == 0 - assert generation.usage.total == 100 - assert generation.calculated_input_cost == expected_input_cost - assert generation.calculated_output_cost == expected_output_cost - assert generation.calculated_total_cost == expected_total_cost - assert generation.usage.unit == expected_usage + assert generation.usage_details == {"input": 51, "output": 0, "total": 100} + assert generation.cost_details == ( + { + "input": expected_input_cost, + "output": expected_output_cost, + "total": expected_total_cost, + } + if any([expected_input_cost, expected_output_cost, expected_total_cost]) + else {} + ) def test_create_span(): langfuse = Langfuse(debug=False) - api = get_api() timestamp = _get_timestamp() span_id = create_uuid() @@ -491,11 +471,11 @@ def test_create_span(): trace_id = langfuse.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.name == "span" assert trace.user_id is None - assert trace.metadata is None + assert trace.metadata == {} assert len(trace.observations) == 1 @@ -546,7 +526,6 @@ def test_score_trace(): def test_score_trace_nested_trace(): langfuse = Langfuse(debug=False) - api = get_api() trace_name = create_uuid() @@ -562,7 +541,7 @@ def test_score_trace_nested_trace(): trace_id = langfuse.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.name == trace_name @@ -579,7 +558,6 @@ def test_score_trace_nested_trace(): def test_score_trace_nested_observation(): langfuse = Langfuse(debug=False) - api = get_api() trace_name = create_uuid() @@ -596,7 +574,7 @@ def test_score_trace_nested_observation(): trace_id = langfuse.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.name == trace_name @@ -655,7 +633,6 @@ def test_score_span(): def test_create_trace_and_span(): langfuse = Langfuse(debug=False) - api = get_api() trace_name = create_uuid() spanId = create_uuid() @@ -665,7 +642,7 @@ def test_create_trace_and_span(): langfuse.flush() - trace = api.trace.get(trace.id) + trace = get_api().trace.get(trace.id) assert trace.name == trace_name assert len(trace.observations) == 1 @@ -678,7 +655,6 @@ def test_create_trace_and_span(): def test_create_trace_and_generation(): langfuse = Langfuse(debug=False) - api = get_api() trace_name = create_uuid() generationId = create_uuid() @@ -695,8 +671,8 @@ def test_create_trace_and_generation(): langfuse.flush() + dbTrace = get_api().trace.get(trace.id) getTrace = langfuse.get_trace(trace.id) - dbTrace = api.trace.get(trace.id) assert dbTrace.name == trace_name assert len(dbTrace.observations) == 1 @@ -713,13 +689,12 @@ def test_create_trace_and_generation(): def backwards_compatibility_sessionId(): langfuse = Langfuse(debug=False) - api = get_api() trace = langfuse.trace(name="test", sessionId="test-sessionId") langfuse.flush() - trace = api.trace.get(trace.id) + trace = get_api().trace.get(trace.id) assert trace.name == "test" assert trace.session_id == "test-sessionId" @@ -755,6 +730,7 @@ def test_create_generation_and_trace(): langfuse.trace(id=trace_id, name=trace_name) langfuse.flush() + sleep(2) trace = api_wrapper.get_trace(trace_id) @@ -773,6 +749,7 @@ def test_create_span_and_get_observation(): langfuse.span(id=span_id, name="span") langfuse.flush() + sleep(2) observation = langfuse.get_observation(span_id) assert observation.name == "span" assert observation.id == span_id @@ -780,7 +757,7 @@ def test_create_span_and_get_observation(): def test_update_generation(): langfuse = Langfuse(debug=False) - api = get_api() + start = _get_timestamp() generation = langfuse.generation(name="generation") @@ -788,7 +765,7 @@ def test_update_generation(): langfuse.flush() - trace = api.trace.get(generation.trace_id) + trace = get_api().trace.get(generation.trace_id) assert trace.name == "generation" assert len(trace.observations) == 1 @@ -803,14 +780,13 @@ def test_update_generation(): def test_update_span(): langfuse = Langfuse(debug=False) - api = get_api() span = langfuse.span(name="span") span.update(metadata={"dict": "value"}) langfuse.flush() - trace = api.trace.get(span.trace_id) + trace = get_api().trace.get(span.trace_id) assert trace.name == "span" assert len(trace.observations) == 1 @@ -823,13 +799,12 @@ def test_update_span(): def test_create_event(): langfuse = Langfuse(debug=False) - api = get_api() event = langfuse.event(name="event") langfuse.flush() - observation = api.observations.get(event.id) + observation = get_api().observations.get(event.id) assert observation.type == "EVENT" assert observation.name == "event" @@ -837,7 +812,6 @@ def test_create_event(): def test_create_trace_and_event(): langfuse = Langfuse(debug=False) - api = get_api() trace_name = create_uuid() eventId = create_uuid() @@ -847,7 +821,7 @@ def test_create_trace_and_event(): langfuse.flush() - trace = api.trace.get(trace.id) + trace = get_api().trace.get(trace.id) assert trace.name == trace_name assert len(trace.observations) == 1 @@ -859,8 +833,6 @@ def test_create_trace_and_event(): def test_create_span_and_generation(): - api = get_api() - langfuse = Langfuse(debug=False) span = langfuse.span(name="span") @@ -868,7 +840,7 @@ def test_create_span_and_generation(): langfuse.flush() - trace = api.trace.get(span.trace_id) + trace = get_api().trace.get(span.trace_id) assert trace.name == "span" assert len(trace.observations) == 2 @@ -938,16 +910,14 @@ def test_end_generation(): def test_end_generation_with_data(): langfuse = Langfuse() - api = get_api() + trace = langfuse.trace() - generation = langfuse.generation( + generation = trace.generation( name="query-generation", ) generation.end( name="test_generation_end", - start_time=datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc), - end_time=datetime(2023, 1, 1, 12, 5, tzinfo=timezone.utc), metadata={"dict": "value"}, level="ERROR", status_message="Generation ended", @@ -970,13 +940,9 @@ def test_end_generation_with_data(): langfuse.flush() - trace_id = langfuse.get_trace_id() - - trace = api.trace.get(trace_id) + fetched_trace = get_api().trace.get(trace.id) - generation = trace.observations[0] - assert generation.start_time == datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) - assert generation.end_time == datetime(2023, 1, 1, 12, 5, tzinfo=timezone.utc) + generation = fetched_trace.observations[0] assert generation.completion_start_time == datetime( 2023, 1, 1, 12, 3, tzinfo=timezone.utc ) @@ -992,7 +958,6 @@ def test_end_generation_with_data(): assert generation.usage.input == 100 assert generation.usage.output == 200 assert generation.usage.total == 500 - assert generation.usage.unit == "CHARACTERS" assert generation.calculated_input_cost == 111 assert generation.calculated_output_cost == 222 assert generation.calculated_total_cost == 444 @@ -1000,7 +965,6 @@ def test_end_generation_with_data(): def test_end_generation_with_openai_token_format(): langfuse = Langfuse() - api = get_api() generation = langfuse.generation( name="query-generation", @@ -1021,7 +985,7 @@ def test_end_generation_with_openai_token_format(): trace_id = langfuse.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) print(trace.observations[0]) generation = trace.observations[0] @@ -1062,7 +1026,6 @@ def test_end_span(): def test_end_span_with_data(): langfuse = Langfuse() - api = get_api() timestamp = _get_timestamp() span = langfuse.span( @@ -1079,7 +1042,7 @@ def test_end_span_with_data(): trace_id = langfuse.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) span = trace.observations[0] assert span.end_time is not None @@ -1108,7 +1071,10 @@ def test_get_generations(): ) langfuse.flush() + + sleep(1) generations = langfuse.get_generations(name=generation_name, limit=10, page=1) + assert len(generations.data) == 1 assert generations.data[0].name == generation_name assert generations.data[0].input == "great-prompt" @@ -1138,6 +1104,8 @@ def test_get_generations_by_user(): ) langfuse.flush() + sleep(1) + generations = langfuse.get_generations(limit=10, page=1, user_id=user_id) assert len(generations.data) == 1 @@ -1148,7 +1116,6 @@ def test_get_generations_by_user(): def test_kwargs(): langfuse = Langfuse() - api = get_api() timestamp = _get_timestamp() @@ -1166,7 +1133,7 @@ def test_kwargs(): langfuse.flush() - observation = api.observations.get(span.id) + observation = get_api().observations.get(span.id) assert observation.start_time is not None assert observation.input == {"key": "value"} assert observation.output == {"key": "value"} @@ -1181,7 +1148,6 @@ def test_timezone_awareness(): assert utc_now.tzinfo is not None langfuse = Langfuse(debug=False) - api = get_api() trace = langfuse.trace(name="test") span = trace.span(name="span") @@ -1192,7 +1158,7 @@ def test_timezone_awareness(): langfuse.flush() - trace = api.trace.get(trace.id) + trace = get_api().trace.get(trace.id) assert len(trace.observations) == 3 for observation in trace.observations: @@ -1219,7 +1185,6 @@ def test_timezone_awareness_setting_timestamps(): print(utc_now) langfuse = Langfuse(debug=False) - api = get_api() trace = langfuse.trace(name="test") trace.span(name="span", start_time=now, end_time=now) @@ -1228,7 +1193,7 @@ def test_timezone_awareness_setting_timestamps(): langfuse.flush() - trace = api.trace.get(trace.id) + trace = get_api().trace.get(trace.id) assert len(trace.observations) == 3 for observation in trace.observations: @@ -1242,7 +1207,6 @@ def test_timezone_awareness_setting_timestamps(): def test_get_trace_by_session_id(): langfuse = Langfuse(debug=False) - api = get_api() # Create a trace with a session_id trace_name = create_uuid() @@ -1255,7 +1219,7 @@ def test_get_trace_by_session_id(): langfuse.flush() # Retrieve the trace using the session_id - traces = api.trace.list(session_id=session_id) + traces = get_api().trace.list(session_id=session_id) # Verify that the trace was retrieved correctly assert len(traces.data) == 1 @@ -1274,6 +1238,7 @@ def test_fetch_trace(): langfuse.flush() # Fetch the trace + sleep(1) response = langfuse.fetch_trace(trace.id) # Assert the structure of the response @@ -1307,6 +1272,7 @@ def test_fetch_traces(): timestamp=trace_param["timestamp"], ) langfuse.flush() + sleep(1) all_traces = langfuse.fetch_traces(limit=10, name=name) assert len(all_traces.data) == 3 @@ -1331,7 +1297,7 @@ def test_fetch_traces(): fetched_trace = response.data[0] assert fetched_trace.name == name assert fetched_trace.session_id == "session-1" - assert fetched_trace.input == {"key": "value"} + assert fetched_trace.input == '{"key":"value"}' assert fetched_trace.output == "output-value" # compare timestamps without microseconds and in UTC assert fetched_trace.timestamp.replace(microsecond=0) == trace_params[1][ @@ -1353,6 +1319,7 @@ def test_fetch_observation(): trace = langfuse.trace(name=name) generation = trace.generation(name=name) langfuse.flush() + sleep(1) # Fetch the observation response = langfuse.fetch_observation(generation.id) @@ -1374,6 +1341,7 @@ def test_fetch_observations(): gen1 = trace.generation(name=name) gen2 = trace.generation(name=name) langfuse.flush() + sleep(1) # Fetch observations response = langfuse.fetch_observations(limit=10, name=name) @@ -1448,6 +1416,7 @@ def test_fetch_sessions(): langfuse.flush() # Fetch traces + sleep(3) response = langfuse.fetch_sessions() # Assert the structure of the response, cannot check for the exact number of sessions as the table is not cleared between tests @@ -1455,12 +1424,10 @@ def test_fetch_sessions(): assert hasattr(response, "data") assert hasattr(response, "meta") assert isinstance(response.data, list) - assert response.data[0].id in [session1, session2, session3] # fetch only one, cannot check for the exact number of sessions as the table is not cleared between tests response = langfuse.fetch_sessions(limit=1, page=2) assert len(response.data) == 1 - assert response.data[0].id in [session1, session2, session3] def test_create_trace_sampling_zero(): @@ -1481,10 +1448,10 @@ def test_create_trace_sampling_zero(): langfuse.flush() - trace = api_wrapper.get_trace(trace.id) - assert trace == { + fetched_trace = api_wrapper.get_trace(trace.id) + assert fetched_trace == { "error": "LangfuseNotFoundError", - "message": "Trace not found within authorized project", + "message": f"Trace {trace.id} not found within authorized project", } diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bd94079a1..be61e9ae7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,13 +1,14 @@ import json import os -from typing import List +import time from concurrent.futures import ThreadPoolExecutor +from typing import List from langchain import LLMChain, OpenAI, PromptTemplate from langfuse import Langfuse -from langfuse.decorators import observe, langfuse_context from langfuse.api.resources.commons.types.observation import Observation +from langfuse.decorators import langfuse_context, observe from tests.utils import create_uuid, get_api, get_llama_index_index @@ -264,11 +265,14 @@ def test_linking_via_id_observation_arg_legacy(): generation = langfuse.generation(id=generation_id) trace_id = generation.trace_id langfuse.flush() + time.sleep(1) item.link(generation_id, run_name) langfuse.flush() + time.sleep(1) + run = langfuse.get_dataset_run(dataset_name, run_name) assert run.name == run_name @@ -435,9 +439,7 @@ def test_llama_index_dataset(): assert len(run.dataset_run_items) == 1 assert run.dataset_run_items[0].dataset_run_id == run.id - api = get_api() - - trace = api.trace.get(handler.get_trace_id()) + trace = get_api().trace.get(handler.get_trace_id()) sorted_observations = sorted_dependencies(trace.observations) @@ -452,7 +454,6 @@ def test_llama_index_dataset(): } assert sorted_observations[0].name == "query" - assert sorted_observations[1].name == "synthesize" def sorted_dependencies( @@ -533,11 +534,8 @@ def execute_dataset_item(item, run_name, trace_id): item.trace_id == trace_id for item in run.dataset_run_items ), f"Trace {trace_id} not found in run" - # Check trace - api = get_api() - for dataset_item_input, trace_id in items_data: - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.name == "run_llm_app_on_dataset_item" assert len(trace.observations) == 0 @@ -552,7 +550,7 @@ def execute_dataset_item(item, run_name, trace_id): langfuse_context.flush() - next_trace = api.trace.get(new_trace_id) + next_trace = get_api().trace.get(new_trace_id) assert next_trace.name == "run_llm_app_on_dataset_item" assert next_trace.input["args"][0] == "non-dataset-run-afterwards" assert next_trace.output == "non-dataset-run-afterwards" diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 0bd911c03..901d1c553 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -2,19 +2,20 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextvars import ContextVar +from time import sleep from typing import Optional import pytest from langchain.prompts import ChatPromptTemplate -from langchain_community.chat_models import ChatOpenAI +from langchain_openai import ChatOpenAI from langfuse.decorators import langfuse_context, observe from langfuse.media import LangfuseMedia from langfuse.openai import AsyncOpenAI from tests.utils import create_uuid, get_api, get_llama_index_index -mock_metadata = "mock_metadata" -mock_deep_metadata = "mock_deep_metadata" +mock_metadata = {"key": "metadata"} +mock_deep_metadata = {"key": "mock_deep_metadata"} mock_session_id = "session-id-1" mock_args = (1, 2, 3) mock_kwargs = {"a": 1, "b": 2, "c": 3} @@ -216,6 +217,7 @@ def level_2_function(): @observe() def level_1_function(*args, **kwargs): + sleep(1) level_2_function() return "level_1" @@ -249,7 +251,7 @@ def level_1_function(*args, **kwargs): level_3_observation = adjacencies[level_2_observation.id][0] assert ( - level_2_observation.metadata is None + level_2_observation.metadata == {} ) # Exception is raised before metadata is set assert level_3_observation.metadata == mock_deep_metadata assert level_3_observation.status_message == "Mock exception" @@ -284,6 +286,7 @@ def level_2_function(): @observe(name=mock_name) def level_1_function(*args, **kwargs): + sleep(1) level_2_function() return "level_1" @@ -309,9 +312,6 @@ def level_1_function(*args, **kwargs): langfuse_context.flush() - print("mock_id_1", mock_trace_id_1) - print("mock_id_2", mock_trace_id_2) - for mock_id in [mock_trace_id_1, mock_trace_id_2]: trace_data = get_api().trace.get(mock_id) assert ( @@ -1035,7 +1035,8 @@ async def level_1_function(*args, **kwargs): assert generation.usage.input is not None assert generation.usage.output is not None assert generation.usage.total is not None - assert "2" in generation.output + print(generation) + assert generation.output == 2 def test_generation_at_highest_level(): @@ -1408,6 +1409,7 @@ def test_top_level_generation(): @observe(as_type="generation") def main(): + sleep(1) langfuse_context.update_current_trace(name="updated_name") return mock_output @@ -1496,6 +1498,7 @@ def test_media(): @observe() def main(): + sleep(1) langfuse_context.update_current_trace( input={ "context": { diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 83cca374b..bb43f5ba8 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -2,6 +2,7 @@ import random import string import time +from time import sleep from typing import Any, Dict, List, Mapping, Optional import pytest @@ -19,10 +20,10 @@ from langchain.prompts import ChatPromptTemplate, PromptTemplate from langchain.schema import Document, HumanMessage, SystemMessage from langchain.text_splitter import CharacterTextSplitter +from langchain_anthropic import Anthropic from langchain_community.agent_toolkits.load_tools import load_tools from langchain_community.document_loaders import TextLoader from langchain_community.embeddings import OpenAIEmbeddings -from langchain_community.llms.anthropic import Anthropic from langchain_community.llms.huggingface_hub import HuggingFaceHub from langchain_community.vectorstores import Chroma from langchain_core.callbacks.manager import CallbackManagerForLLMRun @@ -49,7 +50,6 @@ def test_callback_init(): def test_callback_kwargs(): - api = get_api() callback = CallbackHandler( trace_name="trace-name", release="release", @@ -68,7 +68,7 @@ def test_callback_kwargs(): trace_id = callback.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is not None assert trace.output is not None assert trace.metadata == {"key": "value"} @@ -94,7 +94,6 @@ def test_langfuse_span(): def test_callback_generated_from_trace_chain(): - api = get_api() langfuse = Langfuse(debug=True) trace_id = create_uuid() @@ -115,7 +114,7 @@ def test_callback_generated_from_trace_chain(): langfuse.flush() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is None assert trace.output is None @@ -143,9 +142,9 @@ def test_callback_generated_from_trace_chain(): )[0] assert langchain_generation_span.parent_observation_id == langchain_span.id - assert langchain_generation_span.usage.input > 0 - assert langchain_generation_span.usage.output > 0 - assert langchain_generation_span.usage.total > 0 + assert langchain_generation_span.usage_details["input"] > 0 + assert langchain_generation_span.usage_details["output"] > 0 + assert langchain_generation_span.usage_details["total"] > 0 assert langchain_generation_span.input is not None assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None @@ -153,7 +152,6 @@ def test_callback_generated_from_trace_chain(): def test_callback_generated_from_trace_chat(): - api = get_api() langfuse = Langfuse(debug=False) trace_id = create_uuid() @@ -176,7 +174,7 @@ def test_callback_generated_from_trace_chat(): langfuse.flush() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is None assert trace.output is None @@ -194,9 +192,9 @@ def test_callback_generated_from_trace_chat(): )[0] assert langchain_generation_span.parent_observation_id is None - assert langchain_generation_span.usage.input > 0 - assert langchain_generation_span.usage.output > 0 - assert langchain_generation_span.usage.total > 0 + assert langchain_generation_span.usage_details["input"] > 0 + assert langchain_generation_span.usage_details["output"] > 0 + assert langchain_generation_span.usage_details["total"] > 0 assert langchain_generation_span.input is not None assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None @@ -204,7 +202,6 @@ def test_callback_generated_from_trace_chat(): def test_callback_generated_from_lcel_chain(): - api = get_api() langfuse = Langfuse(debug=False) run_name_override = "This is a custom Run Name" @@ -226,13 +223,12 @@ def test_callback_generated_from_lcel_chain(): langfuse.flush() handler.flush() trace_id = handler.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.name == run_name_override def test_callback_generated_from_span_chain(): - api = get_api() langfuse = Langfuse(debug=False) trace_id = create_uuid() @@ -255,7 +251,7 @@ def test_callback_generated_from_span_chain(): langfuse.flush() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is None assert trace.output is None @@ -294,9 +290,9 @@ def test_callback_generated_from_span_chain(): )[0] assert langchain_generation_span.parent_observation_id == langchain_span.id - assert langchain_generation_span.usage.input > 0 - assert langchain_generation_span.usage.output > 0 - assert langchain_generation_span.usage.total > 0 + assert langchain_generation_span.usage_details["input"] > 0 + assert langchain_generation_span.usage_details["output"] > 0 + assert langchain_generation_span.usage_details["total"] > 0 assert langchain_generation_span.input is not None assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None @@ -304,7 +300,6 @@ def test_callback_generated_from_span_chain(): def test_callback_generated_from_span_chat(): - api = get_api() langfuse = Langfuse(debug=False) trace_id = create_uuid() @@ -330,7 +325,7 @@ def test_callback_generated_from_span_chat(): langfuse.flush() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is None assert trace.output is None @@ -358,9 +353,9 @@ def test_callback_generated_from_span_chat(): )[0] assert langchain_generation_span.parent_observation_id == user_span.id - assert langchain_generation_span.usage.input > 0 - assert langchain_generation_span.usage.output > 0 - assert langchain_generation_span.usage.total > 0 + assert langchain_generation_span.usage_details["input"] > 0 + assert langchain_generation_span.usage_details["output"] > 0 + assert langchain_generation_span.usage_details["total"] > 0 assert langchain_generation_span.input is not None assert langchain_generation_span.input != "" assert langchain_generation_span.output is not None @@ -409,7 +404,6 @@ def test_mistral(): from langchain_core.messages import HumanMessage from langchain_mistralai.chat_models import ChatMistralAI - api = get_api() callback = CallbackHandler(debug=False) chat = ChatMistralAI(model="mistral-small", callbacks=[callback]) @@ -420,7 +414,7 @@ def test_mistral(): trace_id = callback.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.id == trace_id assert len(trace.observations) == 2 @@ -433,7 +427,6 @@ def test_mistral(): def test_vertx(): from langchain.llms import VertexAI - api = get_api() callback = CallbackHandler(debug=False) llm = VertexAI(callbacks=[callback]) @@ -443,7 +436,7 @@ def test_vertx(): trace_id = callback.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.id == trace_id assert len(trace.observations) == 2 @@ -475,17 +468,16 @@ def test_callback_generated_from_trace_anthropic(): langfuse.flush() - api = get_api() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert handler.get_trace_id() == trace_id assert len(trace.observations) == 2 assert trace.id == trace_id for observation in trace.observations: if observation.type == "GENERATION": - assert observation.usage.input > 0 - assert observation.usage.output > 0 - assert observation.usage.total > 0 + assert observation.usage_details["input"] > 0 + assert observation.usage_details["output"] > 0 + assert observation.usage_details["total"] > 0 assert observation.output is not None assert observation.output != "" assert isinstance(observation.input, str) is True @@ -513,9 +505,7 @@ def test_basic_chat_openai(): trace_id = callback.get_trace_id() - api = get_api() - - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.id == trace_id assert len(trace.observations) == 1 @@ -562,9 +552,7 @@ def test_basic_chat_openai_based_on_trace(): trace_id = callback.get_trace_id() - api = get_api() - - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.id == trace_id assert len(trace.observations) == 1 @@ -591,8 +579,7 @@ def test_callback_from_trace_with_trace_update(): trace_id = handler.get_trace_id() - api = get_api() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is not None assert trace.output is not None @@ -606,9 +593,9 @@ def test_callback_from_trace_with_trace_update(): for generation in generations: assert generation.input is not None assert generation.output is not None - assert generation.usage.total is not None - assert generation.usage.input is not None - assert generation.usage.output is not None + assert generation.usage_details["total"] is not None + assert generation.usage_details["input"] is not None + assert generation.usage_details["output"] is not None def test_callback_from_span_with_span_update(): @@ -634,12 +621,11 @@ def test_callback_from_span_with_span_update(): trace_id = handler.get_trace_id() - api = get_api() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is None assert trace.output is None - assert trace.metadata is None + assert trace.metadata == {} assert len(trace.observations) == 3 assert handler.get_trace_id() == trace_id @@ -655,9 +641,9 @@ def test_callback_from_span_with_span_update(): for generation in generations: assert generation.input is not None assert generation.output is not None - assert generation.usage.total is not None - assert generation.usage.input is not None - assert generation.usage.output is not None + assert generation.usage_details["total"] is not None + assert generation.usage_details["input"] is not None + assert generation.usage_details["output"] is not None def test_callback_from_trace_simple_chain(): @@ -681,8 +667,7 @@ def test_callback_from_trace_simple_chain(): trace_id = handler.get_trace_id() - api = get_api() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.input is None assert trace.output is None @@ -695,9 +680,9 @@ def test_callback_from_trace_simple_chain(): for generation in generations: assert generation.input is not None assert generation.output is not None - assert generation.usage.total is not None - assert generation.usage.input is not None - assert generation.usage.output is not None + assert generation.usage_details["total"] is not None + assert generation.usage_details["input"] is not None + assert generation.usage_details["output"] is not None def test_next_span_id_from_trace_simple_chain(): @@ -748,7 +733,6 @@ def test_next_span_id_from_trace_simple_chain(): def test_callback_sequential_chain(): - api = get_api() handler = CallbackHandler(debug=False) llm = OpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY")) @@ -777,16 +761,16 @@ def test_callback_sequential_chain(): trace_id = handler.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert len(trace.observations) == 5 assert trace.id == trace_id for observation in trace.observations: if observation.type == "GENERATION": - assert observation.usage.input > 0 - assert observation.usage.output > 0 - assert observation.usage.total > 0 + assert observation.usage_details["input"] > 0 + assert observation.usage_details["output"] > 0 + assert observation.usage_details["total"] > 0 assert observation.input is not None assert observation.input != "" assert observation.output is not None @@ -921,9 +905,7 @@ def test_callback_retriever_conversational_with_memory(): conversation.predict(input="Hi there!", callbacks=[handler]) handler.flush() - api = get_api() - - trace = api.trace.get(handler.get_trace_id()) + trace = get_api().trace.get(handler.get_trace_id()) generations = list(filter(lambda x: x.type == "GENERATION", trace.observations)) assert len(generations) == 1 @@ -933,9 +915,9 @@ def test_callback_retriever_conversational_with_memory(): assert generation.output is not None assert generation.input != "" assert generation.output != "" - assert generation.usage.total is not None - assert generation.usage.input is not None - assert generation.usage.output is not None + assert generation.usage_details["total"] is not None + assert generation.usage_details["input"] is not None + assert generation.usage_details["output"] is not None def test_callback_retriever_conversational(): @@ -983,8 +965,7 @@ def test_callback_retriever_conversational(): def test_callback_simple_openai(): - api = get_api() - handler = CallbackHandler(debug=False) + handler = CallbackHandler() llm = OpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY")) @@ -996,15 +977,16 @@ def test_callback_simple_openai(): trace_id = handler.get_trace_id() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert len(trace.observations) == 1 for observation in trace.observations: if observation.type == "GENERATION": - assert observation.usage.input > 0 - assert observation.usage.output > 0 - assert observation.usage.total > 0 + print(observation.usage_details) + assert observation.usage_details["input"] > 0 + assert observation.usage_details["output"] > 0 + assert observation.usage_details["total"] > 0 assert observation.input is not None assert observation.input != "" assert observation.output is not None @@ -1012,7 +994,6 @@ def test_callback_simple_openai(): def test_callback_multiple_invocations_on_different_traces(): - api = get_api() handler = CallbackHandler(debug=False) llm = OpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY")) @@ -1031,8 +1012,8 @@ def test_callback_multiple_invocations_on_different_traces(): assert trace_id_one != trace_id_two - trace_one = api.trace.get(trace_id_one) - trace_two = api.trace.get(trace_id_two) + trace_one = get_api().trace.get(trace_id_one) + trace_two = get_api().trace.get(trace_id_two) for test_data in [ {"trace": trace_one, "expected_trace_id": trace_id_one}, @@ -1042,9 +1023,9 @@ def test_callback_multiple_invocations_on_different_traces(): assert test_data["trace"].id == test_data["expected_trace_id"] for observation in test_data["trace"].observations: if observation.type == "GENERATION": - assert observation.usage.input > 0 - assert observation.usage.output > 0 - assert observation.usage.total > 0 + assert observation.usage_details["input"] > 0 + assert observation.usage_details["output"] > 0 + assert observation.usage_details["total"] > 0 assert observation.input is not None assert observation.input != "" assert observation.output is not None @@ -1104,9 +1085,8 @@ def test_tools(): handler.flush() trace_id = handler.get_trace_id() - api = get_api() - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.id == trace_id assert len(trace.observations) > 2 @@ -1227,8 +1207,7 @@ def record_dog(name: str, color: str, fav_food: OptionalFavFood) -> str: handler.langfuse.flush() - api = get_api() - trace = api.trace.get(handler.get_trace_id()) + trace = get_api().trace.get(handler.get_trace_id()) assert len(trace.observations) == 2 @@ -1263,9 +1242,9 @@ def record_dog(name: str, color: str, fav_food: OptionalFavFood) -> str: "refusal": None, }, } - assert generation.usage.total is not None - assert generation.usage.input is not None - assert generation.usage.output is not None + assert generation.usage_details["total"] is not None + assert generation.usage_details["input"] is not None + assert generation.usage_details["output"] is not None def test_agent_executor_chain(): @@ -1312,8 +1291,8 @@ def get_word_length(word: str) -> int: ) callback.flush() - api = get_api() - trace = api.trace.get(callback.get_trace_id()) + + trace = get_api().trace.get(callback.get_trace_id()) generations = list(filter(lambda x: x.type == "GENERATION", trace.observations)) assert len(generations) > 0 @@ -1323,9 +1302,9 @@ def get_word_length(word: str) -> int: assert generation.output is not None assert generation.input != "" assert generation.output != "" - assert generation.usage.total is not None - assert generation.usage.input is not None - assert generation.usage.output is not None + assert generation.usage_details["total"] is not None + assert generation.usage_details["input"] is not None + assert generation.usage_details["output"] is not None # def test_create_extraction_chain(): @@ -1387,9 +1366,9 @@ def get_word_length(word: str) -> int: # handler.flush() -# api = get_api() +# -# trace = api.trace.get(handler.get_trace_id()) +# trace = get_api().trace.get(handler.get_trace_id()) # generations = list(filter(lambda x: x.type == "GENERATION", trace.observations)) # assert len(generations) > 0 @@ -1399,9 +1378,9 @@ def get_word_length(word: str) -> int: # assert generation.output is not None # assert generation.input != "" # assert generation.output != "" -# assert generation.usage.total is not None -# assert generation.usage.input is not None -# assert generation.usage.output is not None +# assert generation.usage_details["total"] is not None +# assert generation.usage_details["input"] is not None +# assert generation.usage_details["output"] is not None @pytest.mark.skip(reason="inference cost") @@ -1511,9 +1490,7 @@ def _identifying_params(self) -> Mapping[str, Any]: callback.flush() - api = get_api() - - trace = api.trace.get(callback.get_trace_id()) + trace = get_api().trace.get(callback.get_trace_id()) assert len(trace.observations) == 5 @@ -1572,8 +1549,8 @@ def test_names_on_spans_lcel(): ) callback.flush() - api = get_api() - trace = api.trace.get(callback.get_trace_id()) + + trace = get_api().trace.get(callback.get_trace_id()) assert len(trace.observations) == 7 @@ -1646,9 +1623,9 @@ def test_openai_instruct_usage(): assert observation.input is not None assert observation.input != "" assert observation.usage is not None - assert observation.usage.input is not None - assert observation.usage.output is not None - assert observation.usage.total is not None + assert observation.usage_details["input"] is not None + assert observation.usage_details["output"] is not None + assert observation.usage_details["total"] is not None def test_get_langchain_prompt_with_jinja2(): @@ -1743,8 +1720,6 @@ def test_get_langchain_chat_prompt(): def test_disabled_langfuse(): - api = get_api() - run_name_override = "This is a custom Run Name" handler = CallbackHandler(enabled=False, debug=False) @@ -1768,7 +1743,7 @@ def test_disabled_langfuse(): trace_id = handler.get_trace_id() with pytest.raises(Exception): - api.trace.get(trace_id) + get_api().trace.get(trace_id) def test_link_langfuse_prompts_invoke(): @@ -1837,6 +1812,7 @@ def test_link_langfuse_prompts_invoke(): ) langfuse_handler.flush() + sleep(2) trace = get_api().trace.get(langfuse_handler.get_trace_id()) @@ -1861,7 +1837,7 @@ def test_link_langfuse_prompts_invoke(): assert generations[0].prompt_version == langfuse_joke_prompt.version assert generations[1].prompt_version == langfuse_explain_prompt.version - assert generations[1].output == output.strip() + assert generations[1].output == (output.strip() if output else None) def test_link_langfuse_prompts_stream(): @@ -1914,7 +1890,7 @@ def test_link_langfuse_prompts_stream(): ) # Run chain - langfuse_handler = CallbackHandler(debug=True) + langfuse_handler = CallbackHandler() stream = chain.stream( {"animal": "dog"}, @@ -1934,6 +1910,7 @@ def test_link_langfuse_prompts_stream(): output += chunk langfuse_handler.flush() + sleep(2) trace = get_api().trace.get(langfuse_handler.get_trace_id()) @@ -1961,7 +1938,7 @@ def test_link_langfuse_prompts_stream(): assert generations[0].time_to_first_token is not None assert generations[1].time_to_first_token is not None - assert generations[1].output == output.strip() + assert generations[1].output == (output.strip() if output else None) def test_link_langfuse_prompts_batch(): @@ -2145,8 +2122,7 @@ class GetWeather(BaseModel): handler.flush() - api = get_api() - trace = api.trace.get(handler.get_trace_id()) + trace = get_api().trace.get(handler.get_trace_id()) generations = list(filter(lambda x: x.type == "GENERATION", trace.observations)) assert len(generations) > 0 @@ -2215,7 +2191,6 @@ def _generate_random_dict(n: int, key_length: int = 8) -> Dict[str, Any]: def test_multimodal(): - api = get_api() handler = CallbackHandler() model = ChatOpenAI(model="gpt-4o-mini") @@ -2237,7 +2212,7 @@ def test_multimodal(): handler.flush() - trace = api.trace.get(handler.get_trace_id()) + trace = get_api().trace.get(handler.get_trace_id()) assert len(trace.observations) == 1 assert trace.observations[0].type == "GENERATION" diff --git a/tests/test_media.py b/tests/test_media.py index df94b8c89..82211a37e 100644 --- a/tests/test_media.py +++ b/tests/test_media.py @@ -1,10 +1,12 @@ import base64 -import pytest -from langfuse.media import LangfuseMedia -from langfuse.client import Langfuse -from uuid import uuid4 import re +from uuid import uuid4 +import pytest + +from langfuse.client import Langfuse +from langfuse.media import LangfuseMedia +from tests.utils import get_api # Test data SAMPLE_JPEG_BYTES = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00" @@ -136,7 +138,7 @@ def test_replace_media_reference_string_in_object(): langfuse.flush() # Verify media reference string format - fetched_trace = langfuse.fetch_trace(trace.id).data + fetched_trace = get_api().trace.get(trace.id) media_ref = fetched_trace.metadata["context"]["nested"] assert re.match( r"^@@@langfuseMedia:type=audio/wav\|id=.+\|source=base64_data_uri@@@$", @@ -163,7 +165,7 @@ def test_replace_media_reference_string_in_object(): langfuse.flush() # Verify second trace has same media reference - fetched_trace2 = langfuse.fetch_trace(trace2.id).data + fetched_trace2 = get_api().trace.get(trace2.id) assert ( fetched_trace2.metadata["context"]["nested"] == fetched_trace.metadata["context"]["nested"] diff --git a/tests/test_openai.py b/tests/test_openai.py index 3de56cc05..5dc991ef4 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -34,7 +34,6 @@ def test_auth_check(): def test_openai_chat_completion(): - api = get_api() generation_name = create_uuid() completion = chat_func( name=generation_name, @@ -51,7 +50,9 @@ def test_openai_chat_completion(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -86,7 +87,7 @@ def test_openai_chat_completion(): assert "2" in generation.data[0].output["content"] assert generation.data[0].output["role"] == "assistant" - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.input == [ { "content": "You are an expert mathematician", @@ -103,7 +104,6 @@ def test_openai_chat_completion(): def test_openai_chat_completion_stream(): - api = get_api() generation_name = create_uuid() completion = chat_func( name=generation_name, @@ -125,7 +125,9 @@ def test_openai_chat_completion_stream(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -147,8 +149,7 @@ def test_openai_chat_completion_stream(): assert generation.data[0].usage.input is not None assert generation.data[0].usage.output is not None assert generation.data[0].usage.total is not None - assert generation.data[0].output == "2" - assert isinstance(generation.data[0].output, str) is True + assert generation.data[0].output == 2 assert generation.data[0].completion_start_time is not None # Completion start time for time-to-first-token @@ -156,13 +157,12 @@ def test_openai_chat_completion_stream(): assert generation.data[0].completion_start_time >= generation.data[0].start_time assert generation.data[0].completion_start_time <= generation.data[0].end_time - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.input == [{"role": "user", "content": "1 + 1 = "}] - assert trace.output == chat_content + assert str(trace.output) == chat_content def test_openai_chat_completion_stream_with_next_iteration(): - api = get_api() generation_name = create_uuid() completion = chat_func( name=generation_name, @@ -189,7 +189,9 @@ def test_openai_chat_completion_stream_with_next_iteration(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -211,8 +213,7 @@ def test_openai_chat_completion_stream_with_next_iteration(): assert generation.data[0].usage.input is not None assert generation.data[0].usage.output is not None assert generation.data[0].usage.total is not None - assert generation.data[0].output == "2" - assert isinstance(generation.data[0].output, str) is True + assert generation.data[0].output == 2 assert generation.data[0].completion_start_time is not None # Completion start time for time-to-first-token @@ -220,13 +221,12 @@ def test_openai_chat_completion_stream_with_next_iteration(): assert generation.data[0].completion_start_time >= generation.data[0].start_time assert generation.data[0].completion_start_time <= generation.data[0].end_time - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.input == [{"role": "user", "content": "1 + 1 = "}] - assert trace.output == chat_content + assert str(trace.output) == chat_content def test_openai_chat_completion_stream_fail(): - api = get_api() generation_name = create_uuid() openai.api_key = "" @@ -242,7 +242,9 @@ def test_openai_chat_completion_stream_fail(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -270,13 +272,12 @@ def test_openai_chat_completion_stream_fail(): openai.api_key = os.environ["OPENAI_API_KEY"] - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.input == [{"role": "user", "content": "1 + 1 = "}] assert trace.output is None def test_openai_chat_completion_with_trace(): - api = get_api() generation_name = create_uuid() trace_id = create_uuid() langfuse = Langfuse() @@ -294,7 +295,9 @@ def test_openai_chat_completion_with_trace(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -302,7 +305,6 @@ def test_openai_chat_completion_with_trace(): def test_openai_chat_completion_with_langfuse_prompt(): - api = get_api() generation_name = create_uuid() langfuse = Langfuse() prompt_name = create_uuid() @@ -319,7 +321,9 @@ def test_openai_chat_completion_with_langfuse_prompt(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -327,7 +331,6 @@ def test_openai_chat_completion_with_langfuse_prompt(): def test_openai_chat_completion_with_parent_observation_id(): - api = get_api() generation_name = create_uuid() trace_id = create_uuid() span_id = create_uuid() @@ -348,7 +351,9 @@ def test_openai_chat_completion_with_parent_observation_id(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -357,7 +362,6 @@ def test_openai_chat_completion_with_parent_observation_id(): def test_openai_chat_completion_fail(): - api = get_api() generation_name = create_uuid() openai.api_key = "" @@ -373,7 +377,9 @@ def test_openai_chat_completion_fail(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -399,7 +405,6 @@ def test_openai_chat_completion_fail(): def test_openai_chat_completion_with_additional_params(): - api = get_api() user_id = create_uuid() session_id = create_uuid() tags = ["tag1", "tag2"] @@ -419,7 +424,7 @@ def test_openai_chat_completion_with_additional_params(): openai.flush_langfuse() assert len(completion.choices) != 0 - trace = api.trace.get(trace_id) + trace = get_api().trace.get(trace_id) assert trace.user_id == user_id assert trace.session_id == session_id @@ -438,7 +443,6 @@ def test_openai_chat_completion_without_extra_param(): def test_openai_chat_completion_two_calls(): - api = get_api() generation_name = create_uuid() completion = chat_func( name=generation_name, @@ -460,7 +464,9 @@ def test_openai_chat_completion_two_calls(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -468,7 +474,9 @@ def test_openai_chat_completion_two_calls(): assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}] - generation_2 = api.observations.get_many(name=generation_name_2, type="GENERATION") + generation_2 = get_api().observations.get_many( + name=generation_name_2, type="GENERATION" + ) assert len(generation_2.data) != 0 assert generation_2.data[0].name == generation_name_2 @@ -478,7 +486,6 @@ def test_openai_chat_completion_two_calls(): def test_openai_chat_completion_with_seed(): - api = get_api() generation_name = create_uuid() completion = chat_func( name=generation_name, @@ -491,7 +498,9 @@ def test_openai_chat_completion_with_seed(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert generation.data[0].model_parameters == { "temperature": 0, @@ -505,7 +514,6 @@ def test_openai_chat_completion_with_seed(): def test_openai_completion(): - api = get_api() generation_name = create_uuid() completion = completion_func( name=generation_name, @@ -517,7 +525,9 @@ def test_openai_completion(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -542,13 +552,12 @@ def test_openai_completion(): assert generation.data[0].usage.total is not None assert generation.data[0].output == "2\n\n1 + 2 = 3\n\n2 + 3 = " - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.input == "1 + 1 = " assert trace.output == completion.choices[0].text def test_openai_completion_stream(): - api = get_api() generation_name = create_uuid() completion = completion_func( name=generation_name, @@ -568,7 +577,9 @@ def test_openai_completion_stream(): assert len(content) > 0 - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -598,13 +609,12 @@ def test_openai_completion_stream(): assert generation.data[0].completion_start_time >= generation.data[0].start_time assert generation.data[0].completion_start_time <= generation.data[0].end_time - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.input == "1 + 1 = " assert trace.output == content def test_openai_completion_fail(): - api = get_api() generation_name = create_uuid() openai.api_key = "" @@ -620,7 +630,9 @@ def test_openai_completion_fail(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -646,7 +658,6 @@ def test_openai_completion_fail(): def test_openai_completion_stream_fail(): - api = get_api() generation_name = create_uuid() openai.api_key = "" @@ -662,7 +673,9 @@ def test_openai_completion_stream_fail(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -692,7 +705,6 @@ def test_openai_completion_stream_fail(): def test_openai_completion_with_languse_prompt(): - api = get_api() generation_name = create_uuid() langfuse = Langfuse() prompt_name = create_uuid() @@ -710,7 +722,9 @@ def test_openai_completion_with_languse_prompt(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -749,7 +763,6 @@ def test_fails_wrong_trace_id(): @pytest.mark.asyncio async def test_async_chat(): - api = get_api() client = AsyncOpenAI() generation_name = create_uuid() @@ -761,7 +774,9 @@ async def test_async_chat(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -789,7 +804,6 @@ async def test_async_chat(): @pytest.mark.asyncio async def test_async_chat_stream(): - api = get_api() client = AsyncOpenAI() generation_name = create_uuid() @@ -806,7 +820,9 @@ async def test_async_chat_stream(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -826,7 +842,7 @@ async def test_async_chat_stream(): assert generation.data[0].usage.input is not None assert generation.data[0].usage.output is not None assert generation.data[0].usage.total is not None - assert "2" in generation.data[0].output + assert "2" in str(generation.data[0].output) # Completion start time for time-to-first-token assert generation.data[0].completion_start_time is not None @@ -836,7 +852,6 @@ async def test_async_chat_stream(): @pytest.mark.asyncio async def test_async_chat_stream_with_anext(): - api = get_api() client = AsyncOpenAI() generation_name = create_uuid() @@ -863,7 +878,9 @@ async def test_async_chat_stream_with_anext(): print(result) - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -897,7 +914,6 @@ def test_openai_function_call(): from pydantic import BaseModel - api = get_api() generation_name = create_uuid() class StepByStepAIResponse(BaseModel): @@ -924,7 +940,9 @@ class StepByStepAIResponse(BaseModel): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -939,7 +957,6 @@ def test_openai_function_call_streamed(): from pydantic import BaseModel - api = get_api() generation_name = create_uuid() class StepByStepAIResponse(BaseModel): @@ -967,7 +984,9 @@ class StepByStepAIResponse(BaseModel): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -976,7 +995,6 @@ class StepByStepAIResponse(BaseModel): def test_openai_tool_call(): - api = get_api() generation_name = create_uuid() tools = [ @@ -1010,7 +1028,9 @@ def test_openai_tool_call(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1026,7 +1046,6 @@ def test_openai_tool_call(): def test_openai_tool_call_streamed(): - api = get_api() generation_name = create_uuid() tools = [ @@ -1065,7 +1084,9 @@ def test_openai_tool_call_streamed(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1082,7 +1103,6 @@ def test_openai_tool_call_streamed(): def test_azure(): - api = get_api() generation_name = create_uuid() azure = AzureOpenAI( api_key="missing", @@ -1101,7 +1121,9 @@ def test_azure(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1127,7 +1149,6 @@ def test_azure(): @pytest.mark.asyncio async def test_async_azure(): - api = get_api() generation_name = create_uuid() azure = AsyncAzureOpenAI( api_key="missing", @@ -1146,7 +1167,9 @@ async def test_async_azure(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1185,7 +1208,6 @@ def test_openai_with_existing_trace_id(): langfuse.flush() - api = get_api() generation_name = create_uuid() completion = chat_func( name=generation_name, @@ -1198,7 +1220,9 @@ def test_openai_with_existing_trace_id(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1223,7 +1247,7 @@ def test_openai_with_existing_trace_id(): assert "2" in generation.data[0].output["content"] assert generation.data[0].output["role"] == "assistant" - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.output == "This is a standard output" assert trace.input == "My custom input" @@ -1237,7 +1261,6 @@ def test_disabled_langfuse(): openai.langfuse_enabled = False - api = get_api() generation_name = create_uuid() openai.chat.completions.create( name=generation_name, @@ -1249,7 +1272,9 @@ def test_disabled_langfuse(): openai.flush_langfuse() - generations = api.observations.get_many(name=generation_name, type="GENERATION") + generations = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generations.data) == 0 @@ -1279,7 +1304,6 @@ def test_langchain_integration(): def test_structured_output_response_format_kwarg(): - api = get_api() generation_name = ( "test_structured_output_response_format_kwarg" + create_uuid()[0:10] ) @@ -1325,7 +1349,9 @@ def test_structured_output_response_format_kwarg(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1355,7 +1381,7 @@ def test_structured_output_response_format_kwarg(): assert generation.data[0].usage.total is not None assert generation.data[0].output["role"] == "assistant" - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.output is not None assert trace.input is not None @@ -1371,7 +1397,6 @@ class CalendarEvent(BaseModel): participants: List[str] generation_name = create_uuid() - api = get_api() params = { "model": "gpt-4o-2024-08-06", @@ -1396,7 +1421,9 @@ class CalendarEvent(BaseModel): if Version(openai.__version__) >= Version("1.50.0"): # Check the trace and observation properties - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) == 1 assert generation.data[0].name == generation_name @@ -1421,7 +1448,7 @@ class CalendarEvent(BaseModel): assert generation.data[0].usage.total is not None # Check trace - trace = api.trace.get(generation.data[0].trace_id) + trace = get_api().trace.get(generation.data[0].trace_id) assert trace.input is not None assert trace.output is not None @@ -1431,7 +1458,6 @@ class CalendarEvent(BaseModel): async def test_close_async_stream(): client = AsyncOpenAI() generation_name = create_uuid() - api = get_api() stream = await client.chat.completions.create( messages=[{"role": "user", "content": "1 + 1 = "}], @@ -1447,7 +1473,9 @@ async def test_close_async_stream(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1467,7 +1495,7 @@ async def test_close_async_stream(): assert generation.data[0].usage.input is not None assert generation.data[0].usage.output is not None assert generation.data[0].usage.total is not None - assert "2" in generation.data[0].output + assert "2" in str(generation.data[0].output) # Completion start time for time-to-first-token assert generation.data[0].completion_start_time is not None @@ -1476,7 +1504,6 @@ async def test_close_async_stream(): def test_base_64_image_input(): - api = get_api() client = openai.OpenAI() generation_name = "test_base_64_image_input" + create_uuid()[:8] @@ -1507,7 +1534,9 @@ def test_base_64_image_input(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name @@ -1528,7 +1557,6 @@ def test_base_64_image_input(): def test_audio_input_and_output(): - api = get_api() client = openai.OpenAI() openai.langfuse_debug = True generation_name = "test_audio_input_and_output" + create_uuid()[:8] @@ -1557,7 +1585,9 @@ def test_audio_input_and_output(): openai.flush_langfuse() - generation = api.observations.get_many(name=generation_name, type="GENERATION") + generation = get_api().observations.get_many( + name=generation_name, type="GENERATION" + ) assert len(generation.data) != 0 assert generation.data[0].name == generation_name diff --git a/tests/utils.py b/tests/utils.py index 583770d3c..6b6849a6a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ import base64 import os import typing +from time import sleep from uuid import uuid4 try: @@ -25,6 +26,8 @@ def create_uuid(): def get_api(): + sleep(2) + return FernLangfuse( username=os.environ.get("LANGFUSE_PUBLIC_KEY"), password=os.environ.get("LANGFUSE_SECRET_KEY"),