Skip to content

Commit 85ba399

Browse files
committed
Support conversion of file outputs into Path in use()
1 parent e2d8e56 commit 85ba399

File tree

2 files changed

+256
-13
lines changed

2 files changed

+256
-13
lines changed

replicate/use.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@
1010
# - [ ] Support helpers for working with ContatenateIterator
1111
import inspect
1212
import os
13+
import tempfile
1314
from dataclasses import dataclass
1415
from functools import cached_property
16+
from pathlib import Path
1517
from typing import Any, Dict, Optional, Tuple
18+
from urllib.parse import urlparse
19+
20+
import httpx
1621

1722
from replicate.client import Client
1823
from replicate.exceptions import ModelError, ReplicateError
@@ -61,6 +66,90 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
6166
return True
6267

6368

69+
def _download_file(url: str) -> Path:
70+
"""
71+
Download a file from URL to a temporary location and return the Path.
72+
"""
73+
parsed_url = urlparse(url)
74+
filename = os.path.basename(parsed_url.path)
75+
76+
if not filename or "." not in filename:
77+
filename = "download"
78+
79+
_, ext = os.path.splitext(filename)
80+
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
81+
with httpx.stream("GET", url) as response:
82+
response.raise_for_status()
83+
for chunk in response.iter_bytes():
84+
temp_file.write(chunk)
85+
86+
return Path(temp_file.name)
87+
88+
89+
def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
90+
"""
91+
Process output data, downloading files based on OpenAPI schema.
92+
"""
93+
output_schema = (
94+
openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
95+
)
96+
97+
# Handle direct string with format=uri
98+
if output_schema.get("type") == "string" and output_schema.get("format") == "uri":
99+
if isinstance(output, str) and output.startswith(("http://", "https://")):
100+
return _download_file(output)
101+
return output
102+
103+
# Handle array of strings with format=uri
104+
if output_schema.get("type") == "array":
105+
items = output_schema.get("items", {})
106+
if items.get("type") == "string" and items.get("format") == "uri":
107+
if isinstance(output, list):
108+
return [
109+
_download_file(url)
110+
if isinstance(url, str) and url.startswith(("http://", "https://"))
111+
else url
112+
for url in output
113+
]
114+
return output
115+
116+
# Handle object with properties
117+
if output_schema.get("type") == "object" and isinstance(output, dict):
118+
properties = output_schema.get("properties", {})
119+
result = output.copy()
120+
121+
for prop_name, prop_schema in properties.items():
122+
if prop_name in result:
123+
value = result[prop_name]
124+
125+
# Direct file property
126+
if (
127+
prop_schema.get("type") == "string"
128+
and prop_schema.get("format") == "uri"
129+
):
130+
if isinstance(value, str) and value.startswith(
131+
("http://", "https://")
132+
):
133+
result[prop_name] = _download_file(value)
134+
135+
# Array of files property
136+
elif prop_schema.get("type") == "array":
137+
items = prop_schema.get("items", {})
138+
if items.get("type") == "string" and items.get("format") == "uri":
139+
if isinstance(value, list):
140+
result[prop_name] = [
141+
_download_file(url)
142+
if isinstance(url, str)
143+
and url.startswith(("http://", "https://"))
144+
else url
145+
for url in value
146+
]
147+
148+
return result
149+
150+
return output
151+
152+
64153
@dataclass
65154
class Run:
66155
"""
@@ -82,7 +171,8 @@ def wait(self) -> Any:
82171
if _has_concatenate_iterator_output_type(self.schema):
83172
return "".join(self.prediction.output)
84173

85-
return self.prediction.output
174+
# Process output for file downloads based on schema
175+
return _process_output_with_schema(self.prediction.output, self.schema)
86176

87177
def logs(self) -> Optional[str]:
88178
"""

tests/test_use.py

Lines changed: 165 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -438,14 +438,23 @@ async def test_use_path_output(use_async_client):
438438
)
439439
mock_prediction_endpoints(output_data="https://example.com/output.jpg")
440440

441+
# Mock the file download
442+
respx.get("https://example.com/output.jpg").mock(
443+
return_value=httpx.Response(200, content=b"fake image data")
444+
)
445+
441446
# Call use with "acme/hotdog-detector"
442447
hotdog_detector = replicate.use("acme/hotdog-detector")
443448

444449
# Call function with prompt="hello world"
445450
output = hotdog_detector(prompt="hello world")
446451

447-
# Assert that output is returned as a string URL
448-
assert output == "https://example.com/output.jpg"
452+
# Assert that output is returned as a Path object
453+
from pathlib import Path
454+
455+
assert isinstance(output, Path)
456+
assert output.exists()
457+
assert output.read_bytes() == b"fake image data"
449458

450459

451460
@pytest.mark.asyncio
@@ -474,17 +483,29 @@ async def test_use_list_of_paths_output(use_async_client):
474483
]
475484
)
476485

486+
# Mock the file downloads
487+
respx.get("https://example.com/output1.jpg").mock(
488+
return_value=httpx.Response(200, content=b"fake image 1 data")
489+
)
490+
respx.get("https://example.com/output2.jpg").mock(
491+
return_value=httpx.Response(200, content=b"fake image 2 data")
492+
)
493+
477494
# Call use with "acme/hotdog-detector"
478495
hotdog_detector = replicate.use("acme/hotdog-detector")
479496

480497
# Call function with prompt="hello world"
481498
output = hotdog_detector(prompt="hello world")
482499

483-
# Assert that output is returned as a list of URLs
484-
assert output == [
485-
"https://example.com/output1.jpg",
486-
"https://example.com/output2.jpg",
487-
]
500+
# Assert that output is returned as a list of Path objects
501+
from pathlib import Path
502+
503+
assert isinstance(output, list)
504+
assert len(output) == 2
505+
assert all(isinstance(path, Path) for path in output)
506+
assert all(path.exists() for path in output)
507+
assert output[0].read_bytes() == b"fake image 1 data"
508+
assert output[1].read_bytes() == b"fake image 2 data"
488509

489510

490511
@pytest.mark.asyncio
@@ -514,17 +535,29 @@ async def test_use_iterator_of_paths_output(use_async_client):
514535
]
515536
)
516537

538+
# Mock the file downloads
539+
respx.get("https://example.com/output1.jpg").mock(
540+
return_value=httpx.Response(200, content=b"fake image 1 data")
541+
)
542+
respx.get("https://example.com/output2.jpg").mock(
543+
return_value=httpx.Response(200, content=b"fake image 2 data")
544+
)
545+
517546
# Call use with "acme/hotdog-detector"
518547
hotdog_detector = replicate.use("acme/hotdog-detector")
519548

520549
# Call function with prompt="hello world"
521550
output = hotdog_detector(prompt="hello world")
522551

523-
# Assert that output is returned as a list of URLs
524-
assert output == [
525-
"https://example.com/output1.jpg",
526-
"https://example.com/output2.jpg",
527-
]
552+
# Assert that output is returned as a list of Path objects
553+
from pathlib import Path
554+
555+
assert isinstance(output, list)
556+
assert len(output) == 2
557+
assert all(isinstance(path, Path) for path in output)
558+
assert all(path.exists() for path in output)
559+
assert output[0].read_bytes() == b"fake image 1 data"
560+
assert output[1].read_bytes() == b"fake image 2 data"
528561

529562

530563
@pytest.mark.asyncio
@@ -600,3 +633,123 @@ async def test_use_function_logs_method_polling(use_async_client):
600633
# Call logs method again to get updated logs (simulates polling)
601634
updated_logs = run.logs()
602635
assert updated_logs == "Starting prediction...\nProcessing input..."
636+
637+
638+
@pytest.mark.asyncio
639+
@pytest.mark.parametrize("use_async_client", [False])
640+
@respx.mock
641+
async def test_use_object_output_with_file_properties(use_async_client):
642+
mock_model_endpoints(
643+
version_overrides={
644+
"openapi_schema": {
645+
"components": {
646+
"schemas": {
647+
"Output": {
648+
"type": "object",
649+
"properties": {
650+
"text": {"type": "string", "title": "Text"},
651+
"image": {
652+
"type": "string",
653+
"format": "uri",
654+
"title": "Image",
655+
},
656+
"count": {"type": "integer", "title": "Count"},
657+
},
658+
"title": "Output",
659+
}
660+
}
661+
}
662+
}
663+
}
664+
)
665+
mock_prediction_endpoints(
666+
output_data={
667+
"text": "Generated text",
668+
"image": "https://example.com/generated.png",
669+
"count": 42,
670+
}
671+
)
672+
673+
# Mock the file download
674+
respx.get("https://example.com/generated.png").mock(
675+
return_value=httpx.Response(200, content=b"fake png data")
676+
)
677+
678+
# Call use with "acme/hotdog-detector"
679+
hotdog_detector = replicate.use("acme/hotdog-detector")
680+
681+
# Call function with prompt="hello world"
682+
output = hotdog_detector(prompt="hello world")
683+
684+
# Assert that output is returned as an object with file downloaded
685+
from pathlib import Path
686+
687+
assert isinstance(output, dict)
688+
assert output["text"] == "Generated text"
689+
assert output["count"] == 42
690+
assert isinstance(output["image"], Path)
691+
assert output["image"].exists()
692+
assert output["image"].read_bytes() == b"fake png data"
693+
694+
695+
@pytest.mark.asyncio
696+
@pytest.mark.parametrize("use_async_client", [False])
697+
@respx.mock
698+
async def test_use_object_output_with_file_list_property(use_async_client):
699+
mock_model_endpoints(
700+
version_overrides={
701+
"openapi_schema": {
702+
"components": {
703+
"schemas": {
704+
"Output": {
705+
"type": "object",
706+
"properties": {
707+
"text": {"type": "string", "title": "Text"},
708+
"images": {
709+
"type": "array",
710+
"items": {"type": "string", "format": "uri"},
711+
"title": "Images",
712+
},
713+
},
714+
"title": "Output",
715+
}
716+
}
717+
}
718+
}
719+
}
720+
)
721+
mock_prediction_endpoints(
722+
output_data={
723+
"text": "Generated text",
724+
"images": [
725+
"https://example.com/image1.png",
726+
"https://example.com/image2.png",
727+
],
728+
}
729+
)
730+
731+
# Mock the file downloads
732+
respx.get("https://example.com/image1.png").mock(
733+
return_value=httpx.Response(200, content=b"fake png 1 data")
734+
)
735+
respx.get("https://example.com/image2.png").mock(
736+
return_value=httpx.Response(200, content=b"fake png 2 data")
737+
)
738+
739+
# Call use with "acme/hotdog-detector"
740+
hotdog_detector = replicate.use("acme/hotdog-detector")
741+
742+
# Call function with prompt="hello world"
743+
output = hotdog_detector(prompt="hello world")
744+
745+
# Assert that output is returned as an object with files downloaded
746+
from pathlib import Path
747+
748+
assert isinstance(output, dict)
749+
assert output["text"] == "Generated text"
750+
assert isinstance(output["images"], list)
751+
assert len(output["images"]) == 2
752+
assert all(isinstance(path, Path) for path in output["images"])
753+
assert all(path.exists() for path in output["images"])
754+
assert output["images"][0].read_bytes() == b"fake png 1 data"
755+
assert output["images"][1].read_bytes() == b"fake png 2 data"

0 commit comments

Comments
 (0)