Skip to content

Commit 6fd9ff3

Browse files
committed
Improve transformation of inputs passed to Function.create()
Currently we only convert `URLPath` instances into URL strings if passed directly as an input parameter. We want to ensure that we convert all instances of `URLPath`s provided otherwise we get JSON encoding errors. The most common use of this is passing a list of files that were output from another model. This PR attempts to transform the most common Python data structures into either lists or dicts and in doing so extracts the underlying URL value from any `URL` path instances. This might be better implemented as a custom JSON encoder.
1 parent ff26a73 commit 6fd9ff3

File tree

2 files changed

+345
-18
lines changed

2 files changed

+345
-18
lines changed

src/replicate/lib/_predictions_use.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Union,
1515
Generic,
1616
Literal,
17+
Mapping,
1718
TypeVar,
1819
Callable,
1920
Iterator,
@@ -26,6 +27,7 @@
2627
)
2728
from pathlib import Path
2829
from functools import cached_property
30+
from collections.abc import Iterable, AsyncIterable
2931
from typing_extensions import ParamSpec, override
3032

3133
import httpx
@@ -456,21 +458,34 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
456458
"""
457459
Start a prediction with the specified inputs.
458460
"""
461+
459462
# Process inputs to convert concatenate SyncOutputIterators to strings and URLPath to URLs
460-
processed_inputs = {}
461-
for key, value in inputs.items():
463+
def _process_input(value: Any) -> Any:
464+
if isinstance(value, bytes) or isinstance(value, str):
465+
return value
466+
462467
if isinstance(value, SyncOutputIterator):
463468
if value.is_concatenate:
464469
# TODO: Fix type inference for str() conversion of generic iterator
465-
processed_inputs[key] = str(value) # type: ignore[arg-type]
466-
else:
467-
# TODO: Fix type inference for SyncOutputIterator iteration
468-
processed_inputs[key] = list(value) # type: ignore[arg-type, misc, assignment]
469-
elif url := get_path_url(value):
470-
processed_inputs[key] = url
471-
else:
472-
# TODO: Fix type inference for generic value assignment
473-
processed_inputs[key] = value # type: ignore[assignment]
470+
return str(value) # type: ignore[arg-type]
471+
472+
# TODO: Fix type inference for SyncOutputIterator iteration
473+
return [_process_input(v) for v in value]
474+
475+
if isinstance(value, Mapping):
476+
return {k: _process_input(v) for k, v in value.items()}
477+
478+
if isinstance(value, Iterable):
479+
return [_process_input(v) for v in value]
480+
481+
if url := get_path_url(value):
482+
return url
483+
484+
return value
485+
486+
processed_inputs = {}
487+
for key, value in inputs.items():
488+
processed_inputs[key] = _process_input(value)
474489

475490
version = self._version
476491

@@ -731,15 +746,31 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
731746
"""
732747
# Process inputs to convert concatenate AsyncOutputIterators to strings and URLPath to URLs
733748
processed_inputs = {}
734-
for key, value in inputs.items():
749+
750+
async def _process_input(value: Any) -> Any:
751+
if isinstance(value, bytes) or isinstance(value, str):
752+
return value
753+
735754
if isinstance(value, AsyncOutputIterator):
736755
# TODO: Fix type inference for AsyncOutputIterator await
737-
processed_inputs[key] = await value # type: ignore[misc]
738-
elif url := get_path_url(value):
739-
processed_inputs[key] = url
740-
else:
741-
# TODO: Fix type inference for generic value assignment
742-
processed_inputs[key] = value # type: ignore[assignment]
756+
return await _process_input(await value)
757+
758+
if isinstance(value, Mapping):
759+
return {k: await _process_input(v) for k, v in value.items()}
760+
761+
if isinstance(value, Iterable):
762+
return [await _process_input(v) for v in value]
763+
764+
if isinstance(value, AsyncIterable):
765+
return [await _process_input(v) async for v in value]
766+
767+
if url := get_path_url(value):
768+
return url
769+
770+
return value
771+
772+
for key, value in inputs.items():
773+
processed_inputs[key] = await _process_input(value)
743774

744775
version = await self._version()
745776

tests/lib/test_use.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import Any, Dict, Optional
5+
6+
import httpx
7+
import pytest
8+
from respx import MockRouter
9+
10+
import replicate
11+
from replicate.lib._predictions_use import URLPath, SyncOutputIterator, AsyncOutputIterator
12+
13+
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
14+
bearer_token = "My Bearer Token"
15+
16+
17+
# Mock prediction data for testing
18+
def create_mock_prediction(
19+
status: str = "succeeded",
20+
output: Any = "test output",
21+
error: Optional[str] = None,
22+
logs: Optional[str] = None,
23+
urls: Optional[Dict[str, str]] = None,
24+
) -> Dict[str, Any]:
25+
if urls is None:
26+
urls = {
27+
"get": "https://api.replicate.com/v1/predictions/test_prediction_id",
28+
"cancel": "https://api.replicate.com/v1/predictions/test_prediction_id/cancel",
29+
"web": "https://replicate.com/p/test_prediction_id",
30+
}
31+
32+
return {
33+
"id": "test_prediction_id",
34+
"version": "test_version",
35+
"status": status,
36+
"input": {"prompt": "test prompt"},
37+
"output": output,
38+
"error": error,
39+
"logs": logs,
40+
"created_at": "2023-01-01T00:00:00Z",
41+
"started_at": "2023-01-01T00:00:01Z",
42+
"completed_at": "2023-01-01T00:00:02Z" if status in ["succeeded", "failed"] else None,
43+
"urls": urls,
44+
"model": "test-model",
45+
"data_removed": False,
46+
}
47+
48+
49+
def create_mock_version() -> Dict[str, Any]:
50+
return {
51+
"cover_image_url": "https://replicate.delivery/xezq/7i7baf9dE93AP6bjmBZzqh3ZBkcB4pEtIb5dK9LajHbF0UyKA/output.mp4",
52+
"created_at": "2025-10-31T12:36:16.373813Z",
53+
"default_example": None,
54+
"description": "Fast GPU-powered concatenation of multiple videos, with short audio crossfades",
55+
"github_url": None,
56+
"latest_version": {
57+
"id": "11365b52712fbf76932e83bfef43a7ccb1af898fbefcd3da00ecea25d2a40f5e",
58+
"created_at": "2025-10-31T17:37:27.465191Z",
59+
"cog_version": "0.16.6",
60+
"openapi_schema": {
61+
"info": {"title": "Cog", "version": "0.1.0"},
62+
"paths": {},
63+
"openapi": "3.0.2",
64+
"components": {
65+
"schemas": {
66+
"Input": {
67+
"type": "object",
68+
"title": "Input",
69+
"required": ["videos"],
70+
"properties": {
71+
"videos": {
72+
"type": "array",
73+
"items": {"type": "string", "format": "uri"},
74+
"title": "Videos",
75+
"description": "Videos to stitch together (can be uploaded files or URLs)",
76+
},
77+
},
78+
},
79+
"Output": {"type": "string", "title": "Output", "format": "uri"},
80+
"Status": {
81+
"enum": ["starting", "processing", "succeeded", "canceled", "failed"],
82+
"type": "string",
83+
"title": "Status",
84+
"description": "An enumeration.",
85+
},
86+
"preset": {
87+
"enum": [
88+
"ultrafast",
89+
"superfast",
90+
"veryfast",
91+
"faster",
92+
"fast",
93+
"medium",
94+
"slow",
95+
"slower",
96+
"veryslow",
97+
],
98+
"type": "string",
99+
"title": "preset",
100+
"description": "An enumeration.",
101+
},
102+
"WebhookEvent": {
103+
"enum": ["start", "output", "logs", "completed"],
104+
"type": "string",
105+
"title": "WebhookEvent",
106+
"description": "An enumeration.",
107+
},
108+
"ValidationError": {
109+
"type": "object",
110+
"title": "ValidationError",
111+
"required": ["loc", "msg", "type"],
112+
"properties": {
113+
"loc": {
114+
"type": "array",
115+
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
116+
"title": "Location",
117+
},
118+
"msg": {"type": "string", "title": "Message"},
119+
"type": {"type": "string", "title": "Error Type"},
120+
},
121+
},
122+
"PredictionRequest": {
123+
"type": "object",
124+
"title": "PredictionRequest",
125+
"properties": {
126+
"id": {"type": "string", "title": "Id", "Noneable": True},
127+
"input": {"$ref": "#/components/schemas/Input", "Noneable": True},
128+
"context": {
129+
"type": "object",
130+
"title": "Context",
131+
"Noneable": True,
132+
"additionalProperties": {"type": "string"},
133+
},
134+
"webhook": {
135+
"type": "string",
136+
"title": "Webhook",
137+
"format": "uri",
138+
"Noneable": True,
139+
"maxLength": 65536,
140+
"minLength": 1,
141+
},
142+
"created_at": {
143+
"type": "string",
144+
"title": "Created At",
145+
"format": "date-time",
146+
"Noneable": True,
147+
},
148+
"output_file_prefix": {
149+
"type": "string",
150+
"title": "Output File Prefix",
151+
"Noneable": True,
152+
},
153+
"webhook_events_filter": {
154+
"type": "array",
155+
"items": {"$ref": "#/components/schemas/WebhookEvent"},
156+
"default": ["start", "output", "logs", "completed"],
157+
"Noneable": True,
158+
},
159+
},
160+
},
161+
"PredictionResponse": {
162+
"type": "object",
163+
"title": "PredictionResponse",
164+
"properties": {
165+
"id": {"type": "string", "title": "Id", "Noneable": True},
166+
"logs": {"type": "string", "title": "Logs", "default": ""},
167+
"error": {"type": "string", "title": "Error", "Noneable": True},
168+
"input": {"$ref": "#/components/schemas/Input", "Noneable": True},
169+
"output": {"$ref": "#/components/schemas/Output"},
170+
"status": {"$ref": "#/components/schemas/Status", "Noneable": True},
171+
"metrics": {
172+
"type": "object",
173+
"title": "Metrics",
174+
"Noneable": True,
175+
"additionalProperties": True,
176+
},
177+
"version": {"type": "string", "title": "Version", "Noneable": True},
178+
"created_at": {
179+
"type": "string",
180+
"title": "Created At",
181+
"format": "date-time",
182+
"Noneable": True,
183+
},
184+
"started_at": {
185+
"type": "string",
186+
"title": "Started At",
187+
"format": "date-time",
188+
"Noneable": True,
189+
},
190+
"completed_at": {
191+
"type": "string",
192+
"title": "Completed At",
193+
"format": "date-time",
194+
"Noneable": True,
195+
},
196+
},
197+
},
198+
"HTTPValidationError": {
199+
"type": "object",
200+
"title": "HTTPValidationError",
201+
"properties": {
202+
"detail": {
203+
"type": "array",
204+
"items": {"$ref": "#/components/schemas/ValidationError"},
205+
"title": "Detail",
206+
}
207+
},
208+
},
209+
}
210+
},
211+
},
212+
},
213+
"license_url": None,
214+
"name": "video-stitcher",
215+
"owner": "andreasjansson",
216+
"is_official": False,
217+
"paper_url": None,
218+
"run_count": 73,
219+
"url": "https://replicate.com/andreasjansson/video-stitcher",
220+
"visibility": "public",
221+
"weights_url": None,
222+
}
223+
224+
225+
def async_list_fixture():
226+
async def inner():
227+
for x in ["https://example.com/image.png"]:
228+
yield x
229+
230+
return inner()
231+
232+
233+
class TestUse:
234+
@pytest.mark.respx(base_url=base_url)
235+
@pytest.mark.parametrize(
236+
"inputs",
237+
[
238+
URLPath("https://example.com/image.png"),
239+
[URLPath("https://example.com/image.png")],
240+
{URLPath("https://example.com/image.png")},
241+
(x for x in [URLPath("https://example.com/image.png")]),
242+
{"file": URLPath("https://example.com/image.png")},
243+
SyncOutputIterator(lambda: (x for x in ["https://example.com/image.png"]), schema={}, is_concatenate=False),
244+
],
245+
)
246+
def test_run_with_url_path(self, respx_mock: MockRouter, inputs) -> None:
247+
"""Test basic model run functionality."""
248+
respx_mock.post("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/predictions").mock(
249+
return_value=httpx.Response(201, json=create_mock_prediction())
250+
)
251+
respx_mock.get("https://api.replicate.com/v1/predictions/test_prediction_id").mock(
252+
return_value=httpx.Response(200, json=create_mock_prediction())
253+
)
254+
respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher").mock(
255+
return_value=httpx.Response(200, json=create_mock_version())
256+
)
257+
respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/versions").mock(
258+
return_value=httpx.Response(404, json={})
259+
)
260+
261+
model = replicate.use("andreasjansson/video-stitcher")
262+
output: Any = model(prompt=inputs)
263+
264+
assert output == "test output"
265+
266+
@pytest.mark.respx(base_url=base_url)
267+
@pytest.mark.parametrize(
268+
"inputs",
269+
[
270+
URLPath("https://example.com/image.png"),
271+
[URLPath("https://example.com/image.png")],
272+
{URLPath("https://example.com/image.png")},
273+
(x for x in [URLPath("https://example.com/image.png")]),
274+
{"file": URLPath("https://example.com/image.png")},
275+
AsyncOutputIterator(async_list_fixture, schema={}, is_concatenate=False),
276+
],
277+
)
278+
async def test_run_with_url_path_async(self, respx_mock: MockRouter, inputs) -> None:
279+
"""Test basic model run functionality."""
280+
respx_mock.post("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/predictions").mock(
281+
return_value=httpx.Response(201, json=create_mock_prediction())
282+
)
283+
respx_mock.get("https://api.replicate.com/v1/predictions/test_prediction_id").mock(
284+
return_value=httpx.Response(200, json=create_mock_prediction())
285+
)
286+
respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher").mock(
287+
return_value=httpx.Response(200, json=create_mock_version())
288+
)
289+
respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/versions").mock(
290+
return_value=httpx.Response(404, json={})
291+
)
292+
293+
model = replicate.use("andreasjansson/video-stitcher", use_async=True)
294+
output: Any = await model(prompt=inputs)
295+
296+
assert output == "test output"

0 commit comments

Comments
 (0)