Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/replicate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self,
*,
bearer_token: str | None = None,
api_token: str | None = None, # Legacy compatibility parameter
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -125,6 +126,14 @@ def __init__(

This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
"""
# Handle legacy api_token parameter
if api_token is not None and bearer_token is not None:
raise ReplicateError(
"Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility."
)
if api_token is not None:
bearer_token = api_token

if bearer_token is None:
bearer_token = _get_api_token_from_environment()
if bearer_token is None:
Expand Down Expand Up @@ -477,6 +486,7 @@ def __init__(
self,
*,
bearer_token: str | None = None,
api_token: str | None = None, # Legacy compatibility parameter
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -500,6 +510,14 @@ def __init__(

This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
"""
# Handle legacy api_token parameter
if api_token is not None and bearer_token is not None:
raise ReplicateError(
"Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility."
)
if api_token is not None:
bearer_token = api_token

if bearer_token is None:
bearer_token = _get_api_token_from_environment()
if bearer_token is None:
Expand Down
88 changes: 88 additions & 0 deletions tests/test_api_token_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Tests for api_token legacy compatibility during client instantiation."""

from __future__ import annotations

import pytest

from replicate import Replicate, AsyncReplicate, ReplicateError
from replicate._client import Client


class TestApiTokenCompatibility:
"""Test that api_token parameter works as a legacy compatibility option."""

def test_sync_client_with_api_token(self) -> None:
"""Test that Replicate accepts api_token parameter."""
client = Replicate(api_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_async_client_with_api_token(self) -> None:
"""Test that AsyncReplicate accepts api_token parameter."""
client = AsyncReplicate(api_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_sync_client_with_bearer_token(self) -> None:
"""Test that Replicate still accepts bearer_token parameter."""
client = Replicate(bearer_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_async_client_with_bearer_token(self) -> None:
"""Test that AsyncReplicate still accepts bearer_token parameter."""
client = AsyncReplicate(bearer_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_sync_client_both_tokens_error(self) -> None:
"""Test that providing both api_token and bearer_token raises an error."""
with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"):
Replicate(api_token="test_api", bearer_token="test_bearer")

def test_async_client_both_tokens_error(self) -> None:
"""Test that providing both api_token and bearer_token raises an error."""
with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"):
AsyncReplicate(api_token="test_api", bearer_token="test_bearer")

def test_sync_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that client reads from environment when no token is provided."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123")
client = Replicate()
assert client.bearer_token == "env_token_123"

def test_async_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that async client reads from environment when no token is provided."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123")
client = AsyncReplicate()
assert client.bearer_token == "env_token_123"

def test_sync_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that client raises error when no token is provided and env is not set."""
monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False)
with pytest.raises(ReplicateError, match="The bearer_token client option must be set"):
Replicate()

def test_async_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that async client raises error when no token is provided and env is not set."""
monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False)
with pytest.raises(ReplicateError, match="The bearer_token client option must be set"):
AsyncReplicate()

def test_legacy_client_alias(self) -> None:
"""Test that legacy Client import still works as an alias."""
assert Client is Replicate

def test_legacy_client_with_api_token(self) -> None:
"""Test that legacy Client alias works with api_token parameter."""
client = Client(api_token="test_token_123")
assert client.bearer_token == "test_token_123"
assert isinstance(client, Replicate)

def test_api_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that explicit api_token overrides environment variable."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token")
client = Replicate(api_token="explicit_token")
assert client.bearer_token == "explicit_token"

def test_bearer_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that explicit bearer_token overrides environment variable."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token")
client = Replicate(bearer_token="explicit_token")
assert client.bearer_token == "explicit_token"
Loading