diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 8f50bacb..602da189 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -562,19 +562,21 @@ async def _execute_non_streaming( # Pre-request callback await self._run_pre_request_callback(context, kwargs) - # Make the API call + # Make the API call - determine function based on request type + is_embedding = context.request_type == "embedding" + if plugin and plugin.has_custom_logic(): kwargs["credential_identifier"] = cred - response = await plugin.acompletion( - self._http_client, **kwargs - ) + call_fn = plugin.aembedding if is_embedding else plugin.acompletion + response = await call_fn(self._http_client, **kwargs) else: # Standard LiteLLM call kwargs["api_key"] = cred self._apply_litellm_logger(kwargs) # Remove internal context before litellm call kwargs.pop("transaction_context", None) - response = await litellm.acompletion(**kwargs) + call_fn = litellm.aembedding if is_embedding else litellm.acompletion + response = await call_fn(**kwargs) # Success! Extract token usage if available ( diff --git a/src/rotator_library/client/rotating_client.py b/src/rotator_library/client/rotating_client.py index a5cee0fc..21578834 100644 --- a/src/rotator_library/client/rotating_client.py +++ b/src/rotator_library/client/rotating_client.py @@ -358,7 +358,7 @@ async def acompletion( return await self._executor.execute(context) - def aembedding( + async def aembedding( self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, @@ -375,19 +375,39 @@ def aembedding( f"Invalid model format or no credentials for provider: {model}" ) + # Extract internal logging parameters (not passed to API) + parent_log_dir = kwargs.pop("_parent_log_dir", None) + + # Resolve model ID + resolved_model = self._model_resolver.resolve_model_id(model, provider) + kwargs["model"] = resolved_model + + # Create transaction logger if enabled + transaction_logger = None + if self.enable_request_logging: + transaction_logger = TransactionLogger( + provider=provider, + model=resolved_model, + enabled=True, + parent_dir=parent_log_dir, + ) + transaction_logger.log_request(kwargs) + # Build request context (embeddings are never streaming) context = RequestContext( - model=model, + model=resolved_model, provider=provider, kwargs=kwargs, streaming=False, + request_type="embedding", credentials=self.all_credentials.get(provider, []), deadline=time.time() + self.global_timeout, request=request, pre_request_callback=pre_request_callback, + transaction_logger=transaction_logger, ) - return self._executor.execute(context) + return await self._executor.execute(context) def token_count(self, **kwargs) -> int: """Calculate token count for text or messages. diff --git a/src/rotator_library/core/types.py b/src/rotator_library/core/types.py index f6cc72e2..c0220e31 100644 --- a/src/rotator_library/core/types.py +++ b/src/rotator_library/core/types.py @@ -63,6 +63,7 @@ class RequestContext: streaming: bool credentials: List[str] deadline: float + request_type: Literal["completion", "embedding"] = "completion" request: Optional[Any] = None # FastAPI Request object pre_request_callback: Optional[Callable] = None transaction_logger: Optional[Any] = None