Skip to content

Commit 00d9513

Browse files
author
Brandon Meyerowitz
committed
fix: add support for provider api keys
1 parent 2b58a00 commit 00d9513

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
env:
2727
CB_PROJECT_ID: ${{ secrets.CB_PROJECT_ID }}
2828
CB_API_KEY: ${{ secrets.CB_API_KEY }}
29+
CB_OPENAI_API_KEY: ${{ secrets.CB_OPENAI_API_KEY }}
2930
run: pytest
3031

3132
release:

commonbase/completion.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _send_completion_request(
5353
chat_context: Optional[ChatContext],
5454
user_id: Optional[str],
5555
truncate_variable: Optional[TruncationConfig],
56+
provider_api_key: Optional[str],
5657
provider_config: Optional[ProviderConfig],
5758
stream: bool,
5859
) -> requests.Response:
@@ -80,6 +81,10 @@ def _send_completion_request(
8081
if stream:
8182
headers["Accept"] = "text/event-stream"
8283

84+
print(provider_api_key)
85+
if provider_api_key is not None:
86+
headers["Provider-API-Key"] = provider_api_key
87+
8388
return requests.post(
8489
"https://api.commonbase.com/completions",
8590
stream=stream,
@@ -99,6 +104,7 @@ def create(
99104
chat_context: Optional[ChatContext] = None,
100105
user_id: Optional[str] = None,
101106
truncate_variable: Optional[TruncationConfig] = None,
107+
provider_api_key: Optional[str] = None,
102108
provider_config: Optional[ProviderConfig] = None,
103109
) -> CompletionResponse:
104110
"""Creates a completion for the given prompt.
@@ -119,6 +125,8 @@ def create(
119125
The User ID that will be logged for the invocation.
120126
truncate_variable : TruncationConfig, optional
121127
Configures variable truncation.
128+
provider_api_key : str, optional
129+
The API Key used to authenticate with a provider.
122130
provider_config : ProviderConfig, optional
123131
Configures the completion provider to use, currently OpenAI or Anthropic.
124132
@@ -138,6 +146,7 @@ def create(
138146
chat_context=chat_context,
139147
user_id=user_id,
140148
truncate_variable=truncate_variable,
149+
provider_api_key=provider_api_key,
141150
provider_config=provider_config,
142151
stream=False,
143152
)
@@ -159,6 +168,7 @@ def stream(
159168
chat_context: Optional[ChatContext] = None,
160169
user_id: Optional[str] = None,
161170
truncate_variable: Optional[TruncationConfig] = None,
171+
provider_api_key: Optional[str] = None,
162172
provider_config: Optional[ProviderConfig] = None,
163173
) -> Generator[CompletionResponse, None, None]:
164174
"""Creates a completion stream for the given prompt.
@@ -174,6 +184,7 @@ def stream(
174184
chat_context=chat_context,
175185
user_id=user_id,
176186
truncate_variable=truncate_variable,
187+
provider_api_key=provider_api_key,
177188
provider_config=provider_config,
178189
stream=True,
179190
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
import pytest
3+
from commonbase.exceptions import CommonbaseApiException
4+
from commonbase.completion import Completion
5+
from commonbase.provider_config import ProviderConfig, OpenAIParams
6+
7+
8+
def test_provider_with_no_api_key():
9+
with pytest.raises(CommonbaseApiException):
10+
Completion.create(
11+
api_key=os.getenv("CB_API_KEY") or "",
12+
project_id=os.getenv("CB_PROJECT_ID") or "",
13+
provider_config=ProviderConfig(
14+
provider="openai", params=OpenAIParams(type="chat")
15+
),
16+
prompt="",
17+
)
18+
19+
20+
def test_provider_with_valid_api_key():
21+
result = Completion.create(
22+
api_key=os.getenv("CB_API_KEY") or "",
23+
project_id=os.getenv("CB_PROJECT_ID") or "",
24+
provider_api_key=os.getenv("CB_OPENAI_API_KEY") or "",
25+
provider_config=ProviderConfig(
26+
provider="openai", params=OpenAIParams(type="chat")
27+
),
28+
prompt="Please return the string '123abc' to me without the quotes.",
29+
)
30+
31+
assert result.completed and result.best_result.strip() == "123abc"

0 commit comments

Comments
 (0)