Skip to content

Commit dd64e91

Browse files
committed
Implement use(ref, streaming=True) to return iterators
1 parent 2afd364 commit dd64e91

File tree

2 files changed

+192
-84
lines changed

2 files changed

+192
-84
lines changed

replicate/use.py

Lines changed: 170 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -336,46 +336,54 @@ class FunctionRef(Protocol, Generic[Input, Output]):
336336
@dataclass
337337
class Run[O]:
338338
"""
339-
Represents a running prediction with access to its version.
339+
Represents a running prediction with access to the underlying schema.
340340
"""
341341

342-
prediction: Prediction
343-
schema: dict
342+
_prediction: Prediction
343+
_schema: dict
344+
345+
def __init__(
346+
self, *, prediction: Prediction, schema: dict, streaming: bool
347+
) -> None:
348+
self._prediction = prediction
349+
self._schema = schema
350+
self._streaming = streaming
344351

345352
def output(self) -> O:
346353
"""
347354
Return the output. For iterator types, returns immediately without waiting.
348355
For non-iterator types, waits for completion.
349356
"""
350-
# Return an OutputIterator immediately for iterator output types
351-
if _has_iterator_output_type(self.schema):
352-
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
357+
# Return an OutputIterator immediately when streaming, we do this for all
358+
# model return types regardless of whether they return an iterator.
359+
if self._streaming:
360+
is_concatenate = _has_concatenate_iterator_output_type(self._schema)
353361
return cast(
354362
O,
355363
OutputIterator(
356-
lambda: self.prediction.output_iterator(),
357-
lambda: self.prediction.async_output_iterator(),
358-
self.schema,
364+
lambda: self._prediction.output_iterator(),
365+
lambda: self._prediction.async_output_iterator(),
366+
self._schema,
359367
is_concatenate=is_concatenate,
360368
),
361369
)
362370

363371
# For non-iterator types, wait for completion and process output
364-
self.prediction.wait()
372+
self._prediction.wait()
365373

366-
if self.prediction.status == "failed":
367-
raise ModelError(self.prediction)
374+
if self._prediction.status == "failed":
375+
raise ModelError(self._prediction)
368376

369377
# Process output for file downloads based on schema
370-
return _process_output_with_schema(self.prediction.output, self.schema)
378+
return _process_output_with_schema(self._prediction.output, self._schema)
371379

372380
def logs(self) -> Optional[str]:
373381
"""
374382
Fetch and return the logs from the prediction.
375383
"""
376-
self.prediction.reload()
384+
self._prediction.reload()
377385

378-
return self.prediction.logs
386+
return self._prediction.logs
379387

380388

381389
@dataclass
@@ -384,45 +392,11 @@ class Function(Generic[Input, Output]):
384392
A wrapper for a Replicate model that can be called as a function.
385393
"""
386394

387-
function_ref: str
388-
389-
def _client(self) -> Client:
390-
return Client()
391-
392-
@cached_property
393-
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
394-
return ModelVersionIdentifier.parse(self.function_ref)
395+
_ref: str
395396

396-
@cached_property
397-
def _model(self) -> Model:
398-
client = self._client()
399-
model_owner, model_name, _ = self._parsed_ref
400-
return client.models.get(f"{model_owner}/{model_name}")
401-
402-
@cached_property
403-
def _version(self) -> Version | None:
404-
_, _, model_version = self._parsed_ref
405-
model = self._model
406-
try:
407-
versions = model.versions.list()
408-
if len(versions) == 0:
409-
# if we got an empty list when getting model versions, this
410-
# model is possibly a procedure instead and should be called via
411-
# the versionless API
412-
return None
413-
except ReplicateError as e:
414-
if e.status == 404:
415-
# if we get a 404 when getting model versions, this is an official
416-
# model and doesn't have addressable versions (despite what
417-
# latest_version might tell us)
418-
return None
419-
raise
420-
421-
version = (
422-
model.versions.get(model_version) if model_version else model.latest_version
423-
)
424-
425-
return version
397+
def __init__(self, ref: str, *, streaming: bool) -> None:
398+
self._ref = ref
399+
self._streaming = streaming
426400

427401
def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
428402
return self.create(*args, **inputs).output()
@@ -455,7 +429,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
455429
model=self._model, input=processed_inputs
456430
)
457431

458-
return Run(prediction, self.openapi_schema)
432+
return Run(
433+
prediction=prediction, schema=self.openapi_schema, streaming=self._streaming
434+
)
459435

460436
@property
461437
def default_example(self) -> Optional[dict[str, Any]]:
@@ -479,50 +455,96 @@ def openapi_schema(self) -> dict[str, Any]:
479455
schema = make_schema_backwards_compatible(schema, cog_version)
480456
return schema
481457

458+
def _client(self) -> Client:
459+
return Client()
460+
461+
@cached_property
462+
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
463+
return ModelVersionIdentifier.parse(self._ref)
464+
465+
@cached_property
466+
def _model(self) -> Model:
467+
client = self._client()
468+
model_owner, model_name, _ = self._parsed_ref
469+
return client.models.get(f"{model_owner}/{model_name}")
470+
471+
@cached_property
472+
def _version(self) -> Version | None:
473+
_, _, model_version = self._parsed_ref
474+
model = self._model
475+
try:
476+
versions = model.versions.list()
477+
if len(versions) == 0:
478+
# if we got an empty list when getting model versions, this
479+
# model is possibly a procedure instead and should be called via
480+
# the versionless API
481+
return None
482+
except ReplicateError as e:
483+
if e.status == 404:
484+
# if we get a 404 when getting model versions, this is an official
485+
# model and doesn't have addressable versions (despite what
486+
# latest_version might tell us)
487+
return None
488+
raise
489+
490+
version = (
491+
model.versions.get(model_version) if model_version else model.latest_version
492+
)
493+
494+
return version
495+
482496

483497
@dataclass
484498
class AsyncRun[O]:
485499
"""
486500
Represents a running prediction with access to its version (async version).
487501
"""
488502

489-
prediction: Prediction
490-
schema: dict
503+
_prediction: Prediction
504+
_schema: dict
505+
506+
def __init__(
507+
self, *, prediction: Prediction, schema: dict, streaming: bool
508+
) -> None:
509+
self._prediction = prediction
510+
self._schema = schema
511+
self._streaming = streaming
491512

492513
async def output(self) -> O:
493514
"""
494515
Return the output. For iterator types, returns immediately without waiting.
495516
For non-iterator types, waits for completion.
496517
"""
497-
# Return an OutputIterator immediately for iterator output types
498-
if _has_iterator_output_type(self.schema):
499-
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
518+
# Return an OutputIterator immediately when streaming, we do this for all
519+
# model return types regardless of whether they return an iterator.
520+
if self._streaming:
521+
is_concatenate = _has_concatenate_iterator_output_type(self._schema)
500522
return cast(
501523
O,
502524
OutputIterator(
503-
lambda: self.prediction.output_iterator(),
504-
lambda: self.prediction.async_output_iterator(),
505-
self.schema,
525+
lambda: self._prediction.output_iterator(),
526+
lambda: self._prediction.async_output_iterator(),
527+
self._schema,
506528
is_concatenate=is_concatenate,
507529
),
508530
)
509531

510532
# For non-iterator types, wait for completion and process output
511-
await self.prediction.async_wait()
533+
await self._prediction.async_wait()
512534

513-
if self.prediction.status == "failed":
514-
raise ModelError(self.prediction)
535+
if self._prediction.status == "failed":
536+
raise ModelError(self._prediction)
515537

516538
# Process output for file downloads based on schema
517-
return _process_output_with_schema(self.prediction.output, self.schema)
539+
return _process_output_with_schema(self._prediction.output, self._schema)
518540

519541
async def logs(self) -> Optional[str]:
520542
"""
521543
Fetch and return the logs from the prediction asynchronously.
522544
"""
523-
await self.prediction.async_reload()
545+
await self._prediction.async_reload()
524546

525-
return self.prediction.logs
547+
return self._prediction.logs
526548

527549

528550
@dataclass
@@ -532,6 +554,7 @@ class AsyncFunction(Generic[Input, Output]):
532554
"""
533555

534556
function_ref: str
557+
streaming: bool
535558

536559
def _client(self) -> Client:
537560
return Client()
@@ -600,7 +623,11 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
600623
model=model, input=processed_inputs
601624
)
602625

603-
return AsyncRun(prediction, await self.openapi_schema())
626+
return AsyncRun(
627+
prediction=prediction,
628+
schema=await self.openapi_schema(),
629+
streaming=self.streaming,
630+
)
604631

605632
@property
606633
def default_example(self) -> Optional[dict[str, Any]]:
@@ -629,6 +656,12 @@ async def openapi_schema(self) -> dict[str, Any]:
629656
def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ...
630657

631658

659+
@overload
660+
def use(
661+
ref: FunctionRef[Input, Output], *, streaming: Literal[False]
662+
) -> Function[Input, Output]: ...
663+
664+
632665
@overload
633666
def use(
634667
ref: FunctionRef[Input, Output], *, use_async: Literal[False]
@@ -643,25 +676,82 @@ def use(
643676

644677
@overload
645678
def use(
646-
ref: str, *, hint: Callable[Input, Output] | None = None, use_async: Literal[True]
679+
ref: FunctionRef[Input, Output],
680+
*,
681+
streaming: Literal[False],
682+
use_async: Literal[True],
647683
) -> AsyncFunction[Input, Output]: ...
648684

649685

686+
@overload
687+
def use(
688+
ref: FunctionRef[Input, Output],
689+
*,
690+
streaming: Literal[True],
691+
use_async: Literal[True],
692+
) -> AsyncFunction[Input, AsyncIterator[Output]]: ...
693+
694+
695+
@overload
696+
def use(
697+
ref: FunctionRef[Input, Output],
698+
*,
699+
streaming: Literal[False],
700+
use_async: Literal[False],
701+
) -> AsyncFunction[Input, AsyncIterator[Output]]: ...
702+
703+
650704
@overload
651705
def use(
652706
ref: str,
653707
*,
654708
hint: Callable[Input, Output] | None = None,
709+
streaming: Literal[False] = False,
655710
use_async: Literal[False] = False,
656711
) -> Function[Input, Output]: ...
657712

658713

714+
@overload
715+
def use(
716+
ref: str,
717+
*,
718+
hint: Callable[Input, Output] | None = None,
719+
streaming: Literal[True],
720+
use_async: Literal[False] = False,
721+
) -> Function[Input, Iterator[Output]]: ...
722+
723+
724+
@overload
725+
def use(
726+
ref: str,
727+
*,
728+
hint: Callable[Input, Output] | None = None,
729+
use_async: Literal[True],
730+
) -> AsyncFunction[Input, Output]: ...
731+
732+
733+
@overload
734+
def use(
735+
ref: str,
736+
*,
737+
hint: Callable[Input, Output] | None = None,
738+
streaming: Literal[True],
739+
use_async: Literal[True],
740+
) -> AsyncFunction[Input, AsyncIterator[Output]]: ...
741+
742+
659743
def use(
660744
ref: str | FunctionRef[Input, Output],
661745
*,
662746
hint: Callable[Input, Output] | None = None,
747+
streaming: bool = False,
663748
use_async: bool = False,
664-
) -> Function[Input, Output] | AsyncFunction[Input, Output]:
749+
) -> (
750+
Function[Input, Output]
751+
| AsyncFunction[Input, Output]
752+
| Function[Input, Iterator[Output]]
753+
| AsyncFunction[Input, AsyncIterator[Output]]
754+
):
665755
"""
666756
Use a Replicate model as a function.
667757
@@ -682,9 +772,9 @@ def use(
682772
pass
683773

684774
if use_async:
685-
return AsyncFunction(function_ref=str(ref))
775+
return AsyncFunction(str(ref), streaming=streaming)
686776

687-
return Function(str(ref))
777+
return Function(str(ref), streaming=streaming)
688778

689779

690780
# class Model:
@@ -693,17 +783,23 @@ def use(
693783
# def __call__(self) -> str: ...
694784

695785

696-
# def model() -> int: ...
786+
# def model() -> AsyncIterator[int]: ...
697787

698788

699789
# flux = use("")
700790
# flux_sync = use("", use_async=False)
791+
# streaming_flux_sync = use("", streaming=True, use_async=False)
701792
# flux_async = use("", use_async=True)
793+
# streaming_flux_async = use("", streaming=True, use_async=True)
702794

703795
# flux = use("", hint=model)
704796
# flux_sync = use("", hint=model, use_async=False)
797+
# streaming_flux_sync = use("", hint=model, streaming=False, use_async=False)
705798
# flux_async = use("", hint=model, use_async=True)
799+
# streaming_flux_async = use("", hint=model, streaming=True, use_async=True)
706800

707801
# flux = use(Model())
708802
# flux_sync = use(Model(), use_async=False)
803+
# streaming_flux_sync = use(Model(), streaming=False, use_async=False)
709804
# flux_async = use(Model(), use_async=True)
805+
# streaming_flux_async = use(Model(), streaming=True, use_async=True)

0 commit comments

Comments
 (0)