Skip to content

Commit c9f34c5

Browse files
committed
add support for use_file_output
1 parent 7b1a1cc commit c9f34c5

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/replicate/lib/_files.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import httpx
1010

11+
from replicate.types.prediction_output import PredictionOutput
12+
1113
from .._utils import is_mapping, is_sequence
1214
from .._client import ReplicateClient, AsyncReplicateClient
1315

@@ -124,7 +126,7 @@ def __repr__(self) -> str:
124126
return f'{self.__class__.__name__}("{self.url}")'
125127

126128

127-
def transform_output(value: Any, client: ReplicateClient | AsyncReplicateClient) -> Any:
129+
def transform_output(value: PredictionOutput, client: ReplicateClient | AsyncReplicateClient) -> Any:
128130
"""
129131
Transform the output of a prediction to a `FileOutput` object if it's a URL.
130132
"""

src/replicate/lib/_predictions.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def run(
2121
ref: Union[Model, Version, ModelVersionIdentifier, str],
2222
*,
2323
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
24-
_use_file_output: Optional[bool] = True,
24+
use_file_output: Optional[bool] = True,
2525
**params: Unpack[PredictionCreateParamsWithoutVersion],
2626
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
2727
"""
@@ -109,15 +109,18 @@ def run(
109109

110110
# TODO: Return an iterator for completed output if the model has an output iterator array type.
111111

112-
return transform_output(prediction.output, client) # type: ignore[no-any-return]
112+
if use_file_output:
113+
return transform_output(prediction.output, client) # type: ignore[no-any-return]
114+
115+
return prediction.output
113116

114117

115118
async def async_run(
116119
client: "AsyncReplicateClient",
117120
ref: Union[Model, Version, ModelVersionIdentifier, str],
118121
*,
119122
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
120-
_use_file_output: Optional[bool] = True,
123+
use_file_output: Optional[bool] = True,
121124
**params: Unpack[PredictionCreateParamsWithoutVersion],
122125
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
123126
"""
@@ -205,4 +208,7 @@ async def async_run(
205208

206209
# TODO: Return an iterator for completed output if the model has an output iterator array type.
207210

208-
return transform_output(prediction.output, client) # type: ignore[no-any-return]
211+
if use_file_output:
212+
return transform_output(prediction.output, client) # type: ignore[no-any-return]
213+
214+
return prediction.output

0 commit comments

Comments
 (0)