Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/rotator_library/client/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,19 +562,21 @@ async def _execute_non_streaming(
# Pre-request callback
await self._run_pre_request_callback(context, kwargs)

# Make the API call
# Make the API call - determine function based on request type
is_embedding = context.request_type == "embedding"

if plugin and plugin.has_custom_logic():
kwargs["credential_identifier"] = cred
response = await plugin.acompletion(
self._http_client, **kwargs
)
call_fn = plugin.aembedding if is_embedding else plugin.acompletion
response = await call_fn(self._http_client, **kwargs)
else:
# Standard LiteLLM call
kwargs["api_key"] = cred
self._apply_litellm_logger(kwargs)
# Remove internal context before litellm call
kwargs.pop("transaction_context", None)
response = await litellm.acompletion(**kwargs)
call_fn = litellm.aembedding if is_embedding else litellm.acompletion
response = await call_fn(**kwargs)

# Success! Extract token usage if available
(
Expand Down
26 changes: 23 additions & 3 deletions src/rotator_library/client/rotating_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ async def acompletion(

return await self._executor.execute(context)

def aembedding(
async def aembedding(
self,
request: Optional[Any] = None,
pre_request_callback: Optional[callable] = None,
Expand All @@ -375,19 +375,39 @@ def aembedding(
f"Invalid model format or no credentials for provider: {model}"
)

# Extract internal logging parameters (not passed to API)
parent_log_dir = kwargs.pop("_parent_log_dir", None)

# Resolve model ID
resolved_model = self._model_resolver.resolve_model_id(model, provider)
kwargs["model"] = resolved_model

# Create transaction logger if enabled
transaction_logger = None
if self.enable_request_logging:
transaction_logger = TransactionLogger(
provider=provider,
model=resolved_model,
enabled=True,
parent_dir=parent_log_dir,
)
transaction_logger.log_request(kwargs)

# Build request context (embeddings are never streaming)
context = RequestContext(
model=model,
model=resolved_model,
provider=provider,
kwargs=kwargs,
streaming=False,
request_type="embedding",
credentials=self.all_credentials.get(provider, []),
deadline=time.time() + self.global_timeout,
request=request,
pre_request_callback=pre_request_callback,
transaction_logger=transaction_logger,
)

return self._executor.execute(context)
return await self._executor.execute(context)

def token_count(self, **kwargs) -> int:
"""Calculate token count for text or messages.
Expand Down
1 change: 1 addition & 0 deletions src/rotator_library/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class RequestContext:
streaming: bool
credentials: List[str]
deadline: float
request_type: Literal["completion", "embedding"] = "completion"
request: Optional[Any] = None # FastAPI Request object
pre_request_callback: Optional[Callable] = None
transaction_logger: Optional[Any] = None
Expand Down