11# TODO
22# - [ ] Support text streaming
33# - [ ] Support file streaming
4- # - [ ] Support asyncio variant
54import hashlib
65import inspect
76import os
1211from pathlib import Path
1312from typing import (
1413 Any ,
14+ AsyncIterator ,
1515 Callable ,
1616 Generic ,
1717 Iterator ,
18+ Literal ,
1819 Optional ,
1920 ParamSpec ,
2021 Protocol ,
2122 Tuple ,
2223 TypeVar ,
24+ Union ,
2325 cast ,
2426 overload ,
2527)
@@ -211,27 +213,61 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
211213class OutputIterator :
212214 """
213215 An iterator wrapper that handles both regular iteration and string conversion.
216+ Supports both sync and async iteration patterns.
214217 """
215218
216- def __init__ (self , iterator_factory , schema : dict , * , is_concatenate : bool ) -> None :
219+ def __init__ (
220+ self ,
221+ iterator_factory : Callable [[], Iterator [Any ]],
222+ async_iterator_factory : Callable [[], AsyncIterator [Any ]],
223+ schema : dict ,
224+ * ,
225+ is_concatenate : bool
226+ ) -> None :
217227 self .iterator_factory = iterator_factory
228+ self .async_iterator_factory = async_iterator_factory
218229 self .schema = schema
219230 self .is_concatenate = is_concatenate
220231
221232 def __iter__ (self ) -> Iterator [Any ]:
222- """Iterate over output items."""
233+ """Iterate over output items synchronously ."""
223234 for chunk in self .iterator_factory ():
224235 if self .is_concatenate :
225236 yield str (chunk )
226237 else :
227238 yield _process_iterator_item (chunk , self .schema )
228239
240+ async def __aiter__ (self ) -> AsyncIterator [Any ]:
241+ """Iterate over output items asynchronously."""
242+ async for chunk in self .async_iterator_factory ():
243+ if self .is_concatenate :
244+ yield str (chunk )
245+ else :
246+ yield _process_iterator_item (chunk , self .schema )
247+
229248 def __str__ (self ) -> str :
230249 """Convert to string by joining segments with empty string."""
231250 if self .is_concatenate :
232251 return "" .join ([str (segment ) for segment in self .iterator_factory ()])
233252 else :
234- return str (self .iterator_factory ())
253+ return str (list (self .iterator_factory ()))
254+
255+ def __await__ (self ):
256+ """Make OutputIterator awaitable, returning appropriate result based on concatenate mode."""
257+ async def _collect_result ():
258+ if self .is_concatenate :
259+ # For concatenate iterators, return the joined string
260+ segments = []
261+ async for segment in self :
262+ segments .append (segment )
263+ return "" .join (segments )
264+ else :
265+ # For regular iterators, return the list of items
266+ items = []
267+ async for item in self :
268+ items .append (item )
269+ return items
270+ return _collect_result ().__await__ ()
235271
236272
237273class URLPath (os .PathLike ):
@@ -319,6 +355,7 @@ def output(self) -> O:
319355 O ,
320356 OutputIterator (
321357 lambda : self .prediction .output_iterator (),
358+ lambda : self .prediction .async_output_iterator (),
322359 self .schema ,
323360 is_concatenate = is_concatenate ,
324361 ),
@@ -435,21 +472,186 @@ def openapi_schema(self) -> dict[str, Any]:
435472 return schema
436473
437474
475+ @dataclass
476+ class AsyncRun [O ]:
477+ """
478+ Represents a running prediction with access to its version (async version).
479+ """
480+
481+ prediction : Prediction
482+ schema : dict
483+
484+ async def output (self ) -> O :
485+ """
486+ Wait for the prediction to complete and return its output asynchronously.
487+ """
488+ await self .prediction .async_wait ()
489+
490+ if self .prediction .status == "failed" :
491+ raise ModelError (self .prediction )
492+
493+ # Return an OutputIterator for iterator output types (including concatenate iterators)
494+ if _has_iterator_output_type (self .schema ):
495+ is_concatenate = _has_concatenate_iterator_output_type (self .schema )
496+ return cast (
497+ O ,
498+ OutputIterator (
499+ lambda : self .prediction .output_iterator (),
500+ lambda : self .prediction .async_output_iterator (),
501+ self .schema ,
502+ is_concatenate = is_concatenate ,
503+ ),
504+ )
505+
506+ # Process output for file downloads based on schema
507+ return _process_output_with_schema (self .prediction .output , self .schema )
508+
509+ async def logs (self ) -> Optional [str ]:
510+ """
511+ Fetch and return the logs from the prediction asynchronously.
512+ """
513+ await self .prediction .async_reload ()
514+
515+ return self .prediction .logs
516+
517+
518+ @dataclass
519+ class AsyncFunction (Generic [Input , Output ]):
520+ """
521+ An async wrapper for a Replicate model that can be called as a function.
522+ """
523+
524+ function_ref : str
525+
526+ def _client (self ) -> Client :
527+ return Client ()
528+
529+ @cached_property
530+ def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
531+ return ModelVersionIdentifier .parse (self .function_ref )
532+
533+ async def _model (self ) -> Model :
534+ client = self ._client ()
535+ model_owner , model_name , _ = self ._parsed_ref
536+ return await client .models .async_get (f"{ model_owner } /{ model_name } " )
537+
538+ async def _version (self ) -> Version | None :
539+ _ , _ , model_version = self ._parsed_ref
540+ model = await self ._model ()
541+ try :
542+ versions = await model .versions .async_list ()
543+ if len (versions ) == 0 :
544+ # if we got an empty list when getting model versions, this
545+ # model is possibly a procedure instead and should be called via
546+ # the versionless API
547+ return None
548+ except ReplicateError as e :
549+ if e .status == 404 :
550+ # if we get a 404 when getting model versions, this is an official
551+ # model and doesn't have addressable versions (despite what
552+ # latest_version might tell us)
553+ return None
554+ raise
555+
556+ if model_version :
557+ version = await model .versions .async_get (model_version )
558+ else :
559+ version = model .latest_version
560+
561+ return version
562+
563+ async def __call__ (self , * args : Input .args , ** inputs : Input .kwargs ) -> Output :
564+ run = await self .create (* args , ** inputs )
565+ return await run .output ()
566+
567+ async def create (self , * _ : Input .args , ** inputs : Input .kwargs ) -> AsyncRun [Output ]:
568+ """
569+ Start a prediction with the specified inputs asynchronously.
570+ """
571+ # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
572+ processed_inputs = {}
573+ for key , value in inputs .items ():
574+ if isinstance (value , OutputIterator ) and value .is_concatenate :
575+ processed_inputs [key ] = str (value )
576+ elif url := get_path_url (value ):
577+ processed_inputs [key ] = url
578+ else :
579+ processed_inputs [key ] = value
580+
581+ version = await self ._version ()
582+
583+ if version :
584+ prediction = await self ._client ().predictions .async_create (
585+ version = version , input = processed_inputs
586+ )
587+ else :
588+ model = await self ._model ()
589+ prediction = await self ._client ().models .predictions .async_create (
590+ model = model , input = processed_inputs
591+ )
592+
593+ return AsyncRun (prediction , await self .openapi_schema ())
594+
595+ @property
596+ def default_example (self ) -> Optional [dict [str , Any ]]:
597+ """
598+ Get the default example for this model.
599+ """
600+ raise NotImplementedError ("This property has not yet been implemented" )
601+
602+ async def openapi_schema (self ) -> dict [str , Any ]:
603+ """
604+ Get the OpenAPI schema for this model version asynchronously.
605+ """
606+ model = await self ._model ()
607+ latest_version = model .latest_version
608+ if latest_version is None :
609+ msg = f"Model { model .owner } /{ model .name } has no latest version"
610+ raise ValueError (msg )
611+
612+ schema = latest_version .openapi_schema
613+ if cog_version := latest_version .cog_version :
614+ schema = make_schema_backwards_compatible (schema , cog_version )
615+ return schema
616+
617+
438618@overload
439619def use (ref : FunctionRef [Input , Output ]) -> Function [Input , Output ]: ...
440620
441621
442622@overload
443623def use (
444- ref : str , * , hint : Callable [Input , Output ] | None = None
624+ ref : FunctionRef [Input , Output ], * , use_async : Literal [False ]
625+ ) -> Function [Input , Output ]: ...
626+
627+
628+ @overload
629+ def use (
630+ ref : FunctionRef [Input , Output ], * , use_async : Literal [True ]
631+ ) -> AsyncFunction [Input , Output ]: ...
632+
633+
634+ @overload
635+ def use (
636+ ref : str , * , hint : Callable [Input , Output ] | None = None , use_async : Literal [True ]
637+ ) -> AsyncFunction [Input , Output ]: ...
638+
639+
640+ @overload
641+ def use (
642+ ref : str ,
643+ * ,
644+ hint : Callable [Input , Output ] | None = None ,
645+ use_async : Literal [False ] = False ,
445646) -> Function [Input , Output ]: ...
446647
447648
448649def use (
449650 ref : str | FunctionRef [Input , Output ],
450651 * ,
451652 hint : Callable [Input , Output ] | None = None ,
452- ) -> Function [Input , Output ]:
653+ use_async : bool = False ,
654+ ) -> Function [Input , Output ] | AsyncFunction [Input , Output ]:
453655 """
454656 Use a Replicate model as a function.
455657
@@ -469,4 +671,29 @@ def use(
469671 except AttributeError :
470672 pass
471673
674+ if use_async :
675+ return AsyncFunction (function_ref = str (ref ))
676+
472677 return Function (str (ref ))
678+
679+
680+ # class Model:
681+ # name = "foo"
682+
683+ # def __call__(self) -> str: ...
684+
685+
686+ # def model() -> int: ...
687+
688+
689+ # flux = use("")
690+ # flux_sync = use("", use_async=False)
691+ # flux_async = use("", use_async=True)
692+
693+ # flux = use("", hint=model)
694+ # flux_sync = use("", hint=model, use_async=False)
695+ # flux_async = use("", hint=model, use_async=True)
696+
697+ # flux = use(Model())
698+ # flux_sync = use(Model(), use_async=False)
699+ # flux_async = use(Model(), use_async=True)
0 commit comments