diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index e7055abb27..f0ca578c53 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -113,6 +113,11 @@ def predict( else: payload = prediction.input.copy() + if prediction.context is None: + prediction.context = {} + if prediction.id is not None: + prediction.context["id"] = prediction.id + sid = self._worker.subscribe(task.handle_event, tag=tag) task.track(self._worker.predict(payload, context=prediction.context, tag=tag)) task.add_done_callback(self._task_done_callback(tag, sid))