From caf4c4efa2be271144b22b93a38ea490b10ad86b Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Thu, 28 Aug 2025 10:03:11 -0700 Subject: [PATCH 1/2] fix: implement lazy client creation in replicate.use() (#57) * fix: implement lazy client creation in replicate.use() Fixes issue where replicate.use() would fail if no API token was available at call time, even when token becomes available later (e.g., from cog.current_scope). Changes: - Modified Function/AsyncFunction classes to accept client factories - Added _client property that creates client on demand - Updated module client to pass factory functions instead of instances - Token is now retrieved from current scope when model is called This maintains full backward compatibility while enabling use in Cog pipelines where tokens are provided through the execution context. * style: fix linter issues - Remove unused *args parameter in test function - Fix formatting issues from linter * fix: resolve async detection and test issues - Fix async detection to not call client factory prematurely - Add use_async parameter to explicitly indicate async mode - Update test to avoid creating client during verification - Fix test mocking to use correct module path * test: simplify lazy client test Replace complex mocking test with simpler verification that: - use() works without token initially - Lazy client factory is properly configured - Client can be created when needed This avoids complex mocking while still verifying the core functionality. * lint * fix: add type ignore for final linter warning * fix: add arg-type ignore for type checker warnings * refactor: simplify lazy client creation to use Type[Client] only Address PR feedback by removing Union types and using a single consistent approach: - Change Function/AsyncFunction constructors to accept Type[Client] only - Remove Union[Client, Type[Client]] in favor of just Type[Client] - Simplify _client property logic by removing isinstance checks - Update all use() overloads to accept class types only - Use issubclass() for async client detection instead of complex logic - Update tests to check for _client_class attribute This maintains the same lazy client creation behavior while being much simpler and more consistent. * Update tests/test_simple_lazy.py Co-authored-by: Aron Carroll * test: improve lazy client test to follow project conventions - Remove verbose comments and print statements - Focus on observable behavior rather than internal implementation - Use proper mocking that matches actual cog integration - Test that cog.current_scope() is called on client creation - Address code review feedback from PR discussion * lint * lint --------- Co-authored-by: Aron Carroll --- src/replicate/_module_client.py | 10 ++++-- src/replicate/lib/_predictions_use.py | 31 +++++++++------- tests/test_simple_lazy.py | 52 +++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 15 deletions(-) create mode 100644 tests/test_simple_lazy.py diff --git a/src/replicate/_module_client.py b/src/replicate/_module_client.py index 817c605..a3e8ab4 100644 --- a/src/replicate/_module_client.py +++ b/src/replicate/_module_client.py @@ -88,13 +88,17 @@ def _run(*args, **kwargs): return _load_client().run(*args, **kwargs) def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs): + from .lib._predictions_use import use + if use_async: # For async, we need to use AsyncReplicate instead from ._client import AsyncReplicate - client = AsyncReplicate() - return client.use(ref, hint=hint, streaming=streaming, **kwargs) - return _load_client().use(ref, hint=hint, streaming=streaming, **kwargs) + return use(AsyncReplicate, ref, hint=hint, streaming=streaming, **kwargs) + + from ._client import Replicate + + return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs) run = _run use = _use diff --git a/src/replicate/lib/_predictions_use.py b/src/replicate/lib/_predictions_use.py index 606bbee..1cd085c 100644 --- a/src/replicate/lib/_predictions_use.py +++ b/src/replicate/lib/_predictions_use.py @@ -9,6 +9,7 @@ Any, Dict, List, + Type, Tuple, Union, Generic, @@ -436,15 +437,18 @@ class Function(Generic[Input, Output]): A wrapper for a Replicate model that can be called as a function. """ - _client: Client _ref: str _streaming: bool - def __init__(self, client: Client, ref: str, *, streaming: bool) -> None: - self._client = client + def __init__(self, client: Type[Client], ref: str, *, streaming: bool) -> None: + self._client_class = client self._ref = ref self._streaming = streaming + @property + def _client(self) -> Client: + return self._client_class() + def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output: return self.create(*args, **inputs).output() @@ -666,16 +670,19 @@ class AsyncFunction(Generic[Input, Output]): An async wrapper for a Replicate model that can be called as a function. """ - _client: AsyncClient _ref: str _streaming: bool _openapi_schema: Optional[Dict[str, Any]] = None - def __init__(self, client: AsyncClient, ref: str, *, streaming: bool) -> None: - self._client = client + def __init__(self, client: Type[AsyncClient], ref: str, *, streaming: bool) -> None: + self._client_class = client self._ref = ref self._streaming = streaming + @property + def _client(self) -> AsyncClient: + return self._client_class() + @cached_property def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: return ModelVersionIdentifier.parse(self._ref) @@ -804,7 +811,7 @@ async def openapi_schema(self) -> Dict[str, Any]: @overload def use( - client: Client, + client: Type[Client], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -814,7 +821,7 @@ def use( @overload def use( - client: Client, + client: Type[Client], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -824,7 +831,7 @@ def use( @overload def use( - client: AsyncClient, + client: Type[AsyncClient], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -834,7 +841,7 @@ def use( @overload def use( - client: AsyncClient, + client: Type[AsyncClient], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -843,7 +850,7 @@ def use( def use( - client: Union[Client, AsyncClient], + client: Union[Type[Client], Type[AsyncClient]], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, # pylint: disable=unused-argument # noqa: ARG001 # required for type inference @@ -868,7 +875,7 @@ def use( except AttributeError: pass - if isinstance(client, AsyncClient): + if issubclass(client, AsyncClient): # TODO: Fix type inference for AsyncFunction return type return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value] diff --git a/tests/test_simple_lazy.py b/tests/test_simple_lazy.py new file mode 100644 index 0000000..312c9fe --- /dev/null +++ b/tests/test_simple_lazy.py @@ -0,0 +1,52 @@ +"""Test lazy client creation in replicate.use().""" + +import os +import sys +from unittest.mock import MagicMock, patch + + +def test_use_does_not_raise_without_token(): + """Test that replicate.use() works even when no API token is available.""" + sys.path.insert(0, "src") + + with patch.dict(os.environ, {}, clear=True): + with patch.dict(sys.modules, {"cog": None}): + import replicate + + # Should not raise an exception + model = replicate.use("test/model") # type: ignore[misc] + assert model is not None + + +def test_cog_current_scope(): + """Test that cog.current_scope().context is read on each client creation.""" + sys.path.insert(0, "src") + + mock_context = MagicMock() + mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-1")] + + mock_scope = MagicMock() + mock_scope.context = mock_context + + mock_cog = MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with patch.dict(os.environ, {}, clear=True): + with patch.dict(sys.modules, {"cog": mock_cog}): + import replicate + + model = replicate.use("test/model") # type: ignore[misc] + + # Access the client property - this should trigger client creation and cog.current_scope call + _ = model._client + + assert mock_cog.current_scope.call_count == 1 + + # Change the token and access client again - should trigger another call + mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-2")] + + # Create a new model to trigger another client creation + model2 = replicate.use("test/model2") # type: ignore[misc] + _ = model2._client + + assert mock_cog.current_scope.call_count == 2 From 29be63ecc4596a1080ac05d4d8c44bcef64b7d7e Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:03:27 +0000 Subject: [PATCH 2/2] release: 2.0.0-alpha.22 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 8 ++++++++ pyproject.toml | 2 +- src/replicate/_version.py | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index d700ccd..588fa32 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.0.0-alpha.21" + ".": "2.0.0-alpha.22" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 71ae0d9..a057384 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 2.0.0-alpha.22 (2025-08-28) + +Full Changelog: [v2.0.0-alpha.21...v2.0.0-alpha.22](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.21...v2.0.0-alpha.22) + +### Bug Fixes + +* implement lazy client creation in replicate.use() ([#57](https://github.com/replicate/replicate-python-stainless/issues/57)) ([caf4c4e](https://github.com/replicate/replicate-python-stainless/commit/caf4c4efa2be271144b22b93a38ea490b10ad86b)) + ## 2.0.0-alpha.21 (2025-08-26) Full Changelog: [v2.0.0-alpha.20...v2.0.0-alpha.21](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.20...v2.0.0-alpha.21) diff --git a/pyproject.toml b/pyproject.toml index a19c570..b64ee78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "replicate" -version = "2.0.0-alpha.21" +version = "2.0.0-alpha.22" description = "The official Python library for the replicate API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/replicate/_version.py b/src/replicate/_version.py index 06fdf6e..3f19536 100644 --- a/src/replicate/_version.py +++ b/src/replicate/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "replicate" -__version__ = "2.0.0-alpha.21" # x-release-please-version +__version__ = "2.0.0-alpha.22" # x-release-please-version