11# TODO
2- # - [x] Support downloading files and conversion into Path when schema is URL
3- # - [x] Support list outputs
4- # - [x] Support iterator outputs
5- # - [x] Support helpers for working with ContatenateIterator
6- # - [ ] Support reusing output URL when passing to new method
7- # - [ ] Support lazy downloading of files into Path
82# - [ ] Support text streaming
93# - [ ] Support file streaming
104# - [ ] Support asyncio variant
2822from replicate .version import Version
2923
3024
25+ __all__ = ["use" , "get_path_url" ]
26+
27+
3128def _in_module_scope () -> bool :
3229 """
3330 Returns True when called from top level module scope.
@@ -41,9 +38,6 @@ def _in_module_scope() -> bool:
4138 return False
4239
4340
44- __all__ = ["use" ]
45-
46-
4741def _has_concatenate_iterator_output_type (openapi_schema : dict ) -> bool :
4842 """
4943 Returns true if the model output type is ConcatenateIterator or
@@ -218,29 +212,41 @@ def ensure_path() -> Path:
218212 path = _download_file (target )
219213 return path
220214
221- object .__setattr__ (self , "__target__ " , target )
222- object .__setattr__ (self , "__path__ " , ensure_path )
215+ object .__setattr__ (self , "__replicate_target__ " , target )
216+ object .__setattr__ (self , "__replicate_path__ " , ensure_path )
223217
224218 def __getattribute__ (self , name ) -> Any :
225- if name in ("__path__ " , "__target__ " ):
219+ if name in ("__replicate_path__ " , "__replicate_target__ " ):
226220 return object .__getattribute__ (self , name )
227221
228222 # TODO: We should cover other common properties on Path...
229223 if name == "__class__" :
230224 return Path
231225
232- return getattr (object .__getattribute__ (self , "__path__ " )(), name )
226+ return getattr (object .__getattribute__ (self , "__replicate_path__ " )(), name )
233227
234228 def __setattr__ (self , name , value ) -> None :
235- if name in ("__path__ " , "__target__ " ):
229+ if name in ("__replicate_path__ " , "__replicate_target__ " ):
236230 raise ValueError ()
237231
238- object .__setattr__ (object .__getattribute__ (self , "__path__" )(), name , value )
232+ object .__setattr__ (
233+ object .__getattribute__ (self , "__replicate_path__" )(), name , value
234+ )
239235
240236 def __delattr__ (self , name ) -> None :
241- if name in ("__path__ " , "__target__ " ):
237+ if name in ("__replicate_path__ " , "__replicate_target__ " ):
242238 raise ValueError ()
243- delattr (object .__getattribute__ (self , "__path__" )(), name )
239+ delattr (object .__getattribute__ (self , "__replicate_path__" )(), name )
240+
241+
242+ def get_path_url (path : Any ) -> str | None :
243+ """
244+ Return the remote URL (if any) for a Path output from a model.
245+ """
246+ try :
247+ return object .__getattribute__ (path , "__replicate_target__" )
248+ except AttributeError :
249+ return None
244250
245251
246252@dataclass
@@ -252,7 +258,7 @@ class Run:
252258 prediction : Prediction
253259 schema : dict
254260
255- def wait (self ) -> Union [Any , Iterator [Any ]]:
261+ def output (self ) -> Union [Any , Iterator [Any ]]:
256262 """
257263 Wait for the prediction to complete and return its output.
258264 """
@@ -330,7 +336,7 @@ def _version(self) -> Version | None:
330336
331337 def __call__ (self , ** inputs : Dict [str , Any ]) -> Any :
332338 run = self .create (** inputs )
333- return run .wait ()
339+ return run .output ()
334340
335341 def create (self , ** inputs : Dict [str , Any ]) -> Run :
336342 """
@@ -341,8 +347,8 @@ def create(self, **inputs: Dict[str, Any]) -> Run:
341347 for key , value in inputs .items ():
342348 if isinstance (value , OutputIterator ) and value .is_concatenate :
343349 processed_inputs [key ] = str (value )
344- elif isinstance (value , PathProxy ):
345- processed_inputs [key ] = object . __getattribute__ ( value , "__target__" )
350+ elif url := get_path_url (value ):
351+ processed_inputs [key ] = url
346352 else :
347353 processed_inputs [key ] = value
348354
0 commit comments