From c7240647808e89c0aa0463c27828c5b0d4fbf7ce Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 13 Nov 2024 11:50:33 +0000 Subject: [PATCH] Fix a couple of bugs in the base64 file_encoding_strategy This commit adds tests for the `file_encoding_strategy` argument for `replicate.run()` and fixes two bugs that surfaced: 1. `replicate.run()` would convert the file provided into base64 encoded data but not a valid data URL. We now use the `base64_encode_file` function used for outputs. 2. `replicate.async_run()` accepted but did not use the `file_encoding_strategy` flag at all. This is fixed, though it is worth noting that `base64_encode_file` is not optimized for async workflows and will block. This migth be okay as the file sizes expected for data URL paylaods should be very small. --- replicate/helpers.py | 10 +++- tests/test_run.py | 129 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index e0bada5d..c6ac9072 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -43,7 +43,7 @@ def encode_json( return encode_json(file, client, file_encoding_strategy) if isinstance(obj, io.IOBase): if file_encoding_strategy == "base64": - return base64.b64encode(obj.read()).decode("utf-8") + return base64_encode_file(obj) else: return client.files.create(obj).urls["get"] if HAS_NUMPY: @@ -77,9 +77,13 @@ async def async_encode_json( ] if isinstance(obj, Path): with obj.open("rb") as file: - return encode_json(file, client, file_encoding_strategy) + return await async_encode_json(file, client, file_encoding_strategy) if isinstance(obj, io.IOBase): - return (await client.files.async_create(obj)).urls["get"] + if file_encoding_strategy == "base64": + # TODO: This should ideally use an async based file reader path. + return base64_encode_file(obj) + else: + return (await client.files.async_create(obj)).urls["get"] if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) diff --git a/tests/test_run.py b/tests/test_run.py index beb7f6e2..93f7248b 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,5 +1,10 @@ import asyncio +import io +import json import sys +from email.message import EmailMessage +from email.parser import BytesParser +from email.policy import HTTP from typing import AsyncIterator, Iterator, Optional, cast import httpx @@ -581,6 +586,130 @@ async def test_run_with_model_error(mock_replicate_api_token): assert excinfo.value.prediction.status == "failed" +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + mock_predictions_create = router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("processing"), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 200, + json=_version_with_schema(), + ) + ) + mock_files_create = router.route( + method="POST", + path="/files", + ).mock( + return_value=httpx.Response( + 200, + json={ + "id": "file1", + "name": "file.png", + "content_type": "image/png", + "size": 10, + "etag": "123", + "checksums": {}, + "metadata": {}, + "created_at": "", + "expires_at": "", + "urls": {"get": "https://api.replicate.com/files/file.txt"}, + }, + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + if async_flag: + await client.async_run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + ) + else: + client.run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + ) + + assert mock_predictions_create.called + prediction_payload = json.loads(mock_predictions_create.calls[0].request.content) + assert ( + prediction_payload.get("input", {}).get("file") + == "https://api.replicate.com/files/file.txt" + ) + + # Validate the Files API request + req = mock_files_create.calls[0].request + body = req.content + content_type = req.headers["Content-Type"] + + # Parse the multipart data + parser = BytesParser(EmailMessage, policy=HTTP) + headers = f"Content-Type: {content_type}\n\n".encode() + parsed_message_generator = parser.parsebytes(headers + body).walk() + next(parsed_message_generator) # wrapper + input_file = next(parsed_message_generator) + assert mock_files_create.called + assert input_file.get_content() == b"hello world" + assert input_file.get_content_type() == "application/octet-stream" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + mock_predictions_create = router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("processing"), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 200, + json=_version_with_schema(), + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + await client.async_run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + file_encoding_strategy="base64", + ) + else: + client.run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + file_encoding_strategy="base64", + ) + + assert mock_predictions_create.called + prediction_payload = json.loads(mock_predictions_create.calls[0].request.content) + assert ( + prediction_payload.get("input", {}).get("file") + == "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=" + ) + + @pytest.mark.asyncio async def test_run_with_file_output(mock_replicate_api_token): router = respx.Router(base_url="https://api.replicate.com/v1")