55import hashlib
66import os
77import tempfile
8- from dataclasses import dataclass
98from functools import cached_property
109from pathlib import Path
1110from typing import (
2524 cast ,
2625 overload ,
2726)
28- from urllib .parse import urlparse
2927
3028import httpx
3129
@@ -62,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
6260 return True
6361
6462
65- def _has_iterator_output_type (openapi_schema : dict ) -> bool :
66- """
67- Returns true if the model output type is an iterator (non-concatenate).
68- """
69- output = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
70- return (
71- output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator"
72- )
73-
74-
75- def _download_file (url : str ) -> Path :
76- """
77- Download a file from URL to a temporary location and return the Path.
78- """
79- parsed_url = urlparse (url )
80- filename = os .path .basename (parsed_url .path )
81-
82- if not filename or "." not in filename :
83- filename = "download"
84-
85- _ , ext = os .path .splitext (filename )
86- with tempfile .NamedTemporaryFile (delete = False , suffix = ext ) as temp_file :
87- with httpx .stream ("GET" , url ) as response :
88- response .raise_for_status ()
89- for chunk in response .iter_bytes ():
90- temp_file .write (chunk )
91-
92- return Path (temp_file .name )
93-
94-
9563def _process_iterator_item (item : Any , openapi_schema : dict ) -> Any :
9664 """
9765 Process a single item from an iterator output based on schema.
@@ -357,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]):
357325 __call__ : Callable [Input , Output ]
358326
359327
360- @dataclass
361328class Run [O ]:
362329 """
363330 Represents a running prediction with access to the underlying schema.
@@ -416,13 +383,13 @@ def logs(self) -> Optional[str]:
416383 return self ._prediction .logs
417384
418385
419- @dataclass
420386class Function (Generic [Input , Output ]):
421387 """
422388 A wrapper for a Replicate model that can be called as a function.
423389 """
424390
425391 _ref : str
392+ _streaming : bool
426393
427394 def __init__ (self , ref : str , * , streaming : bool ) -> None :
428395 self ._ref = ref
@@ -460,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
460427 )
461428
462429 return Run (
463- prediction = prediction , schema = self .openapi_schema , streaming = self ._streaming
430+ prediction = prediction ,
431+ schema = self .openapi_schema (),
432+ streaming = self ._streaming ,
464433 )
465434
466435 @property
@@ -470,18 +439,26 @@ def default_example(self) -> Optional[dict[str, Any]]:
470439 """
471440 raise NotImplementedError ("This property has not yet been implemented" )
472441
473- @cached_property
474442 def openapi_schema (self ) -> dict [str , Any ]:
475443 """
476444 Get the OpenAPI schema for this model version.
477445 """
478- latest_version = self ._model .latest_version
479- if latest_version is None :
446+ return self ._openapi_schema
447+
448+ @cached_property
449+ def _openapi_schema (self ) -> dict [str , Any ]:
450+ _ , _ , model_version = self ._parsed_ref
451+ model = self ._model
452+
453+ version = (
454+ model .versions .get (model_version ) if model_version else model .latest_version
455+ )
456+ if version is None :
480457 msg = f"Model { self ._model .owner } /{ self ._model .name } has no latest version"
481458 raise ValueError (msg )
482459
483- schema = latest_version .openapi_schema
484- if cog_version := latest_version .cog_version :
460+ schema = version .openapi_schema
461+ if cog_version := version .cog_version :
485462 schema = make_schema_backwards_compatible (schema , cog_version )
486463 return _dereference_schema (schema )
487464
@@ -524,7 +501,6 @@ def _version(self) -> Version | None:
524501 return version
525502
526503
527- @dataclass
528504class AsyncRun [O ]:
529505 """
530506 Represents a running prediction with access to its version (async version).
@@ -583,21 +559,25 @@ async def logs(self) -> Optional[str]:
583559 return self ._prediction .logs
584560
585561
586- @dataclass
587562class AsyncFunction (Generic [Input , Output ]):
588563 """
589564 An async wrapper for a Replicate model that can be called as a function.
590565 """
591566
592- function_ref : str
593- streaming : bool
567+ _ref : str
568+ _streaming : bool
569+ _openapi_schema : dict [str , Any ] | None = None
570+
571+ def __init__ (self , ref : str , * , streaming : bool ) -> None :
572+ self ._ref = ref
573+ self ._streaming = streaming
594574
595575 def _client (self ) -> Client :
596576 return Client ()
597577
598578 @cached_property
599579 def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
600- return ModelVersionIdentifier .parse (self .function_ref )
580+ return ModelVersionIdentifier .parse (self ._ref )
601581
602582 async def _model (self ) -> Model :
603583 client = self ._client ()
@@ -662,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
662642 return AsyncRun (
663643 prediction = prediction ,
664644 schema = await self .openapi_schema (),
665- streaming = self .streaming ,
645+ streaming = self ._streaming ,
666646 )
667647
668648 @property
@@ -676,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]:
676656 """
677657 Get the OpenAPI schema for this model version asynchronously.
678658 """
679- model = await self ._model ()
680- latest_version = model .latest_version
681- if latest_version is None :
682- msg = f"Model { model .owner } /{ model .name } has no latest version"
683- raise ValueError (msg )
659+ if not self ._openapi_schema :
660+ _ , _ , model_version = self ._parsed_ref
684661
685- schema = latest_version .openapi_schema
686- if cog_version := latest_version .cog_version :
687- schema = make_schema_backwards_compatible (schema , cog_version )
688- return _dereference_schema (schema )
662+ model = await self ._model ()
663+ if model_version :
664+ version = await model .versions .async_get (model_version )
665+ else :
666+ version = model .latest_version
667+
668+ if version is None :
669+ msg = f"Model { model .owner } /{ model .name } has no version"
670+ raise ValueError (msg )
671+
672+ schema = version .openapi_schema
673+ if cog_version := version .cog_version :
674+ schema = make_schema_backwards_compatible (schema , cog_version )
675+
676+ self ._openapi_schema = _dereference_schema (schema )
677+
678+ return self ._openapi_schema
689679
690680
691681@overload
0 commit comments