diff --git a/ai_diffusion/__init__.py b/ai_diffusion/__init__.py index e343c39b8..98c95e69a 100644 --- a/ai_diffusion/__init__.py +++ b/ai_diffusion/__init__.py @@ -12,7 +12,6 @@ " https://github.com/Acly/krita-ai-diffusion/releases" ) - # The following imports depend on the code running inside Krita, so the cannot be imported in tests. if importlib.util.find_spec("krita"): from .extension import AIToolsExtension as AIToolsExtension diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index bf33e6182..818da781a 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -40,12 +40,16 @@ class TextOutput(NamedTuple): mime: str +class ResizeCommand(NamedTuple): + resize_canvas: bool = False + + class SharedWorkflow(NamedTuple): publisher: str workflow: dict -ClientOutput = dict | SharedWorkflow | TextOutput +ClientOutput = dict | SharedWorkflow | TextOutput | ResizeCommand class ClientMessage(NamedTuple): diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index fb472dbca..988e6091b 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -11,7 +11,14 @@ from .api import WorkflowInput from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels -from .client import SharedWorkflow, TranslationPackage, ClientFeatures, ClientJobQueue, TextOutput +from .client import ( + SharedWorkflow, + TranslationPackage, + ClientFeatures, + ClientJobQueue, + TextOutput, + ResizeCommand, +) from .client import Quantization, MissingResources, filter_supported_styles, loras_to_upload from .comfy_workflow import ComfyObjectInfo from .files import FileFormat @@ -385,6 +392,9 @@ async def _listen_websocket(self, websocket: websockets.ClientConnection): text_output = _extract_text_output(job.id, msg) if text_output is not None: await self._messages.put(text_output) + resize_cmd = _extract_resize_output(job.id, msg) + if resize_cmd is not None: + await self._messages.put(resize_cmd) pose_json = _extract_pose_json(msg) if pose_json is not None: result = pose_json @@ -890,3 +900,25 @@ def _extract_text_output(job_id: str, msg: dict): except Exception as e: log.warning(f"Error processing message, error={str(e)}, msg={msg}") return None + + +def _extract_resize_output(job_id: str, msg: dict): + """Extract a Krita canvas resize toggle encoded directly in the UI output.""" + try: + output = msg["data"]["output"] + if output is None: + return None + + resize = output.get("resize_canvas") + if isinstance(resize, list): + active = any(bool(item) for item in resize) + else: + active = bool(resize) + + if not active: + return None + + return ClientMessage(ClientEvent.output, job_id, result=ResizeCommand(True)) + except Exception as e: + log.warning(f"Error processing Krita resize output: {e}, msg={msg}") + return None diff --git a/ai_diffusion/document.py b/ai_diffusion/document.py index 60647eeaf..173c81f25 100644 --- a/ai_diffusion/document.py +++ b/ai_diffusion/document.py @@ -50,6 +50,10 @@ def get_image( def resize(self, extent: Extent): raise NotImplementedError + def resize_canvas(self, width: int, height: int): + """Resize the underlying canvas if supported by the implementation.""" + pass + def annotate(self, key: str, value: QByteArray): pass @@ -230,6 +234,9 @@ def resize(self, extent: Extent): res = self._doc.resolution() self._doc.scaleImage(extent.width, extent.height, res, res, "Bilinear") + def resize_canvas(self, width: int, height: int): + self._doc.resizeImage(0, 0, width, height) + def annotate(self, key: str, value: QByteArray): self._doc.setAnnotation(f"ai_diffusion/{key}", f"AI Diffusion Plugin: {key}", value) diff --git a/ai_diffusion/jobs.py b/ai_diffusion/jobs.py index 2f091c66c..e2b93b56a 100644 --- a/ai_diffusion/jobs.py +++ b/ai_diffusion/jobs.py @@ -55,6 +55,7 @@ class JobParams: has_mask: bool = False frame: tuple[int, int, int] = (0, 0, 0) animation_id: str = "" + resize_canvas: bool = False @staticmethod def from_dict(data: dict[str, Any]): diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 1e3ade128..3ea6fb52d 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -23,7 +23,14 @@ from .settings import settings from .network import NetworkError from .image import Extent, Image, Mask, Bounds, DummyImage -from .client import Client, ClientMessage, ClientEvent, ClientOutput, is_style_supported +from .client import ( + Client, + ClientMessage, + ClientEvent, + ClientOutput, + is_style_supported, + ResizeCommand, +) from .client import filter_supported_styles, resolve_arch from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode from .document import Document, KritaDocument @@ -582,7 +589,10 @@ def handle_message(self, message: ClientMessage): self.progress_kind = ProgressKind.upload self.progress = message.progress elif message.event is ClientEvent.output: - self.custom.show_output(message.result) + if isinstance(message.result, ResizeCommand): + self._apply_resize_command(message.result, job) + else: + self.custom.show_output(message.result) elif message.event is ClientEvent.finished: if message.error: # successful jobs may have encountered some warnings self.report_error(Error.from_string(message.error, ErrorKind.warning)) @@ -622,6 +632,9 @@ def _finish_job(self, job: Job, event: ClientEvent): self.jobs.notify_cancelled(job) self.progress = 0 + def _apply_resize_command(self, cmd: ResizeCommand, job: Job): + job.params.resize_canvas = cmd.resize_canvas + def update_preview(self): if selection := self.jobs.selection: self.show_preview(selection[0].job, selection[0].image) @@ -665,6 +678,9 @@ def apply_result( region_behavior=ApplyRegionBehavior.layer_group, prefix="", ): + if params.resize_canvas and self.document.extent != image.extent: + self.document.resize_canvas(*image.extent) + bounds = Bounds(*params.bounds.offset, *image.extent) if len(params.regions) == 0 or region_behavior is ApplyRegionBehavior.none: if behavior is ApplyBehavior.replace: