diff --git a/pyproject.toml b/pyproject.toml index 044be7f..ba2a140 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ exclude = [ ".venv", ".nox", ".git", + "tests/test_models_backward_compat.py", ] reportImplicitOverride = true diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 390a552..f2ed308 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -102,6 +102,7 @@ def __init__( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, @@ -124,7 +125,17 @@ def __init__( """Construct a new synchronous Replicate client instance. This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. + + For legacy compatibility, you can also pass `api_token` instead of `bearer_token`. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ReplicateError( + "Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility." + ) + if api_token is not None: + bearer_token = api_token + if bearer_token is None: bearer_token = _get_api_token_from_environment() if bearer_token is None: @@ -175,9 +186,11 @@ def account(self) -> AccountResource: @cached_property def models(self) -> ModelsResource: + from .lib.models import patch_models_resource from .resources.models import ModelsResource - return ModelsResource(self) + models_resource = ModelsResource(self) + return patch_models_resource(models_resource) @cached_property def predictions(self) -> PredictionsResource: @@ -324,6 +337,7 @@ def copy( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.Client | None = None, @@ -337,6 +351,12 @@ def copy( """ Create a new client instance re-using the same options given to the current client with optional overriding. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ValueError("Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility.") + if api_token is not None: + bearer_token = api_token + if default_headers is not None and set_default_headers is not None: raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") @@ -477,6 +497,7 @@ def __init__( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, @@ -499,7 +520,17 @@ def __init__( """Construct a new async AsyncReplicate client instance. This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. + + For legacy compatibility, you can also pass `api_token` instead of `bearer_token`. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ReplicateError( + "Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility." + ) + if api_token is not None: + bearer_token = api_token + if bearer_token is None: bearer_token = _get_api_token_from_environment() if bearer_token is None: @@ -550,9 +581,11 @@ def account(self) -> AsyncAccountResource: @cached_property def models(self) -> AsyncModelsResource: + from .lib.models import patch_models_resource from .resources.models import AsyncModelsResource - return AsyncModelsResource(self) + models_resource = AsyncModelsResource(self) + return patch_models_resource(models_resource) @cached_property def predictions(self) -> AsyncPredictionsResource: @@ -699,6 +732,7 @@ def copy( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.AsyncClient | None = None, @@ -712,6 +746,12 @@ def copy( """ Create a new client instance re-using the same options given to the current client with optional overriding. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ValueError("Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility.") + if api_token is not None: + bearer_token = api_token + if default_headers is not None and set_default_headers is not None: raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") diff --git a/src/replicate/lib/models.py b/src/replicate/lib/models.py new file mode 100644 index 0000000..7306cbe --- /dev/null +++ b/src/replicate/lib/models.py @@ -0,0 +1,131 @@ +""" +Custom models functionality with backward compatibility. +""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Union, overload + +from .._types import NOT_GIVEN, NotGiven +from ._models import ModelVersionIdentifier + +if TYPE_CHECKING: + import httpx + + from .._types import Body, Query, Headers + from ..resources.models.models import ModelsResource, AsyncModelsResource + from ..types.model_get_response import ModelGetResponse + + +def _parse_model_args( + model_or_owner: str | NotGiven, + model_owner: str | NotGiven, + model_name: str | NotGiven, +) -> tuple[str, str]: + """Parse model arguments and return (owner, name).""" + # Handle legacy format: models.get("owner/name") + if model_or_owner is not NOT_GIVEN: + if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN: + raise ValueError( + "Cannot specify both positional 'model_or_owner' and keyword arguments " + "'model_owner'/'model_name'. Use either the legacy format " + "models.get('owner/name') or the new format models.get(model_owner='owner', model_name='name')." + ) + + # Type guard: model_or_owner is definitely a string here + assert isinstance(model_or_owner, str) + + # Parse the owner/name format + if "/" not in model_or_owner: + raise ValueError( + f"Invalid model reference '{model_or_owner}'. " + "Expected format: 'owner/name' (e.g., 'stability-ai/stable-diffusion')" + ) + + try: + parsed = ModelVersionIdentifier.parse(model_or_owner) + return parsed.owner, parsed.name + except ValueError as e: + raise ValueError( + f"Invalid model reference '{model_or_owner}'. " + f"Expected format: 'owner/name' (e.g., 'stability-ai/stable-diffusion'). " + f"Error: {e}" + ) from e + + # Validate required parameters for new format + if model_owner is NOT_GIVEN or model_name is NOT_GIVEN: + raise ValueError( + "model_owner and model_name are required. " + "Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')" + ) + + return model_owner, model_name # type: ignore[return-value] + + +@overload +def patch_models_resource(models_resource: "ModelsResource") -> "ModelsResource": ... + + +@overload +def patch_models_resource(models_resource: "AsyncModelsResource") -> "AsyncModelsResource": ... + + +def patch_models_resource( + models_resource: Union["ModelsResource", "AsyncModelsResource"], +) -> Union["ModelsResource", "AsyncModelsResource"]: + """Patch a models resource to add backward compatibility.""" + original_get = models_resource.get + is_async = inspect.iscoroutinefunction(original_get) + + if is_async: + + async def async_get_wrapper( + model_or_owner: str | NotGiven = NOT_GIVEN, + *, + model_owner: str | NotGiven = NOT_GIVEN, + model_name: str | NotGiven = NOT_GIVEN, + extra_headers: "Headers | None" = None, + extra_query: "Query | None" = None, + extra_body: "Body | None" = None, + timeout: "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN, + ) -> "ModelGetResponse": + owner, name = _parse_model_args(model_or_owner, model_owner, model_name) + return await models_resource._original_get( # type: ignore[misc,no-any-return,attr-defined,union-attr] + model_owner=owner, + model_name=name, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + + wrapper = async_get_wrapper + else: + + def sync_get_wrapper( + model_or_owner: str | NotGiven = NOT_GIVEN, + *, + model_owner: str | NotGiven = NOT_GIVEN, + model_name: str | NotGiven = NOT_GIVEN, + extra_headers: "Headers | None" = None, + extra_query: "Query | None" = None, + extra_body: "Body | None" = None, + timeout: "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN, + ) -> "ModelGetResponse": + owner, name = _parse_model_args(model_or_owner, model_owner, model_name) + return models_resource._original_get( # type: ignore[misc,return-value,attr-defined,union-attr,no-any-return] + model_owner=owner, + model_name=name, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + + wrapper = sync_get_wrapper # type: ignore[assignment] + + # Store original method for tests and replace with wrapper + models_resource._original_get = original_get # type: ignore[attr-defined,union-attr] + models_resource.get = wrapper # type: ignore[method-assign] + return models_resource diff --git a/tests/test_api_token_compatibility.py b/tests/test_api_token_compatibility.py new file mode 100644 index 0000000..f45a541 --- /dev/null +++ b/tests/test_api_token_compatibility.py @@ -0,0 +1,88 @@ +"""Tests for api_token legacy compatibility during client instantiation.""" + +from __future__ import annotations + +import pytest + +from replicate import Replicate, AsyncReplicate, ReplicateError +from replicate._client import Client + + +class TestApiTokenCompatibility: + """Test that api_token parameter works as a legacy compatibility option.""" + + def test_sync_client_with_api_token(self) -> None: + """Test that Replicate accepts api_token parameter.""" + client = Replicate(api_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_async_client_with_api_token(self) -> None: + """Test that AsyncReplicate accepts api_token parameter.""" + client = AsyncReplicate(api_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_sync_client_with_bearer_token(self) -> None: + """Test that Replicate still accepts bearer_token parameter.""" + client = Replicate(bearer_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_async_client_with_bearer_token(self) -> None: + """Test that AsyncReplicate still accepts bearer_token parameter.""" + client = AsyncReplicate(bearer_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_sync_client_both_tokens_error(self) -> None: + """Test that providing both api_token and bearer_token raises an error.""" + with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"): + Replicate(api_token="test_api", bearer_token="test_bearer") + + def test_async_client_both_tokens_error(self) -> None: + """Test that providing both api_token and bearer_token raises an error.""" + with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"): + AsyncReplicate(api_token="test_api", bearer_token="test_bearer") + + def test_sync_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that client reads from environment when no token is provided.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123") + client = Replicate() + assert client.bearer_token == "env_token_123" + + def test_async_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that async client reads from environment when no token is provided.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123") + client = AsyncReplicate() + assert client.bearer_token == "env_token_123" + + def test_sync_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that client raises error when no token is provided and env is not set.""" + monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False) + with pytest.raises(ReplicateError, match="The bearer_token client option must be set"): + Replicate() + + def test_async_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that async client raises error when no token is provided and env is not set.""" + monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False) + with pytest.raises(ReplicateError, match="The bearer_token client option must be set"): + AsyncReplicate() + + def test_legacy_client_alias(self) -> None: + """Test that legacy Client import still works as an alias.""" + assert Client is Replicate + + def test_legacy_client_with_api_token(self) -> None: + """Test that legacy Client alias works with api_token parameter.""" + client = Client(api_token="test_token_123") + assert client.bearer_token == "test_token_123" + assert isinstance(client, Replicate) + + def test_api_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that explicit api_token overrides environment variable.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token") + client = Replicate(api_token="explicit_token") + assert client.bearer_token == "explicit_token" + + def test_bearer_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that explicit bearer_token overrides environment variable.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token") + client = Replicate(bearer_token="explicit_token") + assert client.bearer_token == "explicit_token" diff --git a/tests/test_models_backward_compat.py b/tests/test_models_backward_compat.py new file mode 100644 index 0000000..e834a03 --- /dev/null +++ b/tests/test_models_backward_compat.py @@ -0,0 +1,235 @@ +""" +Tests for backward compatibility in models.get() method. +""" + +from unittest.mock import Mock, AsyncMock + +import pytest + +from replicate import Replicate, AsyncReplicate +from replicate._types import NOT_GIVEN +from replicate.types.model_get_response import ModelGetResponse + + +@pytest.fixture +def mock_model_response(): + """Mock response for model.get requests.""" + return ModelGetResponse( + url="https://replicate.com/stability-ai/stable-diffusion", + owner="stability-ai", + name="stable-diffusion", + description="A model for generating images from text prompts", + visibility="public", + github_url=None, + paper_url=None, + license_url=None, + run_count=0, + cover_image_url=None, + default_example=None, + latest_version=None, + ) + + +class TestModelGetBackwardCompatibility: + """Test backward compatibility for models.get() method.""" + + @pytest.fixture + def client(self): + """Create a Replicate client with mocked token.""" + return Replicate(bearer_token="test-token") + + def test_legacy_format_owner_name(self, client, mock_model_response): + """Test legacy format: models.get('owner/name').""" + # Mock the original get method + client.models._original_get = Mock(return_value=mock_model_response) + + # Call with legacy format + result = client.models.get("stability-ai/stable-diffusion") + + # Verify original method was called with correct parameters + client.models._original_get.assert_called_once_with( + model_owner="stability-ai", + model_name="stable-diffusion", + extra_headers=None, + extra_query=None, + extra_body=None, + timeout=NOT_GIVEN, + ) + assert result == mock_model_response + + def test_new_format_keyword_args(self, client, mock_model_response): + """Test new format: models.get(model_owner='owner', model_name='name').""" + # Mock the original get method + client.models._original_get = Mock(return_value=mock_model_response) + + # Call with new format + result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion") + + # Verify original method was called with correct parameters + client.models._original_get.assert_called_once_with( + model_owner="stability-ai", + model_name="stable-diffusion", + extra_headers=None, + extra_query=None, + extra_body=None, + timeout=NOT_GIVEN, + ) + assert result == mock_model_response + + def test_legacy_format_with_extra_params(self, client, mock_model_response): + """Test legacy format with extra parameters.""" + # Mock the original get method + client.models._original_get = Mock(return_value=mock_model_response) + + # Call with legacy format and extra parameters + result = client.models.get("stability-ai/stable-diffusion", extra_headers={"X-Custom": "test"}, timeout=30.0) + + # Verify original method was called with correct parameters + client.models._original_get.assert_called_once_with( + model_owner="stability-ai", + model_name="stable-diffusion", + extra_headers={"X-Custom": "test"}, + extra_query=None, + extra_body=None, + timeout=30.0, + ) + assert result == mock_model_response + + def test_error_mixed_formats(self, client): + """Test error when mixing legacy and new formats.""" + with pytest.raises(ValueError) as exc_info: + client.models.get("stability-ai/stable-diffusion", model_owner="other-owner") + + assert "Cannot specify both positional 'model_or_owner' and keyword arguments" in str(exc_info.value) + + def test_error_invalid_legacy_format(self, client): + """Test error for invalid legacy format (no slash).""" + with pytest.raises(ValueError) as exc_info: + client.models.get("invalid-format") + + assert "Invalid model reference 'invalid-format'" in str(exc_info.value) + assert "Expected format: 'owner/name'" in str(exc_info.value) + + def test_error_missing_parameters(self, client): + """Test error when no parameters are provided.""" + with pytest.raises(ValueError) as exc_info: + client.models.get() + + assert "model_owner and model_name are required" in str(exc_info.value) + + def test_legacy_format_with_complex_names(self, client, mock_model_response): + """Test legacy format with complex owner/model names.""" + # Mock the original get method + client.models._original_get = Mock(return_value=mock_model_response) + + # Test with hyphenated names and numbers + result = client.models.get("black-forest-labs/flux-1.1-pro") + + # Verify parsing + client.models._original_get.assert_called_once_with( + model_owner="black-forest-labs", + model_name="flux-1.1-pro", + extra_headers=None, + extra_query=None, + extra_body=None, + timeout=NOT_GIVEN, + ) + + def test_legacy_format_multiple_slashes_error(self, client): + """Test error for legacy format with multiple slashes.""" + with pytest.raises(ValueError) as exc_info: + client.models.get("owner/name/version") + + assert "Invalid model reference" in str(exc_info.value) + + +class TestAsyncModelGetBackwardCompatibility: + """Test backward compatibility for async models.get() method.""" + + @pytest.fixture + async def async_client(self): + """Create an async Replicate client with mocked token.""" + return AsyncReplicate(bearer_token="test-token") + + @pytest.mark.asyncio + async def test_async_legacy_format_owner_name(self, async_client, mock_model_response): + """Test async legacy format: models.get('owner/name').""" + # Mock the original async get method + async_client.models._original_get = AsyncMock(return_value=mock_model_response) + + # Call with legacy format + result = await async_client.models.get("stability-ai/stable-diffusion") + + # Verify original method was called with correct parameters + async_client.models._original_get.assert_called_once_with( + model_owner="stability-ai", + model_name="stable-diffusion", + extra_headers=None, + extra_query=None, + extra_body=None, + timeout=NOT_GIVEN, + ) + assert result == mock_model_response + + @pytest.mark.asyncio + async def test_async_new_format_keyword_args(self, async_client, mock_model_response): + """Test async new format: models.get(model_owner='owner', model_name='name').""" + # Mock the original async get method + async_client.models._original_get = AsyncMock(return_value=mock_model_response) + + # Call with new format + result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion") + + # Verify original method was called with correct parameters + async_client.models._original_get.assert_called_once_with( + model_owner="stability-ai", + model_name="stable-diffusion", + extra_headers=None, + extra_query=None, + extra_body=None, + timeout=NOT_GIVEN, + ) + assert result == mock_model_response + + @pytest.mark.asyncio + async def test_async_error_mixed_formats(self, async_client): + """Test async error when mixing legacy and new formats.""" + with pytest.raises(ValueError) as exc_info: + await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner") + + assert "Cannot specify both positional 'model_or_owner' and keyword arguments" in str(exc_info.value) + + +class TestModelVersionIdentifierIntegration: + """Test integration with ModelVersionIdentifier parsing.""" + + @pytest.fixture + def client(self): + """Create a Replicate client with mocked token.""" + return Replicate(bearer_token="test-token") + + def test_legacy_format_parsing_edge_cases(self, client, mock_model_response): + """Test edge cases in legacy format parsing.""" + # Mock the original get method + client.models._original_get = Mock(return_value=mock_model_response) + + # Test various valid formats + test_cases = [ + ("owner/name", "owner", "name"), + ("owner-with-hyphens/name-with-hyphens", "owner-with-hyphens", "name-with-hyphens"), + ("owner123/name456", "owner123", "name456"), + ("owner/name.with.dots", "owner", "name.with.dots"), + ] + + for model_ref, expected_owner, expected_name in test_cases: + client.models._original_get.reset_mock() + client.models.get(model_ref) + + client.models._original_get.assert_called_once_with( + model_owner=expected_owner, + model_name=expected_name, + extra_headers=None, + extra_query=None, + extra_body=None, + timeout=NOT_GIVEN, + )