Skip to content

Commit e4398b9

Browse files
committed
Add initial test for use() functionality
1 parent 0110c2c commit e4398b9

File tree

4 files changed

+445
-2
lines changed

4 files changed

+445
-2
lines changed

replicate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from replicate.client import Client
22
from replicate.pagination import async_paginate as _async_paginate
33
from replicate.pagination import paginate as _paginate
4+
from replicate.use import use
45

56
default_client = Client()
67

replicate/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ def version_has_no_array_type(cog_version: str) -> Optional[bool]:
1515

1616
def make_schema_backwards_compatible(
1717
schema: dict,
18-
cog_version: str,
18+
cog_version: str | None,
1919
) -> dict:
2020
"""A place to add backwards compatibility logic for our openapi schema"""
2121

2222
# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
23-
if version_has_no_array_type(cog_version):
23+
if cog_version and version_has_no_array_type(cog_version):
2424
output = schema["components"]["schemas"]["Output"]
2525
if output.get("type") == "array":
2626
output["x-cog-array-type"] = "iterator"

replicate/use.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# TODO
2+
# - [ ] Support downloading files and conversion into Path when schema is URL
3+
# - [ ] Support asyncio variant
4+
# - [ ] Support list outputs
5+
# - [ ] Support iterator outputs
6+
# - [ ] Support text streaming
7+
# - [ ] Support file streaming
8+
# - [ ] Support reusing output URL when passing to new method
9+
# - [ ] Support lazy downloading of files into Path
10+
# - [ ] Support helpers for working with ContatenateIterator
11+
import inspect
12+
from dataclasses import dataclass
13+
from functools import cached_property
14+
from typing import Any, Dict, Optional, Tuple
15+
16+
from replicate.client import Client
17+
from replicate.exceptions import ModelError, ReplicateError
18+
from replicate.identifier import ModelVersionIdentifier
19+
from replicate.model import Model
20+
from replicate.prediction import Prediction
21+
from replicate.run import make_schema_backwards_compatible
22+
from replicate.version import Version
23+
24+
25+
def _in_module_scope() -> bool:
26+
"""
27+
Returns True when called from top level module scope.
28+
"""
29+
import os
30+
if os.getenv("REPLICATE_ALWAYS_ALLOW_USE"):
31+
return True
32+
33+
if frame := inspect.currentframe():
34+
if caller := frame.f_back:
35+
return caller.f_code.co_name == "<module>"
36+
return False
37+
38+
39+
__all__ = ["use"]
40+
41+
42+
def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
43+
"""
44+
Returns true if the model output type is ConcatenateIterator or
45+
AsyncConcatenateIterator.
46+
"""
47+
output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
48+
49+
if output.get("type") != "array":
50+
return False
51+
52+
if output.get("items", {}).get("type") != "string":
53+
return False
54+
55+
if output.get("x-cog-array-type") != "iterator":
56+
return False
57+
58+
if output.get("x-cog-array-display") != "concatenate":
59+
return False
60+
61+
return True
62+
63+
64+
@dataclass
65+
class Run:
66+
"""
67+
Represents a running prediction with access to its version.
68+
"""
69+
70+
prediction: Prediction
71+
schema: dict
72+
73+
def wait(self) -> Any:
74+
"""
75+
Wait for the prediction to complete and return its output.
76+
"""
77+
self.prediction.wait()
78+
79+
if self.prediction.status == "failed":
80+
raise ModelError(self.prediction)
81+
82+
if _has_concatenate_iterator_output_type(self.schema):
83+
return "".join(self.prediction.output)
84+
85+
return self.prediction.output
86+
87+
def logs(self) -> Optional[str]:
88+
"""
89+
Fetch and return the logs from the prediction.
90+
"""
91+
self.prediction.reload()
92+
93+
return self.prediction.logs
94+
95+
96+
@dataclass
97+
class Function:
98+
"""
99+
A wrapper for a Replicate model that can be called as a function.
100+
"""
101+
102+
function_ref: str
103+
104+
def _client(self) -> Client:
105+
return Client()
106+
107+
@cached_property
108+
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
109+
return ModelVersionIdentifier.parse(self.function_ref)
110+
111+
@cached_property
112+
def _model(self) -> Model:
113+
client = self._client()
114+
model_owner, model_name, _ = self._parsed_ref
115+
return client.models.get(f"{model_owner}/{model_name}")
116+
117+
@cached_property
118+
def _version(self) -> Version | None:
119+
_, _, model_version = self._parsed_ref
120+
model = self._model
121+
try:
122+
versions = model.versions.list()
123+
if len(versions) == 0:
124+
# if we got an empty list when getting model versions, this
125+
# model is possibly a procedure instead and should be called via
126+
# the versionless API
127+
return None
128+
except ReplicateError as e:
129+
if e.status == 404:
130+
# if we get a 404 when getting model versions, this is an official
131+
# model and doesn't have addressable versions (despite what
132+
# latest_version might tell us)
133+
return None
134+
raise
135+
136+
version = (
137+
model.versions.get(model_version) if model_version else model.latest_version
138+
)
139+
140+
return version
141+
142+
def __call__(self, **inputs: Dict[str, Any]) -> Any:
143+
run = self.create(**inputs)
144+
return run.wait()
145+
146+
def create(self, **inputs: Dict[str, Any]) -> Run:
147+
"""
148+
Start a prediction with the specified inputs.
149+
"""
150+
version = self._version
151+
152+
if version:
153+
prediction = self._client().predictions.create(
154+
version=version, input=inputs
155+
)
156+
else:
157+
prediction = self._client().models.predictions.create(
158+
model=self._model, input=inputs
159+
)
160+
161+
return Run(prediction, self.openapi_schema)
162+
163+
@property
164+
def default_example(self) -> Optional[Prediction]:
165+
"""
166+
Get the default example for this model.
167+
"""
168+
raise NotImplementedError("This property has not yet been implemented")
169+
170+
@cached_property
171+
def openapi_schema(self) -> dict[Any, Any]:
172+
"""
173+
Get the OpenAPI schema for this model version.
174+
"""
175+
schema = self._model.latest_version.openapi_schema
176+
if cog_version := self._model.latest_version.cog_version:
177+
schema = make_schema_backwards_compatible(schema, cog_version)
178+
return schema
179+
180+
181+
def use(function_ref: str) -> Function:
182+
"""
183+
Use a Replicate model as a function.
184+
185+
This function can only be called at the top level of a module.
186+
187+
Example:
188+
189+
flux_dev = replicate.use("black-forest-labs/flux-dev")
190+
output = flux_dev(prompt="make me a sandwich")
191+
192+
"""
193+
if not _in_module_scope():
194+
raise RuntimeError(
195+
"You may only call cog.ext.pipelines.include at the top level."
196+
)
197+
198+
return Function(function_ref)

0 commit comments

Comments
 (0)