@@ -336,46 +336,54 @@ class FunctionRef(Protocol, Generic[Input, Output]):
336336@dataclass
337337class Run [O ]:
338338 """
339- Represents a running prediction with access to its version .
339+ Represents a running prediction with access to the underlying schema .
340340 """
341341
342- prediction : Prediction
343- schema : dict
342+ _prediction : Prediction
343+ _schema : dict
344+
345+ def __init__ (
346+ self , * , prediction : Prediction , schema : dict , streaming : bool
347+ ) -> None :
348+ self ._prediction = prediction
349+ self ._schema = schema
350+ self ._streaming = streaming
344351
345352 def output (self ) -> O :
346353 """
347354 Return the output. For iterator types, returns immediately without waiting.
348355 For non-iterator types, waits for completion.
349356 """
350- # Return an OutputIterator immediately for iterator output types
351- if _has_iterator_output_type (self .schema ):
352- is_concatenate = _has_concatenate_iterator_output_type (self .schema )
357+ # Return an OutputIterator immediately when streaming, we do this for all
358+ # model return types regardless of whether they return an iterator.
359+ if self ._streaming :
360+ is_concatenate = _has_concatenate_iterator_output_type (self ._schema )
353361 return cast (
354362 O ,
355363 OutputIterator (
356- lambda : self .prediction .output_iterator (),
357- lambda : self .prediction .async_output_iterator (),
358- self .schema ,
364+ lambda : self ._prediction .output_iterator (),
365+ lambda : self ._prediction .async_output_iterator (),
366+ self ._schema ,
359367 is_concatenate = is_concatenate ,
360368 ),
361369 )
362370
363371 # For non-iterator types, wait for completion and process output
364- self .prediction .wait ()
372+ self ._prediction .wait ()
365373
366- if self .prediction .status == "failed" :
367- raise ModelError (self .prediction )
374+ if self ._prediction .status == "failed" :
375+ raise ModelError (self ._prediction )
368376
369377 # Process output for file downloads based on schema
370- return _process_output_with_schema (self .prediction .output , self .schema )
378+ return _process_output_with_schema (self ._prediction .output , self ._schema )
371379
372380 def logs (self ) -> Optional [str ]:
373381 """
374382 Fetch and return the logs from the prediction.
375383 """
376- self .prediction .reload ()
384+ self ._prediction .reload ()
377385
378- return self .prediction .logs
386+ return self ._prediction .logs
379387
380388
381389@dataclass
@@ -384,45 +392,11 @@ class Function(Generic[Input, Output]):
384392 A wrapper for a Replicate model that can be called as a function.
385393 """
386394
387- function_ref : str
388-
389- def _client (self ) -> Client :
390- return Client ()
391-
392- @cached_property
393- def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
394- return ModelVersionIdentifier .parse (self .function_ref )
395+ _ref : str
395396
396- @cached_property
397- def _model (self ) -> Model :
398- client = self ._client ()
399- model_owner , model_name , _ = self ._parsed_ref
400- return client .models .get (f"{ model_owner } /{ model_name } " )
401-
402- @cached_property
403- def _version (self ) -> Version | None :
404- _ , _ , model_version = self ._parsed_ref
405- model = self ._model
406- try :
407- versions = model .versions .list ()
408- if len (versions ) == 0 :
409- # if we got an empty list when getting model versions, this
410- # model is possibly a procedure instead and should be called via
411- # the versionless API
412- return None
413- except ReplicateError as e :
414- if e .status == 404 :
415- # if we get a 404 when getting model versions, this is an official
416- # model and doesn't have addressable versions (despite what
417- # latest_version might tell us)
418- return None
419- raise
420-
421- version = (
422- model .versions .get (model_version ) if model_version else model .latest_version
423- )
424-
425- return version
397+ def __init__ (self , ref : str , * , streaming : bool ) -> None :
398+ self ._ref = ref
399+ self ._streaming = streaming
426400
427401 def __call__ (self , * args : Input .args , ** inputs : Input .kwargs ) -> Output :
428402 return self .create (* args , ** inputs ).output ()
@@ -455,7 +429,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
455429 model = self ._model , input = processed_inputs
456430 )
457431
458- return Run (prediction , self .openapi_schema )
432+ return Run (
433+ prediction = prediction , schema = self .openapi_schema , streaming = self ._streaming
434+ )
459435
460436 @property
461437 def default_example (self ) -> Optional [dict [str , Any ]]:
@@ -479,50 +455,96 @@ def openapi_schema(self) -> dict[str, Any]:
479455 schema = make_schema_backwards_compatible (schema , cog_version )
480456 return schema
481457
458+ def _client (self ) -> Client :
459+ return Client ()
460+
461+ @cached_property
462+ def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
463+ return ModelVersionIdentifier .parse (self ._ref )
464+
465+ @cached_property
466+ def _model (self ) -> Model :
467+ client = self ._client ()
468+ model_owner , model_name , _ = self ._parsed_ref
469+ return client .models .get (f"{ model_owner } /{ model_name } " )
470+
471+ @cached_property
472+ def _version (self ) -> Version | None :
473+ _ , _ , model_version = self ._parsed_ref
474+ model = self ._model
475+ try :
476+ versions = model .versions .list ()
477+ if len (versions ) == 0 :
478+ # if we got an empty list when getting model versions, this
479+ # model is possibly a procedure instead and should be called via
480+ # the versionless API
481+ return None
482+ except ReplicateError as e :
483+ if e .status == 404 :
484+ # if we get a 404 when getting model versions, this is an official
485+ # model and doesn't have addressable versions (despite what
486+ # latest_version might tell us)
487+ return None
488+ raise
489+
490+ version = (
491+ model .versions .get (model_version ) if model_version else model .latest_version
492+ )
493+
494+ return version
495+
482496
483497@dataclass
484498class AsyncRun [O ]:
485499 """
486500 Represents a running prediction with access to its version (async version).
487501 """
488502
489- prediction : Prediction
490- schema : dict
503+ _prediction : Prediction
504+ _schema : dict
505+
506+ def __init__ (
507+ self , * , prediction : Prediction , schema : dict , streaming : bool
508+ ) -> None :
509+ self ._prediction = prediction
510+ self ._schema = schema
511+ self ._streaming = streaming
491512
492513 async def output (self ) -> O :
493514 """
494515 Return the output. For iterator types, returns immediately without waiting.
495516 For non-iterator types, waits for completion.
496517 """
497- # Return an OutputIterator immediately for iterator output types
498- if _has_iterator_output_type (self .schema ):
499- is_concatenate = _has_concatenate_iterator_output_type (self .schema )
518+ # Return an OutputIterator immediately when streaming, we do this for all
519+ # model return types regardless of whether they return an iterator.
520+ if self ._streaming :
521+ is_concatenate = _has_concatenate_iterator_output_type (self ._schema )
500522 return cast (
501523 O ,
502524 OutputIterator (
503- lambda : self .prediction .output_iterator (),
504- lambda : self .prediction .async_output_iterator (),
505- self .schema ,
525+ lambda : self ._prediction .output_iterator (),
526+ lambda : self ._prediction .async_output_iterator (),
527+ self ._schema ,
506528 is_concatenate = is_concatenate ,
507529 ),
508530 )
509531
510532 # For non-iterator types, wait for completion and process output
511- await self .prediction .async_wait ()
533+ await self ._prediction .async_wait ()
512534
513- if self .prediction .status == "failed" :
514- raise ModelError (self .prediction )
535+ if self ._prediction .status == "failed" :
536+ raise ModelError (self ._prediction )
515537
516538 # Process output for file downloads based on schema
517- return _process_output_with_schema (self .prediction .output , self .schema )
539+ return _process_output_with_schema (self ._prediction .output , self ._schema )
518540
519541 async def logs (self ) -> Optional [str ]:
520542 """
521543 Fetch and return the logs from the prediction asynchronously.
522544 """
523- await self .prediction .async_reload ()
545+ await self ._prediction .async_reload ()
524546
525- return self .prediction .logs
547+ return self ._prediction .logs
526548
527549
528550@dataclass
@@ -532,6 +554,7 @@ class AsyncFunction(Generic[Input, Output]):
532554 """
533555
534556 function_ref : str
557+ streaming : bool
535558
536559 def _client (self ) -> Client :
537560 return Client ()
@@ -600,7 +623,11 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
600623 model = model , input = processed_inputs
601624 )
602625
603- return AsyncRun (prediction , await self .openapi_schema ())
626+ return AsyncRun (
627+ prediction = prediction ,
628+ schema = await self .openapi_schema (),
629+ streaming = self .streaming ,
630+ )
604631
605632 @property
606633 def default_example (self ) -> Optional [dict [str , Any ]]:
@@ -629,6 +656,12 @@ async def openapi_schema(self) -> dict[str, Any]:
629656def use (ref : FunctionRef [Input , Output ]) -> Function [Input , Output ]: ...
630657
631658
659+ @overload
660+ def use (
661+ ref : FunctionRef [Input , Output ], * , streaming : Literal [False ]
662+ ) -> Function [Input , Output ]: ...
663+
664+
632665@overload
633666def use (
634667 ref : FunctionRef [Input , Output ], * , use_async : Literal [False ]
@@ -643,25 +676,82 @@ def use(
643676
644677@overload
645678def use (
646- ref : str , * , hint : Callable [Input , Output ] | None = None , use_async : Literal [True ]
679+ ref : FunctionRef [Input , Output ],
680+ * ,
681+ streaming : Literal [False ],
682+ use_async : Literal [True ],
647683) -> AsyncFunction [Input , Output ]: ...
648684
649685
686+ @overload
687+ def use (
688+ ref : FunctionRef [Input , Output ],
689+ * ,
690+ streaming : Literal [True ],
691+ use_async : Literal [True ],
692+ ) -> AsyncFunction [Input , AsyncIterator [Output ]]: ...
693+
694+
695+ @overload
696+ def use (
697+ ref : FunctionRef [Input , Output ],
698+ * ,
699+ streaming : Literal [False ],
700+ use_async : Literal [False ],
701+ ) -> AsyncFunction [Input , AsyncIterator [Output ]]: ...
702+
703+
650704@overload
651705def use (
652706 ref : str ,
653707 * ,
654708 hint : Callable [Input , Output ] | None = None ,
709+ streaming : Literal [False ] = False ,
655710 use_async : Literal [False ] = False ,
656711) -> Function [Input , Output ]: ...
657712
658713
714+ @overload
715+ def use (
716+ ref : str ,
717+ * ,
718+ hint : Callable [Input , Output ] | None = None ,
719+ streaming : Literal [True ],
720+ use_async : Literal [False ] = False ,
721+ ) -> Function [Input , Iterator [Output ]]: ...
722+
723+
724+ @overload
725+ def use (
726+ ref : str ,
727+ * ,
728+ hint : Callable [Input , Output ] | None = None ,
729+ use_async : Literal [True ],
730+ ) -> AsyncFunction [Input , Output ]: ...
731+
732+
733+ @overload
734+ def use (
735+ ref : str ,
736+ * ,
737+ hint : Callable [Input , Output ] | None = None ,
738+ streaming : Literal [True ],
739+ use_async : Literal [True ],
740+ ) -> AsyncFunction [Input , AsyncIterator [Output ]]: ...
741+
742+
659743def use (
660744 ref : str | FunctionRef [Input , Output ],
661745 * ,
662746 hint : Callable [Input , Output ] | None = None ,
747+ streaming : bool = False ,
663748 use_async : bool = False ,
664- ) -> Function [Input , Output ] | AsyncFunction [Input , Output ]:
749+ ) -> (
750+ Function [Input , Output ]
751+ | AsyncFunction [Input , Output ]
752+ | Function [Input , Iterator [Output ]]
753+ | AsyncFunction [Input , AsyncIterator [Output ]]
754+ ):
665755 """
666756 Use a Replicate model as a function.
667757
@@ -682,9 +772,9 @@ def use(
682772 pass
683773
684774 if use_async :
685- return AsyncFunction (function_ref = str (ref ))
775+ return AsyncFunction (str (ref ), streaming = streaming )
686776
687- return Function (str (ref ))
777+ return Function (str (ref ), streaming = streaming )
688778
689779
690780# class Model:
@@ -693,17 +783,23 @@ def use(
693783# def __call__(self) -> str: ...
694784
695785
696- # def model() -> int: ...
786+ # def model() -> AsyncIterator[ int] : ...
697787
698788
699789# flux = use("")
700790# flux_sync = use("", use_async=False)
791+ # streaming_flux_sync = use("", streaming=True, use_async=False)
701792# flux_async = use("", use_async=True)
793+ # streaming_flux_async = use("", streaming=True, use_async=True)
702794
703795# flux = use("", hint=model)
704796# flux_sync = use("", hint=model, use_async=False)
797+ # streaming_flux_sync = use("", hint=model, streaming=False, use_async=False)
705798# flux_async = use("", hint=model, use_async=True)
799+ # streaming_flux_async = use("", hint=model, streaming=True, use_async=True)
706800
707801# flux = use(Model())
708802# flux_sync = use(Model(), use_async=False)
803+ # streaming_flux_sync = use(Model(), streaming=False, use_async=False)
709804# flux_async = use(Model(), use_async=True)
805+ # streaming_flux_async = use(Model(), streaming=True, use_async=True)
0 commit comments