11# TODO
2- # - [ ] Support downloading files and conversion into Path when schema is URL
3- # - [ ] Support asyncio variant
4- # - [ ] Support list outputs
2+ # - [x] Support downloading files and conversion into Path when schema is URL
3+ # - [x] Support list outputs
54# - [ ] Support iterator outputs
6- # - [ ] Support text streaming
7- # - [ ] Support file streaming
5+ # - [ ] Support helpers for working with ContatenateIterator
86# - [ ] Support reusing output URL when passing to new method
97# - [ ] Support lazy downloading of files into Path
10- # - [ ] Support helpers for working with ContatenateIterator
8+ # - [ ] Support text streaming
9+ # - [ ] Support file streaming
10+ # - [ ] Support asyncio variant
1111import inspect
1212import os
1313import tempfile
1414from dataclasses import dataclass
1515from functools import cached_property
1616from pathlib import Path
17- from typing import Any , Dict , Optional , Tuple
17+ from typing import Any , Dict , Iterator , Optional , Tuple , Union
1818from urllib .parse import urlparse
1919
2020import httpx
@@ -66,6 +66,16 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
6666 return True
6767
6868
69+ def _has_iterator_output_type (openapi_schema : dict ) -> bool :
70+ """
71+ Returns true if the model output type is an iterator (non-concatenate).
72+ """
73+ output = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
74+ return (
75+ output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator"
76+ )
77+
78+
6979def _download_file (url : str ) -> Path :
7080 """
7181 Download a file from URL to a temporary location and return the Path.
@@ -86,6 +96,23 @@ def _download_file(url: str) -> Path:
8696 return Path (temp_file .name )
8797
8898
99+ def _process_iterator_item (item : Any , openapi_schema : dict ) -> Any :
100+ """
101+ Process a single item from an iterator output based on schema.
102+ """
103+ output_schema = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
104+
105+ # For array/iterator types, check the items schema
106+ if output_schema .get ("type" ) == "array" and output_schema .get ("x-cog-array-type" ) == "iterator" :
107+ items_schema = output_schema .get ("items" , {})
108+ # If items are file URLs, download them
109+ if items_schema .get ("type" ) == "string" and items_schema .get ("format" ) == "uri" :
110+ if isinstance (item , str ) and item .startswith (("http://" , "https://" )):
111+ return _download_file (item )
112+
113+ return item
114+
115+
89116def _process_output_with_schema (output : Any , openapi_schema : dict ) -> Any :
90117 """
91118 Process output data, downloading files based on OpenAPI schema.
@@ -159,7 +186,7 @@ class Run:
159186 prediction : Prediction
160187 schema : dict
161188
162- def wait (self ) -> Any :
189+ def wait (self ) -> Union [ Any , Iterator [ Any ]] :
163190 """
164191 Wait for the prediction to complete and return its output.
165192 """
@@ -171,6 +198,13 @@ def wait(self) -> Any:
171198 if _has_concatenate_iterator_output_type (self .schema ):
172199 return "" .join (self .prediction .output )
173200
201+ # Return an iterator for iterator output types
202+ if _has_iterator_output_type (self .schema ) and self .prediction .output is not None :
203+ return (
204+ _process_iterator_item (chunk , self .schema )
205+ for chunk in self .prediction .output
206+ )
207+
174208 # Process output for file downloads based on schema
175209 return _process_output_with_schema (self .prediction .output , self .schema )
176210
@@ -286,8 +320,6 @@ def use(function_ref: str) -> Function:
286320
287321 """
288322 if not _in_module_scope ():
289- raise RuntimeError (
290- "You may only call cog.ext.pipelines.include at the top level."
291- )
323+ raise RuntimeError ("You may only call replicate.use() at the top level." )
292324
293325 return Function (function_ref )
0 commit comments