|
1 | 1 | # TODO |
2 | 2 | # - [x] Support downloading files and conversion into Path when schema is URL |
3 | 3 | # - [x] Support list outputs |
4 | | -# - [ ] Support iterator outputs |
5 | | -# - [ ] Support helpers for working with ContatenateIterator |
| 4 | +# - [x] Support iterator outputs |
| 5 | +# - [x] Support helpers for working with ContatenateIterator |
6 | 6 | # - [ ] Support reusing output URL when passing to new method |
7 | 7 | # - [ ] Support lazy downloading of files into Path |
8 | 8 | # - [ ] Support text streaming |
@@ -187,12 +187,12 @@ class OutputIterator: |
187 | 187 | An iterator wrapper that handles both regular iteration and string conversion. |
188 | 188 | """ |
189 | 189 |
|
190 | | - def __init__(self, iterator_factory, schema: dict, is_concatenate: bool): |
| 190 | + def __init__(self, iterator_factory, schema: dict, *, is_concatenate: bool) -> None: |
191 | 191 | self.iterator_factory = iterator_factory |
192 | 192 | self.schema = schema |
193 | 193 | self.is_concatenate = is_concatenate |
194 | 194 |
|
195 | | - def __iter__(self): |
| 195 | + def __iter__(self) -> Iterator[Any]: |
196 | 196 | """Iterate over output items.""" |
197 | 197 | for chunk in self.iterator_factory(): |
198 | 198 | if self.is_concatenate: |
@@ -230,7 +230,9 @@ def wait(self) -> Union[Any, Iterator[Any]]: |
230 | 230 | if _has_iterator_output_type(self.schema): |
231 | 231 | is_concatenate = _has_concatenate_iterator_output_type(self.schema) |
232 | 232 | return OutputIterator( |
233 | | - lambda: self.prediction.output_iterator(), self.schema, is_concatenate |
| 233 | + lambda: self.prediction.output_iterator(), |
| 234 | + self.schema, |
| 235 | + is_concatenate=is_concatenate, |
234 | 236 | ) |
235 | 237 |
|
236 | 238 | # Process output for file downloads based on schema |
@@ -299,15 +301,23 @@ def create(self, **inputs: Dict[str, Any]) -> Run: |
299 | 301 | """ |
300 | 302 | Start a prediction with the specified inputs. |
301 | 303 | """ |
| 304 | + # Process inputs to convert concatenate OutputIterators to strings |
| 305 | + processed_inputs = {} |
| 306 | + for key, value in inputs.items(): |
| 307 | + if isinstance(value, OutputIterator) and value.is_concatenate: |
| 308 | + processed_inputs[key] = str(value) |
| 309 | + else: |
| 310 | + processed_inputs[key] = value |
| 311 | + |
302 | 312 | version = self._version |
303 | 313 |
|
304 | 314 | if version: |
305 | 315 | prediction = self._client().predictions.create( |
306 | | - version=version, input=inputs |
| 316 | + version=version, input=processed_inputs |
307 | 317 | ) |
308 | 318 | else: |
309 | 319 | prediction = self._client().models.predictions.create( |
310 | | - model=self._model, input=inputs |
| 320 | + model=self._model, input=processed_inputs |
311 | 321 | ) |
312 | 322 |
|
313 | 323 | return Run(prediction, self.openapi_schema) |
|
0 commit comments