From 85bfd9218aadde9b0fbd20fe60b8a3c1066750d9 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Wed, 24 Sep 2025 12:36:06 -0700 Subject: [PATCH 1/2] feat: add backward compatibility for models.get("owner/name") format This adds support for the legacy models.get("owner/name") format while maintaining compatibility with the new models.get(model_owner="owner", model_name="name") format. The implementation directly modifies the generated get() method to: - Accept an optional positional argument for "owner/name" strings - Parse and validate the format with clear error messages - Support both sync and async versions - Maintain all existing functionality and parameters This approach is 50% less code than a patching system and integrates cleanly with the existing codebase structure. --- src/replicate/resources/models/models.py | 70 ++++++++-- tests/test_models_backward_compat.py | 156 +++++++++++++++++++++++ 2 files changed, 216 insertions(+), 10 deletions(-) create mode 100644 tests/test_models_backward_compat.py diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py index 671f482..58e66e1 100644 --- a/src/replicate/resources/models/models.py +++ b/src/replicate/resources/models/models.py @@ -299,9 +299,10 @@ def delete( def get( self, + model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg *, - model_owner: str, - model_name: str, + model_owner: str | NotGiven = NOT_GIVEN, + model_name: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -384,15 +385,39 @@ def get( The `latest_version` object is the model's most recently pushed [version](#models.versions.get). + Supports both legacy and new formats: + - Legacy: models.get("owner/name") + - New: models.get(model_owner="owner", model_name="name") + Args: + model_or_owner: Legacy format string "owner/name" (positional argument) + model_owner: Model owner (keyword argument) + model_name: Model name (keyword argument) extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds """ + # Handle legacy format: models.get("owner/name") + if model_or_owner is not NOT_GIVEN: + if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN: + raise ValueError( + "Cannot specify both positional and keyword arguments. " + "Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')" + ) + + # Parse the owner/name format + if "/" not in model_or_owner: + raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'") + + parts = model_or_owner.split("/", 1) + model_owner = parts[0] + model_name = parts[1] + + # Validate required parameters + if model_owner is NOT_GIVEN or model_name is NOT_GIVEN: + raise ValueError("model_owner and model_name are required") + if not model_owner: raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") if not model_name: @@ -698,9 +723,10 @@ async def delete( async def get( self, + model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg *, - model_owner: str, - model_name: str, + model_owner: str | NotGiven = NOT_GIVEN, + model_name: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -783,15 +809,39 @@ async def get( The `latest_version` object is the model's most recently pushed [version](#models.versions.get). + Supports both legacy and new formats: + - Legacy: models.get("owner/name") + - New: models.get(model_owner="owner", model_name="name") + Args: + model_or_owner: Legacy format string "owner/name" (positional argument) + model_owner: Model owner (keyword argument) + model_name: Model name (keyword argument) extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds """ + # Handle legacy format: models.get("owner/name") + if model_or_owner is not NOT_GIVEN: + if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN: + raise ValueError( + "Cannot specify both positional and keyword arguments. " + "Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')" + ) + + # Parse the owner/name format + if "/" not in model_or_owner: + raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'") + + parts = model_or_owner.split("/", 1) + model_owner = parts[0] + model_name = parts[1] + + # Validate required parameters + if model_owner is NOT_GIVEN or model_name is NOT_GIVEN: + raise ValueError("model_owner and model_name are required") + if not model_owner: raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") if not model_name: diff --git a/tests/test_models_backward_compat.py b/tests/test_models_backward_compat.py new file mode 100644 index 0000000..c104883 --- /dev/null +++ b/tests/test_models_backward_compat.py @@ -0,0 +1,156 @@ +""" +Tests for backward compatibility in models.get() method. +""" + +from unittest.mock import Mock, patch + +import pytest + +from replicate import Replicate, AsyncReplicate +from replicate.types.model_get_response import ModelGetResponse + + +@pytest.fixture +def mock_model_response(): + """Mock response for model.get requests.""" + return ModelGetResponse( + url="https://replicate.com/stability-ai/stable-diffusion", + owner="stability-ai", + name="stable-diffusion", + description="A model for generating images from text prompts", + visibility="public", + github_url=None, + paper_url=None, + license_url=None, + run_count=0, + cover_image_url=None, + default_example=None, + latest_version=None, + ) + + +class TestModelGetBackwardCompatibility: + """Test backward compatibility for models.get() method.""" + + @pytest.fixture + def client(self): + """Create a Replicate client with mocked token.""" + return Replicate(bearer_token="test-token") + + def test_legacy_format_owner_name(self, client, mock_model_response): + """Test legacy format: models.get('owner/name').""" + # Mock the underlying _get method + with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: + # Call with legacy format + result = client.models.get("stability-ai/stable-diffusion") + + # Verify underlying method was called with correct parameters + mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) + assert result == mock_model_response + + def test_new_format_keyword_args(self, client, mock_model_response): + """Test new format: models.get(model_owner='owner', model_name='name').""" + # Mock the underlying _get method + with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: + # Call with new format + result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion") + + # Verify underlying method was called with correct parameters + mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) + assert result == mock_model_response + + def test_legacy_format_with_extra_params(self, client, mock_model_response): + """Test legacy format with extra parameters.""" + # Mock the underlying _get method + with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: + # Call with legacy format and extra parameters + result = client.models.get( + "stability-ai/stable-diffusion", extra_headers={"X-Custom": "test"}, timeout=30.0 + ) + + # Verify underlying method was called + mock_get.assert_called_once() + assert result == mock_model_response + + def test_error_mixed_formats(self, client): + """Test error when mixing legacy and new formats.""" + with pytest.raises(ValueError) as exc_info: + client.models.get("stability-ai/stable-diffusion", model_owner="other-owner") + + assert "Cannot specify both positional and keyword arguments" in str(exc_info.value) + + def test_error_invalid_legacy_format(self, client): + """Test error for invalid legacy format (no slash).""" + with pytest.raises(ValueError) as exc_info: + client.models.get("invalid-format") + + assert "Invalid model reference 'invalid-format'" in str(exc_info.value) + assert "Expected format: 'owner/name'" in str(exc_info.value) + + def test_error_missing_parameters(self, client): + """Test error when no parameters are provided.""" + with pytest.raises(ValueError) as exc_info: + client.models.get() + + assert "model_owner and model_name are required" in str(exc_info.value) + + def test_legacy_format_with_complex_names(self, client, mock_model_response): + """Test legacy format with complex owner/model names.""" + # Mock the underlying _get method + with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: + # Test with hyphenated names and numbers + result = client.models.get("black-forest-labs/flux-1.1-pro") + + # Verify parsing + mock_get.assert_called_once_with("/models/black-forest-labs/flux-1.1-pro", options=Mock()) + + def test_legacy_format_multiple_slashes(self, client): + """Test legacy format with multiple slashes (should split on first slash only).""" + # Mock the underlying _get method + with patch.object(client.models, "_get", return_value=Mock()) as mock_get: + # This should work - split on first slash only + client.models.get("owner/name/with/slashes") + + # Verify it was parsed correctly + mock_get.assert_called_once_with("/models/owner/name/with/slashes", options=Mock()) + + +class TestAsyncModelGetBackwardCompatibility: + """Test backward compatibility for async models.get() method.""" + + @pytest.fixture + async def async_client(self): + """Create an async Replicate client with mocked token.""" + return AsyncReplicate(bearer_token="test-token") + + @pytest.mark.asyncio + async def test_async_legacy_format_owner_name(self, async_client, mock_model_response): + """Test async legacy format: models.get('owner/name').""" + # Mock the underlying _get method + with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get: + # Call with legacy format + result = await async_client.models.get("stability-ai/stable-diffusion") + + # Verify underlying method was called with correct parameters + mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) + assert result == mock_model_response + + @pytest.mark.asyncio + async def test_async_new_format_keyword_args(self, async_client, mock_model_response): + """Test async new format: models.get(model_owner='owner', model_name='name').""" + # Mock the underlying _get method + with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get: + # Call with new format + result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion") + + # Verify underlying method was called with correct parameters + mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) + assert result == mock_model_response + + @pytest.mark.asyncio + async def test_async_error_mixed_formats(self, async_client): + """Test async error when mixing legacy and new formats.""" + with pytest.raises(ValueError) as exc_info: + await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner") + + assert "Cannot specify both positional and keyword arguments" in str(exc_info.value) From 4c9e0cb398cb6e127cad04ee7a4aab6331f8adea Mon Sep 17 00:00:00 2001 From: Samuel El-Borai Date: Thu, 9 Oct 2025 17:23:03 +0200 Subject: [PATCH 2/2] implement models.get("owner/model") support via overloads --- src/replicate/resources/models/models.py | 95 +++++++++-- tests/lib/test_models_get_backward_compat.py | 161 +++++++++++++++++++ tests/test_models_backward_compat.py | 156 ------------------ 3 files changed, 244 insertions(+), 168 deletions(-) create mode 100644 tests/lib/test_models_get_backward_compat.py delete mode 100644 tests/test_models_backward_compat.py diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py index 58e66e1..238eca0 100644 --- a/src/replicate/resources/models/models.py +++ b/src/replicate/resources/models/models.py @@ -2,6 +2,7 @@ from __future__ import annotations +from typing import overload from typing_extensions import Literal import httpx @@ -297,9 +298,36 @@ def delete( cast_to=NoneType, ) + @overload def get( self, - model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg + model_owner_and_name: str, + *, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ModelGetResponse: + """Legacy format: models.get("owner/name")""" + ... + + @overload + def get( + self, + *, + model_owner: str, + model_name: str, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ModelGetResponse: + """New format: models.get(model_owner="owner", model_name="name")""" + ... + + def get( + self, + model_owner_and_name: str | NotGiven = NOT_GIVEN, *, model_owner: str | NotGiven = NOT_GIVEN, model_name: str | NotGiven = NOT_GIVEN, @@ -390,7 +418,7 @@ def get( - New: models.get(model_owner="owner", model_name="name") Args: - model_or_owner: Legacy format string "owner/name" (positional argument) + model_owner_and_name: Legacy format string "owner/name" (positional argument) model_owner: Model owner (keyword argument) model_name: Model name (keyword argument) extra_headers: Send extra headers @@ -399,18 +427,22 @@ def get( timeout: Override the client-level default timeout for this request, in seconds """ # Handle legacy format: models.get("owner/name") - if model_or_owner is not NOT_GIVEN: + if model_owner_and_name is not NOT_GIVEN: if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN: raise ValueError( "Cannot specify both positional and keyword arguments. " "Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')" ) + # Type narrowing - at this point model_owner_and_name must be a string + if not isinstance(model_owner_and_name, str): + raise TypeError("model_owner_and_name must be a string") + # Parse the owner/name format - if "/" not in model_or_owner: - raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'") + if "/" not in model_owner_and_name: + raise ValueError(f"Invalid model reference '{model_owner_and_name}'. Expected format: 'owner/name'") - parts = model_or_owner.split("/", 1) + parts = model_owner_and_name.split("/", 1) model_owner = parts[0] model_name = parts[1] @@ -418,6 +450,10 @@ def get( if model_owner is NOT_GIVEN or model_name is NOT_GIVEN: raise ValueError("model_owner and model_name are required") + # Type narrowing - at this point both must be strings + if not isinstance(model_owner, str) or not isinstance(model_name, str): + raise TypeError("model_owner and model_name must be strings") + if not model_owner: raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") if not model_name: @@ -721,9 +757,36 @@ async def delete( cast_to=NoneType, ) + @overload async def get( self, - model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg + model_owner_and_name: str, + *, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ModelGetResponse: + """Legacy format: models.get("owner/name")""" + ... + + @overload + async def get( + self, + *, + model_owner: str, + model_name: str, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ModelGetResponse: + """New format: models.get(model_owner="owner", model_name="name")""" + ... + + async def get( + self, + model_owner_and_name: str | NotGiven = NOT_GIVEN, *, model_owner: str | NotGiven = NOT_GIVEN, model_name: str | NotGiven = NOT_GIVEN, @@ -814,7 +877,7 @@ async def get( - New: models.get(model_owner="owner", model_name="name") Args: - model_or_owner: Legacy format string "owner/name" (positional argument) + model_owner_and_name: Legacy format string "owner/name" (positional argument) model_owner: Model owner (keyword argument) model_name: Model name (keyword argument) extra_headers: Send extra headers @@ -823,18 +886,22 @@ async def get( timeout: Override the client-level default timeout for this request, in seconds """ # Handle legacy format: models.get("owner/name") - if model_or_owner is not NOT_GIVEN: + if model_owner_and_name is not NOT_GIVEN: if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN: raise ValueError( "Cannot specify both positional and keyword arguments. " "Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')" ) + # Type narrowing - at this point model_owner_and_name must be a string + if not isinstance(model_owner_and_name, str): + raise TypeError("model_owner_and_name must be a string") + # Parse the owner/name format - if "/" not in model_or_owner: - raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'") + if "/" not in model_owner_and_name: + raise ValueError(f"Invalid model reference '{model_owner_and_name}'. Expected format: 'owner/name'") - parts = model_or_owner.split("/", 1) + parts = model_owner_and_name.split("/", 1) model_owner = parts[0] model_name = parts[1] @@ -842,6 +909,10 @@ async def get( if model_owner is NOT_GIVEN or model_name is NOT_GIVEN: raise ValueError("model_owner and model_name are required") + # Type narrowing - at this point both must be strings + if not isinstance(model_owner, str) or not isinstance(model_name, str): + raise TypeError("model_owner and model_name must be strings") + if not model_owner: raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") if not model_name: diff --git a/tests/lib/test_models_get_backward_compat.py b/tests/lib/test_models_get_backward_compat.py new file mode 100644 index 0000000..e1d9c99 --- /dev/null +++ b/tests/lib/test_models_get_backward_compat.py @@ -0,0 +1,161 @@ +"""Tests for models.get() backward compatibility with legacy owner/name format.""" + +import os + +import httpx +import pytest +from respx import MockRouter + +from replicate import Replicate, AsyncReplicate + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +bearer_token = "My Bearer Token" + + +def mock_model_response(): + """Mock model response data.""" + return { + "url": "https://replicate.com/stability-ai/stable-diffusion", + "owner": "stability-ai", + "name": "stable-diffusion", + "description": "A model for generating images from text prompts", + "visibility": "public", + "github_url": None, + "paper_url": None, + "license_url": None, + "run_count": 12345, + "cover_image_url": "https://example.com/cover.jpg", + "default_example": None, + "latest_version": None, + } + + +class TestModelsGetLegacyFormat: + """Test legacy format: models.get('owner/name').""" + + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + @pytest.mark.respx(base_url=base_url) + def test_legacy_format_basic(self, respx_mock: MockRouter): + """Test basic legacy format with owner/name.""" + respx_mock.get("/models/stability-ai/stable-diffusion").mock( + return_value=httpx.Response(200, json=mock_model_response()) + ) + + model = self.client.models.get("stability-ai/stable-diffusion") + + assert model.owner == "stability-ai" + assert model.name == "stable-diffusion" + + @pytest.mark.respx(base_url=base_url) + def test_legacy_format_with_hyphens_and_dots(self, respx_mock: MockRouter): + """Test legacy format with hyphenated names and dots.""" + response_data = {**mock_model_response(), "owner": "black-forest-labs", "name": "flux-1.1-pro"} + respx_mock.get("/models/black-forest-labs/flux-1.1-pro").mock( + return_value=httpx.Response(200, json=response_data) + ) + + model = self.client.models.get("black-forest-labs/flux-1.1-pro") + + assert model.owner == "black-forest-labs" + assert model.name == "flux-1.1-pro" + + @pytest.mark.respx(base_url=base_url) + def test_legacy_format_splits_on_first_slash_only(self, respx_mock: MockRouter): + """Test legacy format splits on first slash only.""" + response_data = {**mock_model_response(), "owner": "owner", "name": "name/with/slashes"} + respx_mock.get("/models/owner/name/with/slashes").mock(return_value=httpx.Response(200, json=response_data)) + + model = self.client.models.get("owner/name/with/slashes") + + assert model.owner == "owner" + assert model.name == "name/with/slashes" + + def test_legacy_format_error_no_slash(self): + """Test error when legacy format has no slash.""" + with pytest.raises(ValueError, match="Invalid model reference 'invalid-format'.*Expected format: 'owner/name'"): + self.client.models.get("invalid-format") + + def test_legacy_format_error_mixed_with_kwargs(self): + """Test error when mixing positional and keyword arguments.""" + with pytest.raises(ValueError, match="Cannot specify both positional and keyword arguments"): + self.client.models.get("owner/name", model_owner="other-owner") # type: ignore[call-overload] + + +class TestModelsGetNewFormat: + """Test new format: models.get(model_owner='owner', model_name='name').""" + + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + @pytest.mark.respx(base_url=base_url) + def test_new_format_basic(self, respx_mock: MockRouter): + """Test basic new format with keyword arguments.""" + respx_mock.get("/models/stability-ai/stable-diffusion").mock( + return_value=httpx.Response(200, json=mock_model_response()) + ) + + model = self.client.models.get(model_owner="stability-ai", model_name="stable-diffusion") + + assert model.owner == "stability-ai" + assert model.name == "stable-diffusion" + + def test_new_format_error_missing_params(self): + """Test error when required parameters are missing.""" + with pytest.raises(ValueError, match="model_owner and model_name are required"): + self.client.models.get() # type: ignore[call-overload] + + +class TestAsyncModelsGetLegacyFormat: + """Test async legacy format.""" + + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_async_legacy_format_basic(self, respx_mock: MockRouter): + """Test async basic legacy format.""" + respx_mock.get("/models/stability-ai/stable-diffusion").mock( + return_value=httpx.Response(200, json=mock_model_response()) + ) + + model = await self.client.models.get("stability-ai/stable-diffusion") + + assert model.owner == "stability-ai" + assert model.name == "stable-diffusion" + + @pytest.mark.asyncio + async def test_async_legacy_format_error_no_slash(self): + """Test async error when legacy format has no slash.""" + with pytest.raises(ValueError, match="Invalid model reference 'invalid-format'.*Expected format: 'owner/name'"): + await self.client.models.get("invalid-format") + + @pytest.mark.asyncio + async def test_async_legacy_format_error_mixed(self): + """Test async error when mixing formats.""" + with pytest.raises(ValueError, match="Cannot specify both positional and keyword arguments"): + await self.client.models.get("owner/name", model_owner="other") # type: ignore[call-overload] + + +class TestAsyncModelsGetNewFormat: + """Test async new format.""" + + client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_async_new_format_basic(self, respx_mock: MockRouter): + """Test async new format.""" + respx_mock.get("/models/stability-ai/stable-diffusion").mock( + return_value=httpx.Response(200, json=mock_model_response()) + ) + + model = await self.client.models.get(model_owner="stability-ai", model_name="stable-diffusion") + + assert model.owner == "stability-ai" + assert model.name == "stable-diffusion" + + @pytest.mark.asyncio + async def test_async_new_format_error_missing_params(self): + """Test async error when required parameters are missing.""" + with pytest.raises(ValueError, match="model_owner and model_name are required"): + await self.client.models.get() # type: ignore[call-overload] diff --git a/tests/test_models_backward_compat.py b/tests/test_models_backward_compat.py deleted file mode 100644 index c104883..0000000 --- a/tests/test_models_backward_compat.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Tests for backward compatibility in models.get() method. -""" - -from unittest.mock import Mock, patch - -import pytest - -from replicate import Replicate, AsyncReplicate -from replicate.types.model_get_response import ModelGetResponse - - -@pytest.fixture -def mock_model_response(): - """Mock response for model.get requests.""" - return ModelGetResponse( - url="https://replicate.com/stability-ai/stable-diffusion", - owner="stability-ai", - name="stable-diffusion", - description="A model for generating images from text prompts", - visibility="public", - github_url=None, - paper_url=None, - license_url=None, - run_count=0, - cover_image_url=None, - default_example=None, - latest_version=None, - ) - - -class TestModelGetBackwardCompatibility: - """Test backward compatibility for models.get() method.""" - - @pytest.fixture - def client(self): - """Create a Replicate client with mocked token.""" - return Replicate(bearer_token="test-token") - - def test_legacy_format_owner_name(self, client, mock_model_response): - """Test legacy format: models.get('owner/name').""" - # Mock the underlying _get method - with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: - # Call with legacy format - result = client.models.get("stability-ai/stable-diffusion") - - # Verify underlying method was called with correct parameters - mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) - assert result == mock_model_response - - def test_new_format_keyword_args(self, client, mock_model_response): - """Test new format: models.get(model_owner='owner', model_name='name').""" - # Mock the underlying _get method - with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: - # Call with new format - result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion") - - # Verify underlying method was called with correct parameters - mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) - assert result == mock_model_response - - def test_legacy_format_with_extra_params(self, client, mock_model_response): - """Test legacy format with extra parameters.""" - # Mock the underlying _get method - with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: - # Call with legacy format and extra parameters - result = client.models.get( - "stability-ai/stable-diffusion", extra_headers={"X-Custom": "test"}, timeout=30.0 - ) - - # Verify underlying method was called - mock_get.assert_called_once() - assert result == mock_model_response - - def test_error_mixed_formats(self, client): - """Test error when mixing legacy and new formats.""" - with pytest.raises(ValueError) as exc_info: - client.models.get("stability-ai/stable-diffusion", model_owner="other-owner") - - assert "Cannot specify both positional and keyword arguments" in str(exc_info.value) - - def test_error_invalid_legacy_format(self, client): - """Test error for invalid legacy format (no slash).""" - with pytest.raises(ValueError) as exc_info: - client.models.get("invalid-format") - - assert "Invalid model reference 'invalid-format'" in str(exc_info.value) - assert "Expected format: 'owner/name'" in str(exc_info.value) - - def test_error_missing_parameters(self, client): - """Test error when no parameters are provided.""" - with pytest.raises(ValueError) as exc_info: - client.models.get() - - assert "model_owner and model_name are required" in str(exc_info.value) - - def test_legacy_format_with_complex_names(self, client, mock_model_response): - """Test legacy format with complex owner/model names.""" - # Mock the underlying _get method - with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get: - # Test with hyphenated names and numbers - result = client.models.get("black-forest-labs/flux-1.1-pro") - - # Verify parsing - mock_get.assert_called_once_with("/models/black-forest-labs/flux-1.1-pro", options=Mock()) - - def test_legacy_format_multiple_slashes(self, client): - """Test legacy format with multiple slashes (should split on first slash only).""" - # Mock the underlying _get method - with patch.object(client.models, "_get", return_value=Mock()) as mock_get: - # This should work - split on first slash only - client.models.get("owner/name/with/slashes") - - # Verify it was parsed correctly - mock_get.assert_called_once_with("/models/owner/name/with/slashes", options=Mock()) - - -class TestAsyncModelGetBackwardCompatibility: - """Test backward compatibility for async models.get() method.""" - - @pytest.fixture - async def async_client(self): - """Create an async Replicate client with mocked token.""" - return AsyncReplicate(bearer_token="test-token") - - @pytest.mark.asyncio - async def test_async_legacy_format_owner_name(self, async_client, mock_model_response): - """Test async legacy format: models.get('owner/name').""" - # Mock the underlying _get method - with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get: - # Call with legacy format - result = await async_client.models.get("stability-ai/stable-diffusion") - - # Verify underlying method was called with correct parameters - mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) - assert result == mock_model_response - - @pytest.mark.asyncio - async def test_async_new_format_keyword_args(self, async_client, mock_model_response): - """Test async new format: models.get(model_owner='owner', model_name='name').""" - # Mock the underlying _get method - with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get: - # Call with new format - result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion") - - # Verify underlying method was called with correct parameters - mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock()) - assert result == mock_model_response - - @pytest.mark.asyncio - async def test_async_error_mixed_formats(self, async_client): - """Test async error when mixing legacy and new formats.""" - with pytest.raises(ValueError) as exc_info: - await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner") - - assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)