-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Feat: Add Load-Balancing for multiple API KEYS #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @jorschac, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the application's interaction with Large Language Models by introducing a robust and unified LLM service layer. This new layer centralizes API key management, implements intelligent load balancing across multiple keys for various providers, and streamlines both synchronous and asynchronous LLM invocations. The changes enhance the system's maintainability, scalability, and reliability by abstracting provider-specific complexities and improving error handling, particularly for streaming responses and token usage tracking. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and valuable refactoring by adding a unified LLMService layer. This new service abstracts LLM API calls across multiple providers and implements load balancing for API keys, which greatly improves maintainability, scalability, and reliability. The code is well-structured, and the inclusion of configuration files and a new test suite is excellent. However, I have identified a critical thread-safety issue related to how API keys are handled for some providers, which could lead to race conditions and incorrect behavior in a concurrent environment. I've also pointed out another race condition in the load balancing logic and a suggestion to improve timeout handling for parallel requests. Addressing these points will make the new service robust and ready for production use.
| # Initialize client for different providers | ||
| if provider == "openai": | ||
| client = model_client_class(api_key=api_key) | ||
| elif provider == "google": | ||
| # Google client reads key from environment variables | ||
| if api_key: | ||
| os.environ["GOOGLE_API_KEY"] = api_key | ||
| client = model_client_class() | ||
| elif provider == "openrouter": | ||
| client = model_client_class(api_key=api_key) | ||
| elif provider == "azure": | ||
| client = model_client_class() # Azure reads key from environment variables | ||
| elif provider == "bedrock": | ||
| client = model_client_class() # Bedrock uses AWS credentials | ||
| elif provider == "dashscope": | ||
| if api_key: | ||
| os.environ["DASHSCOPE_API_KEY"] = api_key | ||
| client = model_client_class() | ||
| elif provider == "ollama": | ||
| client = model_client_class() # Ollama local service | ||
| else: | ||
| raise ValueError(f"Unsupported provider: {provider}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method modifies os.environ to set API keys for providers like google and dashscope. Since os.environ is a global, process-wide state, modifying it without synchronization in a multi-threaded context (as used in parallel_invoke) will cause a race condition. One thread might set an API key, but another thread could overwrite it before the first thread's client is initialized or used, leading to incorrect key usage, failed requests, and broken load balancing.
Recommendation:
The best solution is to refactor the underlying client classes (GoogleGenAIClient, DashscopeClient, etc.) to accept the api_key directly in their constructor, similar to how OpenAIClient is handled. Since DashscopeClient is part of this repository, it can be updated to support this.
If modifying the client classes is not feasible, you must protect the environment variable modification and client instantiation with a lock. For example:
# At class level
import threading
...
class LLMService:
...
def __init__(...):
...
self._client_creation_lock = threading.Lock()
...
def _get_client(self, provider: str, api_key: Optional[str] = None):
...
# For providers that use environment variables
if provider in ["google", "dashscope"]:
with self._client_creation_lock:
# Temporarily set env var, create client, then unset
env_var = f"{provider.upper()}_API_KEY"
original_value = os.environ.get(env_var)
try:
if api_key:
os.environ[env_var] = api_key
client = model_client_class()
finally:
# Restore original env var value
if original_value is not None:
os.environ[env_var] = original_value
elif api_key:
del os.environ[env_var]
else:
# Existing logic for other providers
...However, passing the key to the constructor is strongly preferred for cleaner, safer code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should address this.
| # Load balancing logic: select the key with the least usage and the least recently used | ||
| current_time = time.time() | ||
| best_key = min( | ||
| keys, | ||
| key=lambda k: ( | ||
| self.provider_key_usage[provider][str(k)], | ||
| self.provider_key_last_used[provider][str(k)] | ||
| ) | ||
| ) | ||
|
|
||
| # 更新统计 | ||
| best_key_str = str(best_key) | ||
| self.provider_key_usage[provider][best_key_str] += 1 | ||
| self.provider_key_last_used[provider][best_key_str] = current_time | ||
|
|
||
| logger.debug(f"Selected API key for {provider}: {best_key[:8]}...{best_key[-4:]}") | ||
| return best_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for selecting the next API key and updating usage statistics is not thread-safe. In parallel_invoke, multiple threads can call this method concurrently. This can lead to a race condition where multiple threads read the same usage counts, increment them, and then one thread's update is overwritten by another's. This would result in an incorrect usage count and defeat the purpose of the load balancing strategy.
To ensure atomicity of the key selection and statistics update, you should use a threading.Lock.
- Add
import threadingat the top of the file. - Initialize a lock in
LLMService.__init__:self._key_selection_lock = threading.Lock()
- Use the lock in
get_next_api_keyto protect the critical section as shown in the suggestion.
The same lock should also be used in reset_key_usage_stats to ensure thread safety when resetting statistics.
| # Load balancing logic: select the key with the least usage and the least recently used | |
| current_time = time.time() | |
| best_key = min( | |
| keys, | |
| key=lambda k: ( | |
| self.provider_key_usage[provider][str(k)], | |
| self.provider_key_last_used[provider][str(k)] | |
| ) | |
| ) | |
| # 更新统计 | |
| best_key_str = str(best_key) | |
| self.provider_key_usage[provider][best_key_str] += 1 | |
| self.provider_key_last_used[provider][best_key_str] = current_time | |
| logger.debug(f"Selected API key for {provider}: {best_key[:8]}...{best_key[-4:]}") | |
| return best_key | |
| with self._key_selection_lock: | |
| # Load balancing logic: select the key with the least usage and the least recently used | |
| current_time = time.time() | |
| best_key = min( | |
| keys, | |
| key=lambda k: ( | |
| self.provider_key_usage[provider][str(k)], | |
| self.provider_key_last_used[provider][str(k)] | |
| ) | |
| ) | |
| # 更新统计 | |
| best_key_str = str(best_key) | |
| self.provider_key_usage[provider][best_key_str] += 1 | |
| self.provider_key_last_used[provider][best_key_str] = current_time | |
| logger.debug(f"Selected API key for {provider}: {best_key[:8]}...{best_key[-4:]}") | |
| return best_key |
| # Collect results | ||
| completed = 0 | ||
| for future in as_completed(future_to_index, timeout=timeout): | ||
| try: | ||
| result = future.result() | ||
| index = result.get("index", future_to_index[future]) | ||
| results[index] = result | ||
| completed += 1 | ||
|
|
||
| if completed % 5 == 0 or completed == len(requests): | ||
| logger.info(f"Completed {completed}/{len(requests)} requests") | ||
|
|
||
| except Exception as e: | ||
| index = future_to_index[future] | ||
| logger.error(f"Request {index} failed: {str(e)}") | ||
| results[index] = { | ||
| "index": index, | ||
| "error": str(e), | ||
| "request_data": requests[index] | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The timeout parameter on as_completed applies to the entire batch of futures, not to each individual future. If any single request takes longer than the timeout, as_completed will raise a TimeoutError, terminating the loop and leaving unprocessed futures. This is likely not the intended behavior, which is typically a per-request timeout.
For a more robust per-request timeout, you should remove the timeout from as_completed and apply it to future.result(). This allows you to handle timeouts for individual requests gracefully.
Remember to import TimeoutError from concurrent.futures at the top of the file.
# Collect results
completed = 0
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
result = future.result(timeout=timeout)
results[index] = result
completed += 1
if completed % 5 == 0 or completed == len(requests):
logger.info(f"Completed {completed}/{len(requests)} requests")
except TimeoutError:
logger.error(f"Request {index} timed out after {timeout} seconds")
results[index] = {
"index": index,
"error": f"Request timed out after {timeout} seconds",
"request_data": requests[index]
}
except Exception as e:
logger.error(f"Request {index} failed: {str(e)}")
results[index] = {
"index": index,
"error": str(e),
"request_data": requests[index]
}
Description
This PR introduces a unified
LLMServicelayer that abstracts LLM API calls across multiple providers with built-in load balancing for API keys. The refactoring improves maintainability, scalability, and reliability of the LLM integration. The relevant Feature Request can be found here: #445Key Changes
1. New LLM Service Layer (
api/llm.py)LLMServiceclass as a unified interface for all LLM providers2. Configuration Management (
api/config.py,api/config/api_keys.json)load_api_keys_config()function to load API keys from configuration files or environment variablesreplace_env_placeholders()to support${ENV_VAR}placeholder substitution in JSON configsapi/config/api_keys.jsonfor centralized API key management3. OpenAI Client Improvements (
api/openai_client.py)usagefield in non-streaming responsesNonevalues when usage data is unavailable4. Streaming Interface Integration
api/simple_chat.py: Replaced direct client instantiation withLLMService.async_invoke_stream()api/websocket_wiki.py: Replaced direct client instantiation withLLMService.async_invoke_stream()5. Testing Infrastructure
tests/unit/test_balance_loading.pyto verify load balancing functionality6. Documentation
.env.exampletemplate for environment variable setupType of Change
Testing
Checklist