diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 256a45ca8..133fc17ec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,37 @@ jobs: with: args: check --config ci.ruff.toml + type-checking: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.9" + - name: Install poetry + uses: abatilo/actions-poetry@v2 + - name: Setup a local virtual environment + run: | + poetry config virtualenvs.create true --local + poetry config virtualenvs.in-project true --local + - uses: actions/cache@v3 + name: Define a cache for the virtual environment based on the dependencies lock file + with: + path: ./.venv + key: venv-type-check-${{ hashFiles('poetry.lock') }} + - uses: actions/cache@v3 + name: Cache mypy cache + with: + path: ./.mypy_cache + key: mypy-${{ hashFiles('**/*.py', 'pyproject.toml') }} + restore-keys: | + mypy- + - name: Install dependencies + run: poetry install --only=main,dev --no-extras + - name: Run mypy type checking + run: poetry run mypy langfuse --no-error-summary + ci: runs-on: ubuntu-latest timeout-minutes: 30 @@ -160,7 +191,7 @@ jobs: all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix runs-on: ubuntu-latest - needs: [ci, linting] + needs: [ci, linting, type-checking] if: always() steps: - name: Successful deploy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89c6e0512..594632c23 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,3 +11,13 @@ repos: - id: ruff-format types_or: [ python, pyi, jupyter ] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + additional_dependencies: + - types-requests + - types-setuptools + args: [--no-error-summary] + files: ^langfuse/ + diff --git a/langfuse/_client/attributes.py b/langfuse/_client/attributes.py index 623506e69..950ec6873 100644 --- a/langfuse/_client/attributes.py +++ b/langfuse/_client/attributes.py @@ -68,7 +68,7 @@ def create_trace_attributes( metadata: Optional[Any] = None, tags: Optional[List[str]] = None, public: Optional[bool] = None, -): +) -> Dict[str, Any]: attributes = { LangfuseOtelSpanAttributes.TRACE_NAME: name, LangfuseOtelSpanAttributes.TRACE_USER_ID: user_id, @@ -93,7 +93,7 @@ def create_span_attributes( level: Optional[SpanLevel] = None, status_message: Optional[str] = None, version: Optional[str] = None, -): +) -> Dict[str, Any]: attributes = { LangfuseOtelSpanAttributes.OBSERVATION_TYPE: "span", LangfuseOtelSpanAttributes.OBSERVATION_LEVEL: level, @@ -122,7 +122,7 @@ def create_generation_attributes( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, -): +) -> Dict[str, Any]: attributes = { LangfuseOtelSpanAttributes.OBSERVATION_TYPE: "generation", LangfuseOtelSpanAttributes.OBSERVATION_LEVEL: level, @@ -151,13 +151,13 @@ def create_generation_attributes( return {k: v for k, v in attributes.items() if v is not None} -def _serialize(obj): +def _serialize(obj: Any) -> Optional[str]: return json.dumps(obj, cls=EventSerializer) if obj is not None else None def _flatten_and_serialize_metadata( metadata: Any, type: Literal["observation", "trace"] -): +) -> Dict[str, Any]: prefix = ( LangfuseOtelSpanAttributes.OBSERVATION_METADATA if type == "observation" diff --git a/langfuse/_client/client.py b/langfuse/_client/client.py index ade8f3c6d..e15669da2 100644 --- a/langfuse/_client/client.py +++ b/langfuse/_client/client.py @@ -49,6 +49,7 @@ from langfuse.api.resources.commons.errors.error import Error from langfuse.api.resources.ingestion.types.score_body import ScoreBody from langfuse.api.resources.prompts.types import ( + ChatMessage, CreatePromptRequest_Chat, CreatePromptRequest_Text, Prompt_Chat, @@ -142,6 +143,8 @@ class Langfuse: ``` """ + _otel_tracer: Union[otel_trace_api.Tracer, otel_trace_api.NoOpTracer] + def __init__( self, *, @@ -159,11 +162,11 @@ def __init__( media_upload_thread_count: Optional[int] = None, sample_rate: Optional[float] = None, mask: Optional[MaskFunction] = None, - ): + ) -> None: self._host = host or os.environ.get(LANGFUSE_HOST, "https://cloud.langfuse.com") self._environment = environment or os.environ.get(LANGFUSE_TRACING_ENVIRONMENT) self._mask = mask - self._project_id = None + self._project_id: Optional[str] = None sample_rate = sample_rate or float(os.environ.get(LANGFUSE_SAMPLE_RATE, 1.0)) if not 0.0 <= sample_rate <= 1.0: raise ValueError( @@ -210,7 +213,7 @@ def __init__( self._resources = LangfuseResourceManager( public_key=public_key, secret_key=secret_key, - host=self._host, + host=self._host or "https://cloud.langfuse.com", timeout=timeout, environment=environment, release=release, @@ -223,7 +226,7 @@ def __init__( self._otel_tracer = ( self._resources.tracer - if self._tracing_enabled + if self._tracing_enabled and self._resources.tracer is not None else otel_trace_api.NoOpTracer() ) self.api = self._resources.api @@ -289,9 +292,7 @@ def start_span( trace_id=trace_id, parent_span_id=parent_span_id ) - with otel_trace_api.use_span( - cast(otel_trace_api.Span, remote_parent_span) - ): + with otel_trace_api.use_span(remote_parent_span): otel_span = self._otel_tracer.start_span( name=name, attributes=attributes ) @@ -503,9 +504,7 @@ def start_generation( trace_id=trace_id, parent_span_id=parent_span_id ) - with otel_trace_api.use_span( - cast(otel_trace_api.Span, remote_parent_span) - ): + with otel_trace_api.use_span(remote_parent_span): otel_span = self._otel_tracer.start_span( name=name, attributes=attributes ) @@ -652,16 +651,16 @@ def start_as_current_generation( def _create_span_with_parent_context( self, *, - name, - parent, - remote_parent_span, - attributes, + name: str, + parent: Optional[otel_trace_api.Span], + remote_parent_span: Optional[otel_trace_api.Span], + attributes: Dict[str, Any], as_type: Literal["generation", "span"], input: Optional[Any] = None, output: Optional[Any] = None, metadata: Optional[Any] = None, end_on_exit: Optional[bool] = None, - ): + ) -> Any: parent_span = parent or cast(otel_trace_api.Span, remote_parent_span) with otel_trace_api.use_span(parent_span): @@ -692,7 +691,7 @@ def _start_as_current_otel_span_with_processed_media( output: Optional[Any] = None, metadata: Optional[Any] = None, end_on_exit: Optional[bool] = None, - ): + ) -> Any: with self._otel_tracer.start_as_current_span( name=name, attributes=attributes, @@ -886,7 +885,7 @@ def update_current_trace( metadata: Optional[Any] = None, tags: Optional[List[str]] = None, public: Optional[bool] = None, - ): + ) -> None: """Update the current trace with additional information. This method updates the Langfuse trace that the current span belongs to. It's useful for @@ -1004,39 +1003,41 @@ def create_event( trace_id=trace_id, parent_span_id=parent_span_id ) - with otel_trace_api.use_span( - cast(otel_trace_api.Span, remote_parent_span) - ): + with otel_trace_api.use_span(remote_parent_span): otel_span = self._otel_tracer.start_span( name=name, attributes=attributes, start_time=timestamp ) otel_span.set_attribute(LangfuseOtelSpanAttributes.AS_ROOT, True) - return LangfuseEvent( + event = LangfuseEvent( otel_span=otel_span, langfuse_client=self, input=input, output=output, metadata=metadata, environment=self._environment, - ).end(end_time=timestamp) + ) + event.end(end_time=timestamp) + return event otel_span = self._otel_tracer.start_span( name=name, attributes=attributes, start_time=timestamp ) - return LangfuseEvent( + event = LangfuseEvent( otel_span=otel_span, langfuse_client=self, input=input, output=output, metadata=metadata, environment=self._environment, - ).end(end_time=timestamp) + ) + event.end(end_time=timestamp) + return event def _create_remote_parent_span( self, *, trace_id: str, parent_span_id: Optional[str] - ): + ) -> otel_trace_api.Span: if not self._is_valid_trace_id(trace_id): langfuse_logger.warning( f"Passed trace ID '{trace_id}' is not a valid 32 lowercase hex char Langfuse trace id. Ignoring trace ID." @@ -1063,12 +1064,12 @@ def _create_remote_parent_span( return trace.NonRecordingSpan(span_context) - def _is_valid_trace_id(self, trace_id): + def _is_valid_trace_id(self, trace_id: str) -> bool: pattern = r"^[0-9a-f]{32}$" return bool(re.match(pattern, trace_id)) - def _is_valid_span_id(self, span_id): + def _is_valid_span_id(self, span_id: str) -> bool: pattern = r"^[0-9a-f]{16}$" return bool(re.match(pattern, span_id)) @@ -1170,12 +1171,12 @@ def create_trace_id(*, seed: Optional[str] = None) -> str: return sha256(seed.encode("utf-8")).digest()[:16].hex() - def _get_otel_trace_id(self, otel_span: otel_trace_api.Span): + def _get_otel_trace_id(self, otel_span: otel_trace_api.Span) -> str: span_context = otel_span.get_span_context() return self._format_otel_trace_id(span_context.trace_id) - def _get_otel_span_id(self, otel_span: otel_trace_api.Span): + def _get_otel_span_id(self, otel_span: otel_trace_api.Span) -> str: span_context = otel_span.get_span_context() return self._format_otel_span_id(span_context.span_id) @@ -1304,22 +1305,20 @@ def create_score( score_id = score_id or self._create_observation_id() try: - score_event = { - "id": score_id, - "session_id": session_id, - "dataset_run_id": dataset_run_id, - "trace_id": trace_id, - "observation_id": observation_id, - "name": name, - "value": value, - "data_type": data_type, - "comment": comment, - "config_id": config_id, - "environment": self._environment, - "metadata": metadata, - } - - new_body = ScoreBody(**score_event) + new_body = ScoreBody( + id=score_id, + sessionId=session_id, + datasetRunId=dataset_run_id, + traceId=trace_id, + observationId=observation_id, + name=name, + value=value, + dataType=data_type, # type: ignore[arg-type] + comment=comment, + configId=config_id, + environment=self._environment, + metadata=metadata, + ) event = { "id": self.create_trace_id(), @@ -1501,7 +1500,7 @@ def score_current_trace( config_id=config_id, ) - def flush(self): + def flush(self) -> None: """Force flush all pending spans and events to the Langfuse API. This method manually flushes any pending spans, scores, and other events to the @@ -1523,7 +1522,7 @@ def flush(self): """ self._resources.flush() - def shutdown(self): + def shutdown(self) -> None: """Shut down the Langfuse client and flush all pending data. This method cleanly shuts down the Langfuse client, ensuring all pending data @@ -1810,7 +1809,7 @@ def resolve_media_references( resolve_with: Literal["base64_data_uri"], max_depth: int = 10, content_fetch_timeout_seconds: int = 10, - ): + ) -> Any: """Replace media reference strings in an object with base64 data URIs. This method recursively traverses an object (up to max_depth) looking for media reference strings @@ -1952,7 +1951,7 @@ def get_prompt( label=label, ttl_seconds=cache_ttl_seconds, max_retries=bounded_max_retries, - fetch_timeout_seconds=fetch_timeout_seconds, + fetch_timeout_seconds=fetch_timeout_seconds or 10, ) except Exception as e: if fallback: @@ -1960,7 +1959,7 @@ def get_prompt( f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}" ) - fallback_client_args = { + fallback_client_args: Dict[str, Any] = { "name": name, "prompt": fallback, "type": type, @@ -1972,13 +1971,34 @@ def get_prompt( if type == "text": return TextPromptClient( - prompt=Prompt_Text(**fallback_client_args), + prompt=Prompt_Text( + name=name, + prompt=cast(str, fallback), + version=version or 0, + config=fallback_client_args["config"], + labels=fallback_client_args["labels"], + tags=fallback_client_args["tags"], + type="text", + ), is_fallback=True, ) if type == "chat": return ChatPromptClient( - prompt=Prompt_Chat(**fallback_client_args), + prompt=Prompt_Chat( + name=name, + prompt=[ + ChatMessage( + role=msg["role"], content=msg["content"] + ) + for msg in cast(List[ChatMessageDict], fallback) + ], + version=version or 0, + config=fallback_client_args["config"], + labels=fallback_client_args["labels"], + tags=fallback_client_args["tags"], + type="chat", + ), is_fallback=True, ) @@ -1997,7 +2017,7 @@ def get_prompt( label=label, ttl_seconds=cache_ttl_seconds, max_retries=bounded_max_retries, - fetch_timeout_seconds=fetch_timeout_seconds, + fetch_timeout_seconds=fetch_timeout_seconds or 10, ), ) langfuse_logger.debug( @@ -2023,7 +2043,7 @@ def _fetch_prompt_and_update_cache( label: Optional[str] = None, ttl_seconds: Optional[int] = None, max_retries: int, - fetch_timeout_seconds, + fetch_timeout_seconds: int, ) -> PromptClient: cache_key = PromptCache.generate_cache_key(name, version=version, label=label) langfuse_logger.debug(f"Fetching prompt '{cache_key}' from server...") @@ -2033,7 +2053,7 @@ def _fetch_prompt_and_update_cache( @backoff.on_exception( backoff.constant, Exception, max_tries=max_retries, logger=None ) - def fetch_prompts(): + def fetch_prompts() -> Any: return self.api.prompts.get( self._url_encode(name), version=version, @@ -2047,6 +2067,7 @@ def fetch_prompts(): prompt_response = fetch_prompts() + prompt: Union[ChatPromptClient, TextPromptClient] if prompt_response.type == "chat": prompt = ChatPromptClient(prompt_response) else: @@ -2140,7 +2161,7 @@ def create_prompt( raise ValueError( "For 'chat' type, 'prompt' must be a list of chat messages with role and content attributes." ) - request = CreatePromptRequest_Chat( + request_chat = CreatePromptRequest_Chat( name=name, prompt=cast(Any, prompt), labels=labels, @@ -2149,7 +2170,7 @@ def create_prompt( commitMessage=commit_message, type="chat", ) - server_prompt = self.api.prompts.create(request=request) + server_prompt = self.api.prompts.create(request=request_chat) self._resources.prompt_cache.invalidate(name) @@ -2158,7 +2179,7 @@ def create_prompt( if not isinstance(prompt, str): raise ValueError("For 'text' type, 'prompt' must be a string.") - request = CreatePromptRequest_Text( + request_text = CreatePromptRequest_Text( name=name, prompt=prompt, labels=labels, @@ -2168,7 +2189,7 @@ def create_prompt( type="text", ) - server_prompt = self.api.prompts.create(request=request) + server_prompt = self.api.prompts.create(request=request_text) self._resources.prompt_cache.invalidate(name) @@ -2184,7 +2205,7 @@ def update_prompt( name: str, version: int, new_labels: List[str] = [], - ): + ) -> Any: """Update an existing prompt version in Langfuse. The Langfuse SDK prompt cache is invalidated for all prompts witht he specified name. Args: diff --git a/langfuse/_client/datasets.py b/langfuse/_client/datasets.py index 404a3020b..2f9b44e7e 100644 --- a/langfuse/_client/datasets.py +++ b/langfuse/_client/datasets.py @@ -91,7 +91,7 @@ def run( run_name: str, run_metadata: Optional[Any] = None, run_description: Optional[str] = None, - ): + ) -> Any: """Create a context manager for the dataset item run that links the execution to a Langfuse trace. This method is a context manager that creates a trace for the dataset run and yields a span diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index cdfb04d7b..38d691e6d 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -196,7 +196,7 @@ def _async_observe( transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> F: @wraps(func) - async def async_wrapper(*args, **kwargs): + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: trace_id = kwargs.pop("langfuse_trace_id", None) parent_observation_id = kwargs.pop("langfuse_parent_observation_id", None) trace_context: Optional[TraceContext] = ( @@ -284,7 +284,7 @@ def _sync_observe( transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> F: @wraps(func) - def sync_wrapper(*args, **kwargs): + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: trace_id = kwargs.pop("langfuse_trace_id", None) parent_observation_id = kwargs.pop("langfuse_parent_observation_id", None) trace_context: Optional[TraceContext] = ( @@ -388,7 +388,7 @@ def _wrap_sync_generator_result( langfuse_span_or_generation: Union[LangfuseSpan, LangfuseGeneration], generator: Generator, transform_to_string: Optional[Callable[[Iterable], str]] = None, - ): + ) -> Generator: items = [] try: @@ -401,10 +401,10 @@ def _wrap_sync_generator_result( output = items if transform_to_string is not None: - output = transform_to_string(items) + output = cast(Any, transform_to_string(items)) elif all(isinstance(item, str) for item in items): - output = "".join(items) + output = cast(Any, "".join(items)) langfuse_span_or_generation.update(output=output) langfuse_span_or_generation.end() @@ -427,10 +427,10 @@ async def _wrap_async_generator_result( output = items if transform_to_string is not None: - output = transform_to_string(items) + output = cast(Any, transform_to_string(items)) elif all(isinstance(item, str) for item in items): - output = "".join(items) + output = cast(Any, "".join(items)) langfuse_span_or_generation.update(output=output) langfuse_span_or_generation.end() diff --git a/langfuse/_client/resource_manager.py b/langfuse/_client/resource_manager.py index 548a637d9..9065e1537 100644 --- a/langfuse/_client/resource_manager.py +++ b/langfuse/_client/resource_manager.py @@ -18,7 +18,7 @@ import os import threading from queue import Full, Queue -from typing import Dict, Optional, cast +from typing import Any, Dict, Optional, cast import httpx from opentelemetry import trace as otel_trace_api @@ -75,6 +75,7 @@ class LangfuseResourceManager: _instances: Dict[str, "LangfuseResourceManager"] = {} _lock = threading.RLock() + _otel_tracer: Optional[otel_trace_api.Tracer] def __new__( cls, @@ -130,7 +131,7 @@ def _initialize_instance( media_upload_thread_count: Optional[int] = None, httpx_client: Optional[httpx.Client] = None, sample_rate: Optional[float] = None, - ): + ) -> None: self.public_key = public_key # OTEL Tracer @@ -148,7 +149,7 @@ def _initialize_instance( ) tracer_provider.add_span_processor(langfuse_processor) - tracer_provider = otel_trace_api.get_tracer_provider() + tracer_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider()) self._otel_tracer = tracer_provider.get_tracer( LANGFUSE_TRACER_NAME, langfuse_version, @@ -195,7 +196,7 @@ def _initialize_instance( LANGFUSE_MEDIA_UPLOAD_ENABLED, "True" ).lower() not in ("false", "0") - self._media_upload_queue = Queue(100_000) + self._media_upload_queue: Queue[Any] = Queue(100_000) self._media_manager = MediaManager( api_client=self.api, media_upload_queue=self._media_upload_queue, @@ -220,7 +221,7 @@ def _initialize_instance( self.prompt_cache = PromptCache() # Score ingestion - self._score_ingestion_queue = Queue(100_000) + self._score_ingestion_queue: Queue[Any] = Queue(100_000) self._ingestion_consumers = [] ingestion_consumer = ScoreIngestionConsumer( @@ -248,10 +249,10 @@ def _initialize_instance( ) @classmethod - def reset(cls): + def reset(cls) -> None: cls._instances.clear() - def add_score_task(self, event: dict): + def add_score_task(self, event: dict) -> None: try: # Sample scores with the same sampler that is used for tracing tracer_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider()) @@ -291,14 +292,14 @@ def add_score_task(self, event: dict): return @property - def tracer(self): + def tracer(self) -> Optional[otel_trace_api.Tracer]: return self._otel_tracer @staticmethod - def get_current_span(): + def get_current_span() -> otel_trace_api.Span: return otel_trace_api.get_current_span() - def _stop_and_join_consumer_threads(self): + def _stop_and_join_consumer_threads(self) -> None: """End the consumer threads once the queue is empty. Blocks execution until finished @@ -337,7 +338,7 @@ def _stop_and_join_consumer_threads(self): f"Shutdown: Score ingestion thread #{score_ingestion_consumer._identifier} successfully terminated" ) - def flush(self): + def flush(self) -> None: tracer_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider()) if isinstance(tracer_provider, otel_trace_api.ProxyTracerProvider): return @@ -351,7 +352,7 @@ def flush(self): self._media_upload_queue.join() langfuse_logger.debug("Successfully flushed media upload queue") - def shutdown(self): + def shutdown(self) -> None: # Unregister the atexit handler first atexit.unregister(self.shutdown) diff --git a/langfuse/_client/span.py b/langfuse/_client/span.py index 39d5c62eb..9419c9bae 100644 --- a/langfuse/_client/span.py +++ b/langfuse/_client/span.py @@ -97,7 +97,7 @@ def __init__( ) # Handle media only if span is sampled - if self._otel_span.is_recording: + if self._otel_span.is_recording(): media_processed_input = self._process_media_and_apply_mask( data=input, field="input", span=self._otel_span ) @@ -119,7 +119,7 @@ def __init__( {k: v for k, v in attributes.items() if v is not None} ) - def end(self, *, end_time: Optional[int] = None): + def end(self, *, end_time: Optional[int] = None) -> "LangfuseSpanWrapper": """End the span, marking it as completed. This method ends the wrapped OpenTelemetry span, marking the end of the @@ -145,7 +145,7 @@ def update_trace( metadata: Optional[Any] = None, tags: Optional[List[str]] = None, public: Optional[bool] = None, - ): + ) -> None: """Update the trace that this span belongs to. This method updates trace-level attributes of the trace that this span @@ -344,7 +344,7 @@ def _set_processed_span_attributes( input: Optional[Any] = None, output: Optional[Any] = None, metadata: Optional[Any] = None, - ): + ) -> None: """Set span attributes after processing media and applying masks. Internal method that processes media in the input, output, and metadata @@ -395,7 +395,7 @@ def _process_media_and_apply_mask( data: Optional[Any] = None, span: otel_trace_api.Span, field: Union[Literal["input"], Literal["output"], Literal["metadata"]], - ): + ) -> Any: """Process media in an attribute and apply masking. Internal method that processes any media content in the data and applies @@ -413,7 +413,7 @@ def _process_media_and_apply_mask( data=self._process_media_in_attribute(data=data, span=span, field=field) ) - def _mask_attribute(self, *, data): + def _mask_attribute(self, *, data: Any) -> Any: """Apply the configured mask function to data. Internal method that applies the client's configured masking function to @@ -443,7 +443,7 @@ def _process_media_in_attribute( data: Optional[Any] = None, span: otel_trace_api.Span, field: Union[Literal["input"], Literal["output"], Literal["metadata"]], - ): + ) -> Any: """Process any media content in the attribute data. Internal method that identifies and processes any media content in the @@ -517,7 +517,7 @@ def update( version: Optional[str] = None, level: Optional[SpanLevel] = None, status_message: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> "LangfuseSpan": """Update this span with new information. @@ -632,7 +632,7 @@ def start_span( name=name, attributes=attributes ) - if new_otel_span.is_recording: + if new_otel_span.is_recording(): self._set_processed_span_attributes( span=new_otel_span, as_type="span", @@ -731,7 +731,7 @@ def start_generation( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, - ): + ) -> "LangfuseGeneration": """Create a new child generation span. This method creates a new child generation span with this span as the parent. @@ -809,7 +809,7 @@ def start_generation( name=name, attributes=attributes ) - if new_otel_span.is_recording: + if new_otel_span.is_recording(): self._set_processed_span_attributes( span=new_otel_span, as_type="generation", @@ -967,7 +967,7 @@ def create_event( name=name, attributes=attributes, start_time=timestamp ) - if new_otel_span.is_recording: + if new_otel_span.is_recording(): self._set_processed_span_attributes( span=new_otel_span, as_type="event", @@ -976,14 +976,16 @@ def create_event( metadata=metadata, ) - return LangfuseEvent( + event = LangfuseEvent( otel_span=new_otel_span, langfuse_client=self._langfuse_client, input=input, output=output, metadata=metadata, environment=self._environment, - ).end(end_time=timestamp) + ) + event.end(end_time=timestamp) + return event class LangfuseGeneration(LangfuseSpanWrapper): @@ -1039,7 +1041,7 @@ def update( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, - **kwargs, + **kwargs: Any, ) -> "LangfuseGeneration": """Update this generation span with new information. diff --git a/langfuse/_client/utils.py b/langfuse/_client/utils.py index 670e40c4b..dac7a3f1b 100644 --- a/langfuse/_client/utils.py +++ b/langfuse/_client/utils.py @@ -11,7 +11,7 @@ from opentelemetry.sdk.trace import ReadableSpan -def span_formatter(span: ReadableSpan): +def span_formatter(span: ReadableSpan) -> str: parent_id = ( otel_trace_api.format_span_id(span.parent.span_id) if span.parent else None ) diff --git a/langfuse/_task_manager/media_manager.py b/langfuse/_task_manager/media_manager.py index 43a50f8c6..7d1747b29 100644 --- a/langfuse/_task_manager/media_manager.py +++ b/langfuse/_task_manager/media_manager.py @@ -39,7 +39,7 @@ def __init__( LANGFUSE_MEDIA_UPLOAD_ENABLED, "True" ).lower() not in ("false", "0") - def process_next_media_upload(self): + def process_next_media_upload(self) -> None: try: upload_job = self._queue.get(block=True, timeout=1) self._log.debug( @@ -64,14 +64,14 @@ def _find_and_process_media( trace_id: str, observation_id: Optional[str], field: str, - ): + ) -> Any: if not self._enabled: return data seen = set() max_levels = 10 - def _process_data_recursively(data: Any, level: int): + def _process_data_recursively(data: Any, level: int) -> Any: if id(data) in seen or level > max_levels: return data @@ -168,7 +168,7 @@ def _process_media( trace_id: str, observation_id: Optional[str], field: str, - ): + ) -> None: if ( media._content_length is None or media._content_type is None @@ -215,7 +215,7 @@ def _process_upload_media_job( self, *, data: UploadMediaJob, - ): + ) -> None: upload_url_response = self._request_with_backoff( self._api_client.media.get_upload_url, request=GetMediaUploadUrlRequest( diff --git a/langfuse/_task_manager/media_upload_consumer.py b/langfuse/_task_manager/media_upload_consumer.py index ccfad2c20..182170864 100644 --- a/langfuse/_task_manager/media_upload_consumer.py +++ b/langfuse/_task_manager/media_upload_consumer.py @@ -28,7 +28,7 @@ def __init__( self._identifier = identifier self._media_manager = media_manager - def run(self): + def run(self) -> None: """Run the media upload consumer.""" self._log.debug( f"Thread: Media upload consumer thread #{self._identifier} started and actively processing queue items" @@ -36,7 +36,7 @@ def run(self): while self.running: self._media_manager.process_next_media_upload() - def pause(self): + def pause(self) -> None: """Pause the media upload consumer.""" self._log.debug( f"Thread: Pausing media upload consumer thread #{self._identifier}" diff --git a/langfuse/_task_manager/score_ingestion_consumer.py b/langfuse/_task_manager/score_ingestion_consumer.py index 9543c12d9..3e14d82e1 100644 --- a/langfuse/_task_manager/score_ingestion_consumer.py +++ b/langfuse/_task_manager/score_ingestion_consumer.py @@ -11,9 +11,9 @@ from ..version import __version__ as langfuse_version try: - import pydantic.v1 as pydantic + import pydantic.v1 as pydantic_lib except ImportError: - import pydantic + import pydantic as pydantic_lib # type: ignore[no-redef] from langfuse._utils.parse_error import handle_exception from langfuse._utils.request import APIError, LangfuseClient @@ -23,7 +23,7 @@ MAX_BATCH_SIZE_BYTES = int(os.environ.get("LANGFUSE_MAX_BATCH_SIZE_BYTES", 2_500_000)) -class ScoreIngestionMetadata(pydantic.BaseModel): +class ScoreIngestionMetadata(pydantic_lib.BaseModel): batch_size: int sdk_name: str sdk_version: str @@ -61,9 +61,9 @@ def __init__( self._max_retries = max_retries or 3 self._public_key = public_key - def _next(self): + def _next(self) -> List[Any]: """Return the next batch of items to upload.""" - events = [] + events: List[Any] = [] start_time = time.monotonic() total_size = 0 @@ -78,7 +78,9 @@ def _next(self): ) # convert pydantic models to dicts - if "body" in event and isinstance(event["body"], pydantic.BaseModel): + if "body" in event and isinstance( + event["body"], pydantic_lib.BaseModel + ): event["body"] = event["body"].dict(exclude_none=True) item_size = self._get_item_size(event) @@ -119,7 +121,7 @@ def _get_item_size(self, item: Any) -> int: """Return the size of the item in bytes.""" return len(json.dumps(item, cls=EventSerializer).encode()) - def run(self): + def run(self) -> None: """Run the consumer.""" self._log.debug( f"Startup: Score ingestion consumer thread #{self._identifier} started with batch size {self._flush_at} and interval {self._flush_interval}s" @@ -127,7 +129,7 @@ def run(self): while self.running: self.upload() - def upload(self): + def upload(self) -> None: """Upload the next batch of items, return whether successful.""" batch = self._next() if len(batch) == 0: @@ -142,11 +144,11 @@ def upload(self): for _ in batch: self._ingestion_queue.task_done() - def pause(self): + def pause(self) -> None: """Pause the consumer.""" self.running = False - def _upload_batch(self, batch: List[Any]): + def _upload_batch(self, batch: List[Any]) -> None: self._log.debug( f"API: Uploading batch of {len(batch)} score events to Langfuse API" ) @@ -161,7 +163,7 @@ def _upload_batch(self, batch: List[Any]): @backoff.on_exception( backoff.expo, Exception, max_tries=self._max_retries, logger=None ) - def execute_task_with_backoff(batch: List[Any]): + def execute_task_with_backoff(batch: List[Any]) -> None: try: self._client.batch_post(batch=batch, metadata=metadata) except Exception as e: diff --git a/langfuse/_utils/__init__.py b/langfuse/_utils/__init__.py index 036a40be4..3fe71b5dc 100644 --- a/langfuse/_utils/__init__.py +++ b/langfuse/_utils/__init__.py @@ -9,14 +9,14 @@ log = logging.getLogger("langfuse") -def _get_timestamp(): +def _get_timestamp() -> datetime: return datetime.now(timezone.utc) def _create_prompt_context( prompt: typing.Optional[PromptClient] = None, -): +) -> typing.Dict[str, typing.Optional[str]]: if prompt is not None and not prompt.is_fallback: - return {"prompt_version": prompt.version, "prompt_name": prompt.name} + return {"prompt_version": str(prompt.version), "prompt_name": prompt.name} return {"prompt_version": None, "prompt_name": None} diff --git a/langfuse/_utils/environment.py b/langfuse/_utils/environment.py index bd7d6021d..a696b3a59 100644 --- a/langfuse/_utils/environment.py +++ b/langfuse/_utils/environment.py @@ -1,6 +1,7 @@ """@private""" import os +from typing import Optional common_release_envs = [ # Render @@ -26,7 +27,7 @@ ] -def get_common_release_envs(): +def get_common_release_envs() -> Optional[str]: for env in common_release_envs: if env in os.environ: return os.environ[env] diff --git a/langfuse/_utils/error_logging.py b/langfuse/_utils/error_logging.py index ef3507fe1..a60feab65 100644 --- a/langfuse/_utils/error_logging.py +++ b/langfuse/_utils/error_logging.py @@ -1,15 +1,15 @@ import functools import logging -from typing import List, Optional +from typing import Any, Callable, List, Optional logger = logging.getLogger("langfuse") -def catch_and_log_errors(func): +def catch_and_log_errors(func: Callable[..., Any]) -> Callable[..., Any]: """Catch all exceptions and log them. Do NOT re-raise the exception.""" @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: try: return func(*args, **kwargs) except Exception as e: @@ -18,14 +18,17 @@ def wrapper(*args, **kwargs): return wrapper -def auto_decorate_methods_with(decorator, exclude: Optional[List[str]] = []): +def auto_decorate_methods_with( + decorator: Callable[[Any], Any], exclude: Optional[List[str]] = None +) -> Callable[[Any], Any]: """Class decorator to automatically apply a given decorator to all methods of a class. """ - def class_decorator(cls): + def class_decorator(cls: Any) -> Any: + exclude_list = exclude or [] for attr_name, attr_value in cls.__dict__.items(): - if attr_name in exclude: + if attr_name in exclude_list: continue if callable(attr_value): # Wrap callable attributes (methods) with the decorator diff --git a/langfuse/_utils/parse_error.py b/langfuse/_utils/parse_error.py index 12d891606..0f50ce89a 100644 --- a/langfuse/_utils/parse_error.py +++ b/langfuse/_utils/parse_error.py @@ -61,9 +61,10 @@ def generate_error_message_fern(error: Error) -> str: if isinstance(error.status_code, str) else error.status_code ) - return errorResponseByCode.get(status_code, defaultErrorResponse) - else: - return defaultErrorResponse + if status_code is not None: + return errorResponseByCode.get(status_code, defaultErrorResponse) + + return defaultErrorResponse def handle_fern_exception(exception: Error) -> None: diff --git a/langfuse/_utils/prompt_cache.py b/langfuse/_utils/prompt_cache.py index 67611d50d..e02cd8bbe 100644 --- a/langfuse/_utils/prompt_cache.py +++ b/langfuse/_utils/prompt_cache.py @@ -5,7 +5,7 @@ from datetime import datetime from queue import Empty, Queue from threading import Thread -from typing import Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set from langfuse.model import PromptClient @@ -39,7 +39,7 @@ def __init__(self, queue: Queue, identifier: int): self._queue = queue self._identifier = identifier - def run(self): + def run(self) -> None: while self.running: try: task = self._queue.get(timeout=1) @@ -58,7 +58,7 @@ def run(self): except Empty: pass - def pause(self): + def pause(self) -> None: """Pause the consumer.""" self.running = False @@ -83,7 +83,7 @@ def __init__(self, threads: int = 1): atexit.register(self.shutdown) - def add_task(self, key: str, task): + def add_task(self, key: str, task: Any) -> None: if key not in self._processing_keys: self._log.debug(f"Adding prompt cache refresh task for key: {key}") self._processing_keys.add(key) @@ -97,8 +97,8 @@ def add_task(self, key: str, task): def active_tasks(self) -> int: return len(self._processing_keys) - def _wrap_task(self, key: str, task): - def wrapped(): + def _wrap_task(self, key: str, task: Any) -> Any: + def wrapped() -> None: self._log.debug(f"Refreshing prompt cache for key: {key}") try: task() @@ -108,7 +108,7 @@ def wrapped(): return wrapped - def shutdown(self): + def shutdown(self) -> None: self._log.debug( f"Shutting down prompt refresh task manager, {len(self._consumers)} consumers,..." ) @@ -146,19 +146,19 @@ def __init__( def get(self, key: str) -> Optional[PromptCacheItem]: return self._cache.get(key, None) - def set(self, key: str, value: PromptClient, ttl_seconds: Optional[int]): + def set(self, key: str, value: PromptClient, ttl_seconds: Optional[int]) -> None: if ttl_seconds is None: ttl_seconds = DEFAULT_PROMPT_CACHE_TTL_SECONDS self._cache[key] = PromptCacheItem(value, ttl_seconds) - def invalidate(self, prompt_name: str): + def invalidate(self, prompt_name: str) -> None: """Invalidate all cached prompts with the given prompt name.""" for key in list(self._cache): if key.startswith(prompt_name): del self._cache[key] - def add_refresh_prompt_task(self, key: str, fetch_func): + def add_refresh_prompt_task(self, key: str, fetch_func: Any) -> None: self._log.debug(f"Submitting refresh task for key: {key}") self._task_manager.add_task(key, fetch_func) diff --git a/langfuse/_utils/request.py b/langfuse/_utils/request.py index d420a3a13..9b4cd6b38 100644 --- a/langfuse/_utils/request.py +++ b/langfuse/_utils/request.py @@ -3,7 +3,7 @@ import json import logging from base64 import b64encode -from typing import Any, List, Union +from typing import Any, Dict, List, Union import httpx @@ -34,7 +34,7 @@ def __init__( self._timeout = timeout self._session = session - def generate_headers(self): + def generate_headers(self) -> Dict[str, str]: return { "Authorization": "Basic " + b64encode( @@ -46,7 +46,7 @@ def generate_headers(self): "x_langfuse_public_key": self._public_key, } - def batch_post(self, **kwargs) -> httpx.Response: + def batch_post(self, **kwargs: Any) -> httpx.Response: """Post the `kwargs` to the batch API endpoint for events""" log = logging.getLogger("langfuse") log.debug("uploading data: %s", kwargs) @@ -56,7 +56,7 @@ def batch_post(self, **kwargs) -> httpx.Response: res, success_message="data uploaded successfully", return_json=False ) - def post(self, **kwargs) -> httpx.Response: + def post(self, **kwargs: Any) -> httpx.Response: """Post the `kwargs` to the API""" log = logging.getLogger("langfuse") url = self._remove_trailing_slash(self._base_url) + "/api/public/ingestion" @@ -125,7 +125,7 @@ def __init__(self, status: Union[int, str], message: str, details: Any = None): self.status = status self.details = details - def __str__(self): + def __str__(self) -> str: msg = "{0} ({1}): {2}" return msg.format(self.message, self.status, self.details) @@ -134,7 +134,7 @@ class APIErrors(Exception): def __init__(self, errors: List[APIError]): self.errors = errors - def __str__(self): + def __str__(self) -> str: errors = ", ".join(str(error) for error in self.errors) return f"[Langfuse] {errors}" diff --git a/langfuse/_utils/serializer.py b/langfuse/_utils/serializer.py index 8f9665711..16bc8b8c1 100644 --- a/langfuse/_utils/serializer.py +++ b/langfuse/_utils/serializer.py @@ -20,25 +20,29 @@ # Attempt to import Serializable try: from langchain.load.serializable import Serializable + + SERIALIZABLE_AVAILABLE = True except ImportError: - # If Serializable is not available, set it to NoneType - Serializable = type(None) + # If Serializable is not available, set it to a placeholder type + Serializable = None # type: ignore + SERIALIZABLE_AVAILABLE = False + # Attempt to import numpy try: import numpy as np except ImportError: - np = None + np = None # type: ignore logger = getLogger(__name__) class EventSerializer(JSONEncoder): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.seen = set() # Track seen objects to detect circular references + self.seen: set[int] = set() # Track seen objects to detect circular references - def default(self, obj: Any): + def default(self, obj: Any) -> Any: try: if isinstance(obj, (datetime)): # Timezone-awareness check @@ -75,7 +79,7 @@ def default(self, obj: Any): if isinstance(obj, Queue): return type(obj).__name__ - if is_dataclass(obj): + if is_dataclass(obj) and not isinstance(obj, type): return asdict(obj) if isinstance(obj, UUID): @@ -106,8 +110,8 @@ def default(self, obj: Any): if isinstance(obj, Path): return str(obj) - # if langchain is not available, the Serializable type is NoneType - if Serializable is not type(None) and isinstance(obj, Serializable): + # if langchain is not available, the Serializable type is None + if SERIALIZABLE_AVAILABLE and isinstance(obj, Serializable): return obj.to_json() # 64-bit integers might overflow the JavaScript safe integer range. diff --git a/langfuse/api/core/jsonable_encoder.py b/langfuse/api/core/jsonable_encoder.py index 7a05e9190..7e0b4e75d 100644 --- a/langfuse/api/core/jsonable_encoder.py +++ b/langfuse/api/core/jsonable_encoder.py @@ -59,7 +59,7 @@ def jsonable_encoder( if "__root__" in obj_dict: obj_dict = obj_dict["__root__"] return jsonable_encoder(obj_dict, custom_encoder=encoder) - if dataclasses.is_dataclass(obj): + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): obj_dict = dataclasses.asdict(obj) return jsonable_encoder(obj_dict, custom_encoder=custom_encoder) if isinstance(obj, bytes): diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 576bfdebd..e7de4aa7a 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -19,6 +19,8 @@ from langfuse._utils import _get_timestamp from langfuse.langchain.utils import _extract_model_name +from langfuse.model import PromptClient +from langfuse.types import SpanLevel try: from langchain.callbacks.base import ( @@ -60,8 +62,8 @@ def __init__(self, *, public_key: Optional[str] = None) -> None: self.client = get_client(public_key=public_key) self.runs: Dict[UUID, Union[LangfuseSpan, LangfuseGeneration]] = {} - self.prompt_to_parent_run_map = {} - self.updated_completion_start_time_memo = set() + self.prompt_to_parent_run_map: Dict[UUID, str] = {} + self.updated_completion_start_time_memo: set[UUID] = set() def on_llm_new_token( self, @@ -104,18 +106,18 @@ def get_langchain_run_name( str: The determined name of the Langchain runnable. """ if "name" in kwargs and kwargs["name"] is not None: - return kwargs["name"] + return str(kwargs["name"]) if serialized is None: return "" try: - return serialized["name"] + return str(serialized["name"]) except (KeyError, TypeError): pass try: - return serialized["id"][-1] + return str(serialized["id"][-1]) except (KeyError, TypeError): pass @@ -174,11 +176,21 @@ def on_chain_start( } if parent_run_id is None: - self.runs[run_id] = self.client.start_span(**content) + self.runs[run_id] = self.client.start_span( + name=str(content["name"]), + metadata=content["metadata"], + input=content["input"], + level=cast(Optional[SpanLevel], content["level"]), + ) else: self.runs[run_id] = cast( LangfuseSpan, self.runs[parent_run_id] - ).start_span(**content) + ).start_span( + name=str(content["name"]), + metadata=content["metadata"], + input=content["input"], + level=cast(Optional[SpanLevel], content["level"]), + ) except Exception as e: langfuse_logger.exception(e) @@ -186,10 +198,10 @@ def on_chain_start( def _register_langfuse_prompt( self, *, - run_id, + run_id: UUID, parent_run_id: Optional[UUID], metadata: Optional[Dict[str, Any]], - ): + ) -> None: """We need to register any passed Langfuse prompt to the parent_run_id so that we can link following generations with that prompt. If parent_run_id is None, we are at the root of a trace and should not attempt to register the prompt, as there will be no LLM invocation following it. @@ -209,7 +221,7 @@ def _register_langfuse_prompt( registered_prompt = self.prompt_to_parent_run_map[parent_run_id] self.prompt_to_parent_run_map[run_id] = registered_prompt - def _deregister_langfuse_prompt(self, run_id: Optional[UUID]): + def _deregister_langfuse_prompt(self, run_id: Optional[UUID]) -> None: if run_id in self.prompt_to_parent_run_map: del self.prompt_to_parent_run_map[run_id] @@ -306,7 +318,7 @@ def on_chain_error( level = "ERROR" self.runs[run_id].update( - level=level, + level=cast(Optional[SpanLevel], level), status_message=str(error) if level else None, input=kwargs.get("inputs"), ).end() @@ -338,11 +350,8 @@ def on_chat_model_start( self.__on_llm_action( serialized, run_id, - cast( - List, - _flatten_comprehension( - [self._create_message_dicts(m) for m in messages] - ), + _flatten_comprehension( + [self._create_message_dicts(m) for m in messages] ), parent_run_id, tags=tags, @@ -439,7 +448,12 @@ def on_retriever_start( "level": "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None, } - self.runs[run_id] = self.client.start_span(**content) + self.runs[run_id] = self.client.start_span( + name=str(content["name"]), + metadata=content["metadata"], + input=content["input"], + level=cast(Optional[SpanLevel], content["level"]), + ) else: self.runs[run_id] = cast( LangfuseSpan, self.runs[parent_run_id] @@ -529,21 +543,27 @@ def __on_llm_action( self, serialized: Optional[Dict[str, Any]], run_id: UUID, - prompts: List[str], + prompts: List[Any], parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ): + ) -> None: try: tools = kwargs.get("invocation_params", {}).get("tools", None) if tools and isinstance(tools, list): - prompts.extend([{"role": "tool", "content": tool} for tool in tools]) + prompts.extend( + [{"role": "tool", "content": str(tool)} for tool in tools] + ) model_name = self._parse_model_and_log_errors( serialized=serialized, metadata=metadata, kwargs=kwargs ) - registered_prompt = self.prompt_to_parent_run_map.get(parent_run_id, None) + registered_prompt = ( + self.prompt_to_parent_run_map.get(parent_run_id) + if parent_run_id + else None + ) if registered_prompt: self._deregister_langfuse_prompt(parent_run_id) @@ -560,15 +580,33 @@ def __on_llm_action( if parent_run_id is not None and parent_run_id in self.runs: self.runs[run_id] = cast( LangfuseSpan, self.runs[parent_run_id] - ).start_generation(**content) + ).start_generation( + name=str(content["name"]), + input=content["input"], + metadata=content["metadata"], + model=cast(Optional[str], content["model"]), + model_parameters=cast( + Optional[Dict[str, Any]], content["model_parameters"] + ), + prompt=cast(Optional[PromptClient], content["prompt"]), + ) else: - self.runs[run_id] = self.client.start_generation(**content) + self.runs[run_id] = self.client.start_generation( + name=str(content["name"]), + input=content["input"], + metadata=content["metadata"], + model=cast(Optional[str], content["model"]), + model_parameters=cast( + Optional[Dict[str, Any]], content["model_parameters"] + ), + prompt=cast(Optional[PromptClient], content["prompt"]), + ) except Exception as e: langfuse_logger.exception(e) @staticmethod - def _parse_model_parameters(kwargs): + def _parse_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]: """Parse the model parameters from the kwargs.""" if kwargs["invocation_params"].get("_type") == "IBM watsonx.ai" and kwargs[ "invocation_params" @@ -600,7 +638,13 @@ def _parse_model_parameters(kwargs): if value is not None } - def _parse_model_and_log_errors(self, *, serialized, metadata, kwargs): + def _parse_model_and_log_errors( + self, + *, + serialized: Optional[Dict[str, Any]], + metadata: Optional[Dict[str, Any]], + kwargs: Dict[str, Any], + ) -> Optional[str]: """Parse the model name and log errors if parsing fails.""" try: model_name = _parse_model_name_from_metadata( @@ -614,8 +658,9 @@ def _parse_model_and_log_errors(self, *, serialized, metadata, kwargs): langfuse_logger.exception(e) self._log_model_parse_warning() + return None - def _log_model_parse_warning(self): + def _log_model_parse_warning(self) -> None: if not hasattr(self, "_model_parse_warning_logged"): langfuse_logger.warning( "Langfuse was not able to parse the LLM model. The LLM call will be recorded without model name. Please create an issue: https://github.com/langfuse/langfuse/issues/new/choose" @@ -638,11 +683,11 @@ def on_llm_end( if run_id not in self.runs: raise Exception("Run not found, see docs what to do in this case.") else: - generation = response.generations[-1][-1] + llm_generation = response.generations[-1][-1] extracted_response = ( - self._convert_message_to_dict(generation.message) - if isinstance(generation, ChatGeneration) - else _extract_raw_response(generation) + self._convert_message_to_dict(llm_generation.message) + if isinstance(llm_generation, ChatGeneration) + else _extract_raw_response(llm_generation) ) llm_usage = _parse_usage(response) @@ -730,7 +775,7 @@ def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: message_dict["name"] = message.additional_kwargs["name"] if message.additional_kwargs: - message_dict["additional_kwargs"] = message.additional_kwargs + message_dict["additional_kwargs"] = cast(Any, message.additional_kwargs) return message_dict @@ -744,14 +789,14 @@ def _log_debug_event( event_name: str, run_id: UUID, parent_run_id: Optional[UUID] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: langfuse_logger.debug( f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}" ) -def _extract_raw_response(last_response): +def _extract_raw_response(last_response: Any) -> Any: """Extract the response from the last response of the LLM call.""" # We return the text of the response if not empty if last_response.text is not None and last_response.text.strip() != "": @@ -764,11 +809,13 @@ def _extract_raw_response(last_response): return "" -def _flatten_comprehension(matrix): +def _flatten_comprehension(matrix: Any) -> List[Any]: return [item for row in matrix for item in row] -def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]): +def _parse_usage_model( + usage: typing.Union[pydantic.BaseModel, dict], +) -> typing.Optional[typing.Dict[str, typing.Any]]: # maintains a list of key translations. For each key, the usage model is checked # and a new object will be created with the new key if the key exists in the usage model # All non matched keys will remain on the object. @@ -891,7 +938,7 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]): return usage_model if usage_model else None -def _parse_usage(response: LLMResult): +def _parse_usage(response: LLMResult) -> typing.Optional[typing.Dict[str, typing.Any]]: # langchain-anthropic uses the usage field llm_usage_keys = ["token_usage", "usage"] llm_usage = None @@ -938,7 +985,7 @@ def _parse_usage(response: LLMResult): return llm_usage -def _parse_model(response: LLMResult): +def _parse_model(response: LLMResult) -> typing.Optional[str]: # langchain-anthropic uses the usage field llm_model_keys = ["model_name"] llm_model = None @@ -951,14 +998,18 @@ def _parse_model(response: LLMResult): return llm_model -def _parse_model_name_from_metadata(metadata: Optional[Dict[str, Any]]): +def _parse_model_name_from_metadata( + metadata: Optional[Dict[str, Any]], +) -> typing.Optional[str]: if metadata is None or not isinstance(metadata, dict): return None return metadata.get("ls_model_name", None) -def _strip_langfuse_keys_from_dict(metadata: Optional[Dict[str, Any]]): +def _strip_langfuse_keys_from_dict( + metadata: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: if metadata is None or not isinstance(metadata, dict): return metadata diff --git a/langfuse/langchain/utils.py b/langfuse/langchain/utils.py index abbd9e70d..1f95034df 100644 --- a/langfuse/langchain/utils.py +++ b/langfuse/langchain/utils.py @@ -1,7 +1,7 @@ """@private""" import re -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, cast # NOTE ON DEPENDENCIES: # - since Jan 2024, there is https://pypi.org/project/langchain-openai/ which is a separate package and imports openai models. @@ -12,7 +12,7 @@ def _extract_model_name( serialized: Optional[Dict[str, Any]], **kwargs: Any, -): +) -> Optional[str]: """Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse.""" # In this function we return on the first match, so the order of operations is important @@ -39,27 +39,39 @@ def _extract_model_name( for model_name, keys, select_from in models_by_id: model = _extract_model_by_path_for_id( - model_name, serialized, kwargs, keys, select_from + model_name, + serialized, + kwargs, + keys, + cast(Literal["serialized", "kwargs"], select_from), ) if model: return model # Second, we match AzureOpenAI as we need to extract the model name, fdeployment version and deployment name - if serialized.get("id")[-1] == "AzureOpenAI": - if kwargs.get("invocation_params").get("model"): - return kwargs.get("invocation_params").get("model") - - if kwargs.get("invocation_params").get("model_name"): - return kwargs.get("invocation_params").get("model_name") + serialized_id = serialized.get("id") if serialized else None + if ( + serialized + and serialized_id + and isinstance(serialized_id, list) + and serialized_id[-1] == "AzureOpenAI" + ): + invocation_params = kwargs.get("invocation_params") + if invocation_params and invocation_params.get("model"): + return str(invocation_params.get("model")) + + if invocation_params and invocation_params.get("model_name"): + return str(invocation_params.get("model_name")) deployment_name = None deployment_version = None - if serialized.get("kwargs").get("openai_api_version"): - deployment_version = serialized.get("kwargs").get("deployment_version") + serialized_kwargs = serialized.get("kwargs") if serialized else None + if serialized_kwargs and serialized_kwargs.get("openai_api_version"): + deployment_version = serialized_kwargs.get("deployment_version") - if serialized.get("kwargs").get("deployment_name"): - deployment_name = serialized.get("kwargs").get("deployment_name") + if serialized_kwargs and serialized_kwargs.get("deployment_name"): + deployment_name = serialized_kwargs.get("deployment_name") if not isinstance(deployment_name, str): return None @@ -111,7 +123,9 @@ def _extract_model_name( ] for select in ["kwargs", "serialized"]: for path in random_paths: - model = _extract_model_by_path(serialized, kwargs, path, select) + model = _extract_model_by_path( + serialized, kwargs, path, cast(Literal["serialized", "kwargs"], select) + ) if model: return model @@ -123,19 +137,21 @@ def _extract_model_from_repr_by_pattern( serialized: Optional[Dict[str, Any]], pattern: str, default: Optional[str] = None, -): +) -> Optional[str]: if serialized is None: return None - if serialized.get("id")[-1] == id: - if serialized.get("repr"): - extracted = _extract_model_with_regex(pattern, serialized.get("repr")) + serialized_id = serialized.get("id") if serialized else None + if serialized_id and isinstance(serialized_id, list) and serialized_id[-1] == id: + repr_str = serialized.get("repr") + if repr_str: + extracted = _extract_model_with_regex(pattern, repr_str) return extracted if extracted else default if default else None return None -def _extract_model_with_regex(pattern: str, text: str): +def _extract_model_with_regex(pattern: str, text: str) -> Optional[str]: match = re.search(rf"{pattern}='(.*?)'", text) if match: return match.group(1) @@ -145,31 +161,42 @@ def _extract_model_with_regex(pattern: str, text: str): def _extract_model_by_path_for_id( id: str, serialized: Optional[Dict[str, Any]], - kwargs: dict, + kwargs: Dict[str, Any], keys: List[str], select_from: Literal["serialized", "kwargs"], -): +) -> Optional[str]: if serialized is None and select_from == "serialized": return None - if serialized.get("id")[-1] == id: + serialized_id = serialized.get("id") if serialized else None + if ( + serialized + and serialized_id + and isinstance(serialized_id, list) + and serialized_id[-1] == id + ): return _extract_model_by_path(serialized, kwargs, keys, select_from) + return None + def _extract_model_by_path( serialized: Optional[Dict[str, Any]], - kwargs: dict, + kwargs: Dict[str, Any], keys: List[str], select_from: Literal["serialized", "kwargs"], -): +) -> Optional[str]: if serialized is None and select_from == "serialized": return None current_obj = kwargs if select_from == "kwargs" else serialized for key in keys: - current_obj = current_obj.get(key) + if current_obj and isinstance(current_obj, dict): + current_obj = current_obj.get(key) + else: + return None if not current_obj: return None - return current_obj if current_obj else None + return str(current_obj) if current_obj else None diff --git a/langfuse/media.py b/langfuse/media.py index e0be5d7c5..c7a890dae 100644 --- a/langfuse/media.py +++ b/langfuse/media.py @@ -106,11 +106,11 @@ def _read_file(self, file_path: str) -> Optional[bytes]: return None - def _get_media_id(self): + def _get_media_id(self) -> Optional[str]: content_hash = self._content_sha256_hash if content_hash is None: - return + return None # Convert hash to base64Url url_safe_content_hash = content_hash.replace("+", "-").replace("/", "_") @@ -187,7 +187,7 @@ def parse_reference_string(reference_string: str) -> ParsedMediaReference: return ParsedMediaReference( media_id=parsed_data["id"], source=parsed_data["source"], - content_type=parsed_data["type"], + content_type=cast(MediaContentType, parsed_data["type"]), ) def _parse_base64_data_uri( @@ -293,14 +293,14 @@ def traverse(obj: Any, depth: int) -> Any: media_data = langfuse_client.api.media.get( parsed_media_reference["media_id"] ) - media_content = requests.get( + media_response = requests.get( media_data.url, timeout=content_fetch_timeout_seconds ) - if not media_content.ok: + if not media_response.ok: raise Exception("Failed to fetch media content") base64_media_content = base64.b64encode( - media_content.content + media_response.content ).decode() base64_data_uri = f"data:{media_data.content_type};base64,{base64_media_content}" @@ -336,4 +336,4 @@ def traverse(obj: Any, depth: int) -> Any: return obj - return traverse(obj, 0) + return cast(T, traverse(obj, 0)) diff --git a/langfuse/model.py b/langfuse/model.py index 6380bf5f2..112b3a872 100644 --- a/langfuse/model.py +++ b/langfuse/model.py @@ -138,7 +138,9 @@ def __init__(self, prompt: Prompt, is_fallback: bool = False): self.is_fallback = is_fallback @abstractmethod - def compile(self, **kwargs) -> Union[str, List[ChatMessage]]: + def compile( + self, **kwargs: Any + ) -> Union[str, List[ChatMessage], List[ChatMessageDict]]: pass @property @@ -147,15 +149,15 @@ def variables(self) -> List[str]: pass @abstractmethod - def __eq__(self, other): + def __eq__(self, other: object) -> bool: pass @abstractmethod - def get_langchain_prompt(self): + def get_langchain_prompt(self) -> Any: pass @staticmethod - def _get_langchain_prompt_string(content: str): + def _get_langchain_prompt_string(content: str) -> str: return re.sub(r"{{\s*(\w+)\s*}}", r"{\g<1>}", content) @@ -164,7 +166,7 @@ def __init__(self, prompt: Prompt_Text, is_fallback: bool = False): super().__init__(prompt, is_fallback) self.prompt = prompt.prompt - def compile(self, **kwargs) -> str: + def compile(self, **kwargs: Any) -> str: return TemplateParser.compile_template(self.prompt, kwargs) @property @@ -172,8 +174,8 @@ def variables(self) -> List[str]: """Return all the variable names in the prompt template.""" return TemplateParser.find_variable_names(self.prompt) - def __eq__(self, other): - if isinstance(self, other.__class__): + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): return ( self.name == other.name and self.version == other.version @@ -183,7 +185,7 @@ def __eq__(self, other): return False - def get_langchain_prompt(self, **kwargs) -> str: + def get_langchain_prompt(self, **kwargs: Any) -> str: """Convert Langfuse prompt into string compatible with Langchain PromptTemplate. This method adapts the mustache-style double curly braces {{variable}} used in Langfuse @@ -212,7 +214,7 @@ def __init__(self, prompt: Prompt_Chat, is_fallback: bool = False): ChatMessageDict(role=p.role, content=p.content) for p in prompt.prompt ] - def compile(self, **kwargs) -> List[ChatMessageDict]: + def compile(self, **kwargs: Any) -> List[ChatMessageDict]: return [ ChatMessageDict( content=TemplateParser.compile_template( @@ -232,8 +234,8 @@ def variables(self) -> List[str]: for variable in TemplateParser.find_variable_names(chat_message["content"]) ] - def __eq__(self, other): - if isinstance(self, other.__class__): + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): return ( self.name == other.name and self.version == other.version @@ -246,7 +248,7 @@ def __eq__(self, other): return False - def get_langchain_prompt(self, **kwargs): + def get_langchain_prompt(self, **kwargs: Any) -> Any: """Convert Langfuse prompt into string compatible with Langchain ChatPromptTemplate. It specifically adapts the mustache-style double curly braces {{variable}} used in Langfuse diff --git a/langfuse/openai.py b/langfuse/openai.py index 7d0f2a454..224f76de0 100644 --- a/langfuse/openai.py +++ b/langfuse/openai.py @@ -22,7 +22,7 @@ from collections import defaultdict from dataclasses import dataclass from inspect import isclass -from typing import Optional, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from openai._types import NotGiven from packaging.version import Version @@ -44,10 +44,10 @@ try: from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI # noqa: F401 except ImportError: - AsyncAzureOpenAI = None - AsyncOpenAI = None - AzureOpenAI = None - OpenAI = None + AsyncAzureOpenAI = None # type: ignore[misc,assignment] + AsyncOpenAI = None # type: ignore[misc,assignment] + AzureOpenAI = None # type: ignore[misc,assignment] + OpenAI = None # type: ignore[misc,assignment] log = logging.getLogger("langfuse") @@ -147,13 +147,15 @@ class OpenAiDefinition: class OpenAiArgsExtractor: def __init__( self, - metadata=None, - name=None, - langfuse_prompt=None, # we cannot use prompt because it's an argument of the old OpenAI completions API - langfuse_public_key=None, - **kwargs, - ): - self.args = {} + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + langfuse_prompt: Optional[ + Any + ] = None, # we cannot use prompt because it's an argument of the old OpenAI completions API + langfuse_public_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.args: Dict[str, Any] = {} self.args["metadata"] = ( metadata if "response_format" not in kwargs @@ -171,10 +173,10 @@ def __init__( self.kwargs = kwargs - def get_langfuse_args(self): + def get_langfuse_args(self) -> Dict[str, Any]: return {**self.args, **self.kwargs} - def get_openai_args(self): + def get_openai_args(self) -> Dict[str, Any]: # If OpenAI model distillation is enabled, we need to add the metadata to the kwargs # https://platform.openai.com/docs/guides/distillation if self.kwargs.get("store", False): @@ -189,9 +191,9 @@ def get_openai_args(self): return self.kwargs -def _langfuse_wrapper(func): - def _with_langfuse(open_ai_definitions): - def wrapper(wrapped, instance, args, kwargs): +def _langfuse_wrapper(func: Callable[..., Any]) -> Callable[..., Any]: + def _with_langfuse(open_ai_definitions: Any) -> Callable[..., Any]: + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: return func(open_ai_definitions, wrapped, args, kwargs) return wrapper @@ -199,9 +201,9 @@ def wrapper(wrapped, instance, args, kwargs): return _with_langfuse -def _extract_chat_prompt(kwargs: any): +def _extract_chat_prompt(kwargs: Dict[str, Any]) -> Union[List[Any], Dict[str, Any]]: """Extracts the user input from prompts. Returns an array of messages or dict with messages and functions""" - prompt = {} + prompt: Dict[str, Any] = {} if kwargs.get("functions") is not None: prompt.update({"functions": kwargs["functions"]}) @@ -227,7 +229,7 @@ def _extract_chat_prompt(kwargs: any): return [_process_message(message) for message in kwargs.get("messages", [])] -def _process_message(message): +def _process_message(message: Any) -> Any: if not isinstance(message, dict): return message @@ -237,7 +239,7 @@ def _process_message(message): if not isinstance(content, list): return processed_message - processed_content = [] + processed_content: List[Any] = [] for content_part in content: if content_part.get("type") == "input_audio": @@ -264,7 +266,7 @@ def _process_message(message): return processed_message -def _extract_chat_response(kwargs: any): +def _extract_chat_response(kwargs: Any) -> Dict[str, Any]: """Extracts the llm output from the response.""" response = { "role": kwargs.get("role", None), @@ -297,7 +299,9 @@ def _extract_chat_response(kwargs: any): return response -def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs): +def _get_langfuse_data_from_kwargs( + resource: OpenAiDefinition, kwargs: Dict[str, Any] +) -> Dict[str, Any]: name = kwargs.get("name", "OpenAI-generation") if name is None: @@ -418,13 +422,13 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs): def _create_langfuse_update( - completion, + completion: Any, generation: LangfuseGeneration, - completion_start_time, - model=None, - usage=None, - metadata=None, -): + completion_start_time: Any, + model: Optional[str] = None, + usage: Optional[Any] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: update = { "output": completion, "completion_start_time": completion_start_time, @@ -441,9 +445,9 @@ def _create_langfuse_update( generation.update(**update) -def _parse_usage(usage=None): +def _parse_usage(usage: Optional[Any] = None) -> Optional[Dict[str, Any]]: if usage is None: - return + return None usage_dict = usage.copy() if isinstance(usage, dict) else usage.__dict__.copy() @@ -466,7 +470,9 @@ def _parse_usage(usage=None): return usage_dict -def _extract_streamed_response_api_response(chunks): +def _extract_streamed_response_api_response( + chunks: List[Any], +) -> Tuple[Optional[str], Any, Optional[Any], Dict[str, Any]]: completion, model, usage = None, None, None metadata = {} @@ -493,8 +499,10 @@ def _extract_streamed_response_api_response(chunks): return (model, completion, usage, metadata) -def _extract_streamed_openai_response(resource, chunks): - completion = defaultdict(str) if resource.type == "chat" else "" +def _extract_streamed_openai_response( + resource: OpenAiDefinition, chunks: List[Any] +) -> Tuple[Optional[str], Any, Optional[Any], Optional[Any]]: + completion: Any = defaultdict(str) if resource.type == "chat" else "" model, usage = None, None for chunk in chunks: @@ -575,7 +583,7 @@ def _extract_streamed_openai_response(resource, chunks): if resource.type == "completion": completion += choice.get("text", "") - def get_response_for_chat(): + def get_response_for_chat() -> Optional[Any]: return ( completion["content"] or ( @@ -606,7 +614,9 @@ def get_response_for_chat(): ) -def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, response): +def _get_langfuse_data_from_default_response( + resource: OpenAiDefinition, response: Optional[Dict[str, Any]] +) -> Tuple[Optional[str], Any, Optional[Dict[str, Any]]]: if response is None: return None, "", None @@ -655,11 +665,11 @@ def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, respons return (model, completion, usage) -def _is_openai_v1(): +def _is_openai_v1() -> bool: return Version(openai.__version__) >= Version("1.0.0") -def _is_streaming_response(response): +def _is_streaming_response(response: Any) -> bool: return ( isinstance(response, types.GeneratorType) or isinstance(response, types.AsyncGeneratorType) @@ -669,7 +679,9 @@ def _is_streaming_response(response): @_langfuse_wrapper -def _wrap(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs): +def _wrap( + open_ai_resource: OpenAiDefinition, wrapped: Any, args: Any, kwargs: Any +) -> Any: arg_extractor = OpenAiArgsExtractor(*args, **kwargs) langfuse_args = arg_extractor.get_langfuse_args() @@ -730,7 +742,9 @@ def _wrap(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs): @_langfuse_wrapper -async def _wrap_async(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs): +async def _wrap_async( + open_ai_resource: OpenAiDefinition, wrapped: Any, args: Any, kwargs: Any +) -> Any: arg_extractor = OpenAiArgsExtractor(*args, **kwargs) langfuse_args = arg_extractor.get_langfuse_args() @@ -790,7 +804,7 @@ async def _wrap_async(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs) raise ex -def register_tracing(): +def register_tracing() -> None: resources = OPENAI_METHODS_V1 if _is_openai_v1() else OPENAI_METHODS_V0 for resource in resources: @@ -813,18 +827,18 @@ class LangfuseResponseGeneratorSync: def __init__( self, *, - resource, - response, - generation, - ): - self.items = [] + resource: OpenAiDefinition, + response: Any, + generation: LangfuseGeneration, + ) -> None: + self.items: List[Any] = [] self.resource = resource self.response = response self.generation = generation - self.completion_start_time = None + self.completion_start_time: Optional[Any] = None - def __iter__(self): + def __iter__(self) -> Any: try: for i in self.response: self.items.append(i) @@ -836,7 +850,7 @@ def __iter__(self): finally: self._finalize() - def __next__(self): + def __next__(self) -> Any: try: item = self.response.__next__() self.items.append(item) @@ -851,13 +865,13 @@ def __next__(self): raise - def __enter__(self): + def __enter__(self) -> Any: return self.__iter__() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: pass - def _finalize(self): + def _finalize(self) -> None: try: model, completion, usage, metadata = ( _extract_streamed_response_api_response(self.items) @@ -883,18 +897,18 @@ class LangfuseResponseGeneratorAsync: def __init__( self, *, - resource, - response, - generation, - ): - self.items = [] + resource: OpenAiDefinition, + response: Any, + generation: LangfuseGeneration, + ) -> None: + self.items: List[Any] = [] self.resource = resource self.response = response self.generation = generation - self.completion_start_time = None + self.completion_start_time: Optional[Any] = None - async def __aiter__(self): + async def __aiter__(self) -> Any: try: async for i in self.response: self.items.append(i) @@ -906,7 +920,7 @@ async def __aiter__(self): finally: await self._finalize() - async def __anext__(self): + async def __anext__(self) -> Any: try: item = await self.response.__anext__() self.items.append(item) @@ -921,13 +935,13 @@ async def __anext__(self): raise - async def __aenter__(self): + async def __aenter__(self) -> Any: return self.__aiter__() - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: pass - async def _finalize(self): + async def _finalize(self) -> None: try: model, completion, usage, metadata = ( _extract_streamed_response_api_response(self.items) diff --git a/langfuse/py.typed b/langfuse/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/poetry.lock b/poetry.lock index 10681d69b..1c52d300c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2792,6 +2792,60 @@ files = [ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] +[[package]] +name = "mypy" +version = "1.16.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "mypy-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7909541fef256527e5ee9c0a7e2aeed78b6cda72ba44298d1334fe7881b05c5c"}, + {file = "mypy-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e71d6f0090c2256c713ed3d52711d01859c82608b5d68d4fa01a3fe30df95571"}, + {file = "mypy-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:936ccfdd749af4766be824268bfe22d1db9eb2f34a3ea1d00ffbe5b5265f5491"}, + {file = "mypy-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4086883a73166631307fdd330c4a9080ce24913d4f4c5ec596c601b3a4bdd777"}, + {file = "mypy-1.16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:feec38097f71797da0231997e0de3a58108c51845399669ebc532c815f93866b"}, + {file = "mypy-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:09a8da6a0ee9a9770b8ff61b39c0bb07971cda90e7297f4213741b48a0cc8d93"}, + {file = "mypy-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9f826aaa7ff8443bac6a494cf743f591488ea940dd360e7dd330e30dd772a5ab"}, + {file = "mypy-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:82d056e6faa508501af333a6af192c700b33e15865bda49611e3d7d8358ebea2"}, + {file = "mypy-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:089bedc02307c2548eb51f426e085546db1fa7dd87fbb7c9fa561575cf6eb1ff"}, + {file = "mypy-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6a2322896003ba66bbd1318c10d3afdfe24e78ef12ea10e2acd985e9d684a666"}, + {file = "mypy-1.16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:021a68568082c5b36e977d54e8f1de978baf401a33884ffcea09bd8e88a98f4c"}, + {file = "mypy-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:54066fed302d83bf5128632d05b4ec68412e1f03ef2c300434057d66866cea4b"}, + {file = "mypy-1.16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c5436d11e89a3ad16ce8afe752f0f373ae9620841c50883dc96f8b8805620b13"}, + {file = "mypy-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f2622af30bf01d8fc36466231bdd203d120d7a599a6d88fb22bdcb9dbff84090"}, + {file = "mypy-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d045d33c284e10a038f5e29faca055b90eee87da3fc63b8889085744ebabb5a1"}, + {file = "mypy-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b4968f14f44c62e2ec4a038c8797a87315be8df7740dc3ee8d3bfe1c6bf5dba8"}, + {file = "mypy-1.16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eb14a4a871bb8efb1e4a50360d4e3c8d6c601e7a31028a2c79f9bb659b63d730"}, + {file = "mypy-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:bd4e1ebe126152a7bbaa4daedd781c90c8f9643c79b9748caa270ad542f12bec"}, + {file = "mypy-1.16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a9e056237c89f1587a3be1a3a70a06a698d25e2479b9a2f57325ddaaffc3567b"}, + {file = "mypy-1.16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b07e107affb9ee6ce1f342c07f51552d126c32cd62955f59a7db94a51ad12c0"}, + {file = "mypy-1.16.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c6fb60cbd85dc65d4d63d37cb5c86f4e3a301ec605f606ae3a9173e5cf34997b"}, + {file = "mypy-1.16.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a7e32297a437cc915599e0578fa6bc68ae6a8dc059c9e009c628e1c47f91495d"}, + {file = "mypy-1.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:afe420c9380ccec31e744e8baff0d406c846683681025db3531b32db56962d52"}, + {file = "mypy-1.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:55f9076c6ce55dd3f8cd0c6fff26a008ca8e5131b89d5ba6d86bd3f47e736eeb"}, + {file = "mypy-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f56236114c425620875c7cf71700e3d60004858da856c6fc78998ffe767b73d3"}, + {file = "mypy-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:15486beea80be24ff067d7d0ede673b001d0d684d0095803b3e6e17a886a2a92"}, + {file = "mypy-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f2ed0e0847a80655afa2c121835b848ed101cc7b8d8d6ecc5205aedc732b1436"}, + {file = "mypy-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb5fbc8063cb4fde7787e4c0406aa63094a34a2daf4673f359a1fb64050e9cb2"}, + {file = "mypy-1.16.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a5fcfdb7318c6a8dd127b14b1052743b83e97a970f0edb6c913211507a255e20"}, + {file = "mypy-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:2e7e0ad35275e02797323a5aa1be0b14a4d03ffdb2e5f2b0489fa07b89c67b21"}, + {file = "mypy-1.16.0-py3-none-any.whl", hash = "sha256:29e1499864a3888bca5c1542f2d7232c6e586295183320caa95758fc84034031"}, + {file = "mypy-1.16.0.tar.gz", hash = "sha256:84b94283f817e2aa6350a14b4a8fb2a35a53c286f97c9d30f53b63620e7af8ab"}, +] + +[package.dependencies] +mypy_extensions = ">=1.0.0" +pathspec = ">=0.9.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing_extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -3374,6 +3428,17 @@ files = [ [package.extras] dev = ["jinja2"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pdoc" version = "14.6.0" @@ -5484,4 +5549,4 @@ openai = ["openai"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "d0353d0579f11dc634c6da45be88e85310fd9f1172500b9670b2cbc428370eb9" +content-hash = "acdcf5642aba80585f46a67189b0ae2931af42d63551ce3061c09353d6dbf230" diff --git a/pyproject.toml b/pyproject.toml index 14937f0de..162efa57a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ pytest-asyncio = ">=0.21.1,<0.24.0" pytest-httpserver = "^1.0.8" boto3 = "^1.28.59" ruff = ">=0.1.8,<0.6.0" +mypy = "^1.0.0" langchain-mistralai = ">=0.0.1,<0.3" google-cloud-aiplatform = "^1.38.1" cohere = ">=4.46,<6.0" @@ -72,5 +73,53 @@ log_cli = true [tool.poetry_bumpversion.file."langfuse/version.py"] +[tool.mypy] +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = false +warn_no_return = true +warn_unreachable = true +strict_equality = true +show_error_codes = true + +# Performance optimizations for CI +cache_dir = ".mypy_cache" +sqlite_cache = true +incremental = true +show_column_numbers = true + +[[tool.mypy.overrides]] +module = [ + "langchain.*", + "openai.*", + "chromadb.*", + "tiktoken.*", + "google.*", + "anthropic.*", + "cohere.*", + "dashscope.*", + "pymongo.*", + "bson.*", + "boto3.*", + "llama_index.*", + "respx.*", + "bs4.*", + "lark.*", + "huggingface_hub.*", + "backoff.*", + "wrapt.*", + "packaging.*", + "requests.*", + "opentelemetry.*" +] +ignore_missing_imports = true + [tool.poetry.scripts] release = "scripts.release:main"