Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ exclude = [
".venv",
".nox",
".git",
"tests/test_models_backward_compat.py",
]

reportImplicitOverride = true
Expand Down
44 changes: 42 additions & 2 deletions src/replicate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self,
*,
bearer_token: str | None = None,
api_token: str | None = None, # Legacy compatibility parameter
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -124,7 +125,17 @@ def __init__(
"""Construct a new synchronous Replicate client instance.

This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.

For legacy compatibility, you can also pass `api_token` instead of `bearer_token`.
"""
# Handle legacy api_token parameter
if api_token is not None and bearer_token is not None:
raise ReplicateError(
"Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility."
)
if api_token is not None:
bearer_token = api_token

if bearer_token is None:
bearer_token = _get_api_token_from_environment()
if bearer_token is None:
Expand Down Expand Up @@ -175,9 +186,11 @@ def account(self) -> AccountResource:

@cached_property
def models(self) -> ModelsResource:
from .lib.models import patch_models_resource
from .resources.models import ModelsResource

return ModelsResource(self)
models_resource = ModelsResource(self)
return patch_models_resource(models_resource)

@cached_property
def predictions(self) -> PredictionsResource:
Expand Down Expand Up @@ -324,6 +337,7 @@ def copy(
self,
*,
bearer_token: str | None = None,
api_token: str | None = None, # Legacy compatibility parameter
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
Expand All @@ -337,6 +351,12 @@ def copy(
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
# Handle legacy api_token parameter
if api_token is not None and bearer_token is not None:
raise ValueError("Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility.")
if api_token is not None:
bearer_token = api_token

if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")

Expand Down Expand Up @@ -477,6 +497,7 @@ def __init__(
self,
*,
bearer_token: str | None = None,
api_token: str | None = None, # Legacy compatibility parameter
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -499,7 +520,17 @@ def __init__(
"""Construct a new async AsyncReplicate client instance.

This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.

For legacy compatibility, you can also pass `api_token` instead of `bearer_token`.
"""
# Handle legacy api_token parameter
if api_token is not None and bearer_token is not None:
raise ReplicateError(
"Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility."
)
if api_token is not None:
bearer_token = api_token

if bearer_token is None:
bearer_token = _get_api_token_from_environment()
if bearer_token is None:
Expand Down Expand Up @@ -550,9 +581,11 @@ def account(self) -> AsyncAccountResource:

@cached_property
def models(self) -> AsyncModelsResource:
from .lib.models import patch_models_resource
from .resources.models import AsyncModelsResource

return AsyncModelsResource(self)
models_resource = AsyncModelsResource(self)
return patch_models_resource(models_resource)

@cached_property
def predictions(self) -> AsyncPredictionsResource:
Expand Down Expand Up @@ -699,6 +732,7 @@ def copy(
self,
*,
bearer_token: str | None = None,
api_token: str | None = None, # Legacy compatibility parameter
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -712,6 +746,12 @@ def copy(
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
# Handle legacy api_token parameter
if api_token is not None and bearer_token is not None:
raise ValueError("Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility.")
if api_token is not None:
bearer_token = api_token

if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")

Expand Down
131 changes: 131 additions & 0 deletions src/replicate/lib/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Custom models functionality with backward compatibility.
"""

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Union, overload

from .._types import NOT_GIVEN, NotGiven
from ._models import ModelVersionIdentifier

if TYPE_CHECKING:
import httpx

from .._types import Body, Query, Headers
from ..resources.models.models import ModelsResource, AsyncModelsResource
from ..types.model_get_response import ModelGetResponse


def _parse_model_args(
model_or_owner: str | NotGiven,
model_owner: str | NotGiven,
model_name: str | NotGiven,
) -> tuple[str, str]:
"""Parse model arguments and return (owner, name)."""
# Handle legacy format: models.get("owner/name")
if model_or_owner is not NOT_GIVEN:
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
raise ValueError(
"Cannot specify both positional 'model_or_owner' and keyword arguments "
"'model_owner'/'model_name'. Use either the legacy format "
"models.get('owner/name') or the new format models.get(model_owner='owner', model_name='name')."
)

# Type guard: model_or_owner is definitely a string here
assert isinstance(model_or_owner, str)

# Parse the owner/name format
if "/" not in model_or_owner:
raise ValueError(
f"Invalid model reference '{model_or_owner}'. "
"Expected format: 'owner/name' (e.g., 'stability-ai/stable-diffusion')"
)

try:
parsed = ModelVersionIdentifier.parse(model_or_owner)
return parsed.owner, parsed.name
except ValueError as e:
raise ValueError(
f"Invalid model reference '{model_or_owner}'. "
f"Expected format: 'owner/name' (e.g., 'stability-ai/stable-diffusion'). "
f"Error: {e}"
) from e

# Validate required parameters for new format
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
raise ValueError(
"model_owner and model_name are required. "
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
)

return model_owner, model_name # type: ignore[return-value]


@overload
def patch_models_resource(models_resource: "ModelsResource") -> "ModelsResource": ...


@overload
def patch_models_resource(models_resource: "AsyncModelsResource") -> "AsyncModelsResource": ...


def patch_models_resource(
models_resource: Union["ModelsResource", "AsyncModelsResource"],
) -> Union["ModelsResource", "AsyncModelsResource"]:
"""Patch a models resource to add backward compatibility."""
original_get = models_resource.get
is_async = inspect.iscoroutinefunction(original_get)

if is_async:

async def async_get_wrapper(
model_or_owner: str | NotGiven = NOT_GIVEN,
*,
model_owner: str | NotGiven = NOT_GIVEN,
model_name: str | NotGiven = NOT_GIVEN,
extra_headers: "Headers | None" = None,
extra_query: "Query | None" = None,
extra_body: "Body | None" = None,
timeout: "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN,
) -> "ModelGetResponse":
owner, name = _parse_model_args(model_or_owner, model_owner, model_name)
return await models_resource._original_get( # type: ignore[misc,no-any-return,attr-defined,union-attr]
model_owner=owner,
model_name=name,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
)

wrapper = async_get_wrapper
else:

def sync_get_wrapper(
model_or_owner: str | NotGiven = NOT_GIVEN,
*,
model_owner: str | NotGiven = NOT_GIVEN,
model_name: str | NotGiven = NOT_GIVEN,
extra_headers: "Headers | None" = None,
extra_query: "Query | None" = None,
extra_body: "Body | None" = None,
timeout: "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN,
) -> "ModelGetResponse":
owner, name = _parse_model_args(model_or_owner, model_owner, model_name)
return models_resource._original_get( # type: ignore[misc,return-value,attr-defined,union-attr,no-any-return]
model_owner=owner,
model_name=name,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
)

wrapper = sync_get_wrapper # type: ignore[assignment]

# Store original method for tests and replace with wrapper
models_resource._original_get = original_get # type: ignore[attr-defined,union-attr]
models_resource.get = wrapper # type: ignore[method-assign]
return models_resource
88 changes: 88 additions & 0 deletions tests/test_api_token_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Tests for api_token legacy compatibility during client instantiation."""

from __future__ import annotations

import pytest

from replicate import Replicate, AsyncReplicate, ReplicateError
from replicate._client import Client


class TestApiTokenCompatibility:
"""Test that api_token parameter works as a legacy compatibility option."""

def test_sync_client_with_api_token(self) -> None:
"""Test that Replicate accepts api_token parameter."""
client = Replicate(api_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_async_client_with_api_token(self) -> None:
"""Test that AsyncReplicate accepts api_token parameter."""
client = AsyncReplicate(api_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_sync_client_with_bearer_token(self) -> None:
"""Test that Replicate still accepts bearer_token parameter."""
client = Replicate(bearer_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_async_client_with_bearer_token(self) -> None:
"""Test that AsyncReplicate still accepts bearer_token parameter."""
client = AsyncReplicate(bearer_token="test_token_123")
assert client.bearer_token == "test_token_123"

def test_sync_client_both_tokens_error(self) -> None:
"""Test that providing both api_token and bearer_token raises an error."""
with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"):
Replicate(api_token="test_api", bearer_token="test_bearer")

def test_async_client_both_tokens_error(self) -> None:
"""Test that providing both api_token and bearer_token raises an error."""
with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"):
AsyncReplicate(api_token="test_api", bearer_token="test_bearer")

def test_sync_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that client reads from environment when no token is provided."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123")
client = Replicate()
assert client.bearer_token == "env_token_123"

def test_async_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that async client reads from environment when no token is provided."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123")
client = AsyncReplicate()
assert client.bearer_token == "env_token_123"

def test_sync_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that client raises error when no token is provided and env is not set."""
monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False)
with pytest.raises(ReplicateError, match="The bearer_token client option must be set"):
Replicate()

def test_async_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that async client raises error when no token is provided and env is not set."""
monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False)
with pytest.raises(ReplicateError, match="The bearer_token client option must be set"):
AsyncReplicate()

def test_legacy_client_alias(self) -> None:
"""Test that legacy Client import still works as an alias."""
assert Client is Replicate

def test_legacy_client_with_api_token(self) -> None:
"""Test that legacy Client alias works with api_token parameter."""
client = Client(api_token="test_token_123")
assert client.bearer_token == "test_token_123"
assert isinstance(client, Replicate)

def test_api_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that explicit api_token overrides environment variable."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token")
client = Replicate(api_token="explicit_token")
assert client.bearer_token == "explicit_token"

def test_bearer_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that explicit bearer_token overrides environment variable."""
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token")
client = Replicate(bearer_token="explicit_token")
assert client.bearer_token == "explicit_token"
Loading