From d52f49b2fc59d9afcc86db3f96b3e08b2a1cf11e Mon Sep 17 00:00:00 2001 From: David Meadows Date: Wed, 7 May 2025 15:13:06 -0400 Subject: [PATCH 1/4] fix files uploading --- src/replicate/lib/_files.py | 10 ++++------ src/replicate/lib/_predictions.py | 2 +- src/replicate/resources/files.py | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/replicate/lib/_files.py b/src/replicate/lib/_files.py index c14a944..7c6f485 100644 --- a/src/replicate/lib/_files.py +++ b/src/replicate/lib/_files.py @@ -51,9 +51,8 @@ def encode_json( if file_encoding_strategy == "base64": return base64_encode_file(obj) else: - # todo: support files endpoint - # return client.files.create(obj).urls["get"] - raise NotImplementedError("File upload is not supported yet") + response = client.files.create(content=obj.read()) + return response.urls.get if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) @@ -91,9 +90,8 @@ async def async_encode_json( # TODO: This should ideally use an async based file reader path. return base64_encode_file(obj) else: - # todo: support files endpoint - # return (await client.files.async_create(obj)).urls["get"] - raise NotImplementedError("File upload is not supported yet") + response = await client.files.create(content=obj.read()) + return response.urls.get if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) diff --git a/src/replicate/lib/_predictions.py b/src/replicate/lib/_predictions.py index 6d47e9a..e025f93 100644 --- a/src/replicate/lib/_predictions.py +++ b/src/replicate/lib/_predictions.py @@ -215,6 +215,6 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]: if prediction.status == "failed": raise ModelError(prediction=prediction) - output: list[Any] = prediction.output or [] # type: ignore + output = prediction.output or [] # type: ignore new_output = output[len(previous_output) :] yield from new_output diff --git a/src/replicate/resources/files.py b/src/replicate/resources/files.py index fdb90c0..a9f5034 100644 --- a/src/replicate/resources/files.py +++ b/src/replicate/resources/files.py @@ -58,7 +58,7 @@ def create( self, *, content: FileTypes, - filename: str, + filename: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, type: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -336,7 +336,7 @@ async def create( self, *, content: FileTypes, - filename: str, + filename: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, type: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. From 243dc78c9497d33b933e1a932dfc48a19f6274a0 Mon Sep 17 00:00:00 2001 From: David Meadows Date: Wed, 7 May 2025 15:13:06 -0400 Subject: [PATCH 2/4] fix files uploading From 091ff781a7bb18170dcd6202d29ff611024eaa1b Mon Sep 17 00:00:00 2001 From: David Meadows Date: Wed, 7 May 2025 15:23:18 -0400 Subject: [PATCH 3/4] add test for file upload --- tests/lib/test_run.py | 56 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/lib/test_run.py b/tests/lib/test_run.py index 43df10d..168447d 100644 --- a/tests/lib/test_run.py +++ b/tests/lib/test_run.py @@ -10,9 +10,11 @@ from respx import MockRouter from replicate import Replicate, AsyncReplicate +from replicate._compat import model_dump from replicate.lib._files import FileOutput, AsyncFileOutput from replicate._exceptions import ModelError, NotFoundError, BadRequestError from replicate.lib._models import Model, Version, ModelVersionIdentifier +from replicate.types.file_create_response import URLs, Checksums, FileCreateResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") bearer_token = "My Bearer Token" @@ -89,6 +91,16 @@ class TestRun: # Common model reference format that will work with the new SDK model_ref = "owner/name:version" + file_create_response = FileCreateResponse( + id="test_file_id", + checksums=Checksums(sha256="test_sha256"), + content_type="application/octet-stream", + created_at=datetime.datetime.now(), + expires_at=datetime.datetime.now() + datetime.timedelta(days=1), + metadata={}, + size=1234, + urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"), + ) @pytest.mark.respx(base_url=base_url) def test_run_basic(self, respx_mock: MockRouter) -> None: @@ -236,6 +248,23 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None: assert output == "test output" + @pytest.mark.respx(base_url=base_url) + def test_run_with_file_upload(self, respx_mock: MockRouter) -> None: + """Test run with base64 encoded file input.""" + # Create a simple file-like object + file_obj = io.BytesIO(b"test content") + + # Mock the prediction response + respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction())) + # Mock the file upload endpoint + respx_mock.post("/files").mock( + return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json")) + ) + + output: Any = self.client.run(self.model_ref, input={"file": file_obj}) + + assert output == "test output" + def test_run_with_prefer_conflict(self) -> None: """Test run with conflicting wait and prefer parameters.""" with pytest.raises(TypeError, match="cannot mix and match prefer and wait"): @@ -349,6 +378,16 @@ class TestAsyncRun: # Common model reference format that will work with the new SDK model_ref = "owner/name:version" + file_create_response = FileCreateResponse( + id="test_file_id", + checksums=Checksums(sha256="test_sha256"), + content_type="application/octet-stream", + created_at=datetime.datetime.now(), + expires_at=datetime.datetime.now() + datetime.timedelta(days=1), + metadata={}, + size=1234, + urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"), + ) @pytest.mark.respx(base_url=base_url) async def test_async_run_basic(self, respx_mock: MockRouter) -> None: @@ -501,6 +540,23 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None: assert output == "test output" + @pytest.mark.respx(base_url=base_url) + async def test_async_run_with_file_upload(self, respx_mock: MockRouter) -> None: + """Test async run with base64 encoded file input.""" + # Create a simple file-like object + file_obj = io.BytesIO(b"test content") + + # Mock the prediction response + respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction())) + # Mock the file upload endpoint + respx_mock.post("/files").mock( + return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json")) + ) + + output: Any = await self.client.run(self.model_ref, input={"file": file_obj}) + + assert output == "test output" + async def test_async_run_with_prefer_conflict(self) -> None: """Test async run with conflicting wait and prefer parameters.""" with pytest.raises(TypeError, match="cannot mix and match prefer and wait"): From bda7c30d833b4b7b30fb4716687ffcbbd95e1cd5 Mon Sep 17 00:00:00 2001 From: David Meadows Date: Fri, 9 May 2025 16:13:03 -0400 Subject: [PATCH 4/4] fixup! --- src/replicate/lib/_predictions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replicate/lib/_predictions.py b/src/replicate/lib/_predictions.py index e025f93..6d47e9a 100644 --- a/src/replicate/lib/_predictions.py +++ b/src/replicate/lib/_predictions.py @@ -215,6 +215,6 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]: if prediction.status == "failed": raise ModelError(prediction=prediction) - output = prediction.output or [] # type: ignore + output: list[Any] = prediction.output or [] # type: ignore new_output = output[len(previous_output) :] yield from new_output