diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index aa421a53727b..f3b12d716160 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Literal, Optional, Type, Union +import PIL.Image import torch from ..configuration_utils import ConfigMixin, FrozenDict @@ -323,11 +324,192 @@ class ConfigSpec: description: Optional[str] = None -# YiYi Notes: both inputs and intermediate_inputs are InputParam objects -# however some fields are not relevant for intermediate_inputs -# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed -# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs -# -> should we use different class for inputs and intermediate_inputs? +# ====================================================== +# InputParam and OutputParam templates +# ====================================================== + +INPUT_PARAM_TEMPLATES = { + "prompt": { + "type_hint": str, + "required": True, + "description": "The prompt or prompts to guide image generation.", + }, + "negative_prompt": { + "type_hint": str, + "description": "The prompt or prompts not to guide the image generation.", + }, + "max_sequence_length": { + "type_hint": int, + "default": 512, + "description": "Maximum sequence length for prompt encoding.", + }, + "height": { + "type_hint": int, + "description": "The height in pixels of the generated image.", + }, + "width": { + "type_hint": int, + "description": "The width in pixels of the generated image.", + }, + "num_inference_steps": { + "type_hint": int, + "default": 50, + "description": "The number of denoising steps.", + }, + "num_images_per_prompt": { + "type_hint": int, + "default": 1, + "description": "The number of images to generate per prompt.", + }, + "generator": { + "type_hint": torch.Generator, + "description": "Torch generator for deterministic generation.", + }, + "sigmas": { + "type_hint": List[float], + "description": "Custom sigmas for the denoising process.", + }, + "strength": { + "type_hint": float, + "default": 0.9, + "description": "Strength for img2img/inpainting.", + }, + "image": { + "type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]], + "required": True, + "description": "Reference image(s) for denoising. Can be a single image or list of images.", + }, + "latents": { + "type_hint": torch.Tensor, + "description": "Pre-generated noisy latents for image generation.", + }, + "timesteps": { + "type_hint": torch.Tensor, + "description": "Timesteps for the denoising process.", + }, + "output_type": { + "type_hint": str, + "default": "pil", + "description": "Output format: 'pil', 'np', 'pt'.", + }, + "attention_kwargs": { + "type_hint": Dict[str, Any], + "description": "Additional kwargs for attention processors.", + }, + "denoiser_input_fields": { + "name": None, + "kwargs_type": "denoiser_input_fields", + "description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + }, + # inpainting + "mask_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Mask image for inpainting.", + }, + "padding_mask_crop": { + "type_hint": int, + "description": "Padding for mask cropping in inpainting.", + }, + # controlnet + "control_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Control image for ControlNet conditioning.", + }, + "control_guidance_start": { + "type_hint": float, + "default": 0.0, + "description": "When to start applying ControlNet.", + }, + "control_guidance_end": { + "type_hint": float, + "default": 1.0, + "description": "When to stop applying ControlNet.", + }, + "controlnet_conditioning_scale": { + "type_hint": float, + "default": 1.0, + "description": "Scale for ControlNet conditioning.", + }, + "layers": { + "type_hint": int, + "default": 4, + "description": "Number of layers to extract from the image", + }, + # common intermediate inputs + "prompt_embeds": { + "type_hint": torch.Tensor, + "required": True, + "description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.", + }, + "prompt_embeds_mask": { + "type_hint": torch.Tensor, + "required": True, + "description": "mask for the text embeddings. Can be generated from text_encoder step.", + }, + "negative_prompt_embeds": { + "type_hint": torch.Tensor, + "description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.", + }, + "negative_prompt_embeds_mask": { + "type_hint": torch.Tensor, + "description": "mask for the negative text embeddings. Can be generated from text_encoder step.", + }, + "image_latents": { + "type_hint": torch.Tensor, + "required": True, + "description": "image latents used to guide the image generation. Can be generated from vae_encoder step.", + }, + "batch_size": { + "type_hint": int, + "default": 1, + "description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + }, + "dtype": { + "type_hint": torch.dtype, + "default": torch.float32, + "description": "The dtype of the model inputs, can be generated in input step.", + }, +} + +OUTPUT_PARAM_TEMPLATES = { + "images": { + "type_hint": List[PIL.Image.Image], + "description": "Generated images.", + }, + "latents": { + "type_hint": torch.Tensor, + "description": "Denoised latents.", + }, + # intermediate outputs + "prompt_embeds": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The prompt embeddings.", + }, + "prompt_embeds_mask": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The encoder attention mask.", + }, + "negative_prompt_embeds": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The negative prompt embeddings.", + }, + "negative_prompt_embeds_mask": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The negative prompt embeddings mask.", + }, + "image_latents": { + "type_hint": torch.Tensor, + "description": "The latent representation of the input image.", + }, +} + + @dataclass class InputParam: """Specification for an input parameter.""" @@ -337,11 +519,31 @@ class InputParam: default: Any = None required: bool = False description: str = "" - kwargs_type: str = None # YiYi Notes: remove this feature (maybe) + kwargs_type: str = None def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + @classmethod + def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam": + """Get template for name if exists, otherwise raise ValueError.""" + if template_name not in INPUT_PARAM_TEMPLATES: + raise ValueError(f"InputParam template for {template_name} not found") + + template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy() + + # Determine the actual param name: + # 1. From overrides if provided + # 2. From template if present + # 3. Fall back to template_name + name = overrides.pop("name", template_kwargs.pop("name", template_name)) + + if note and "description" in template_kwargs: + template_kwargs["description"] = f"{template_kwargs['description']} ({note})" + + template_kwargs.update(overrides) + return cls(name=name, **template_kwargs) + @dataclass class OutputParam: @@ -350,13 +552,33 @@ class OutputParam: name: str type_hint: Any = None description: str = "" - kwargs_type: str = None # YiYi notes: remove this feature (maybe) + kwargs_type: str = None def __repr__(self): return ( f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" ) + @classmethod + def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam": + """Get template for name if exists, otherwise raise ValueError.""" + if template_name not in OUTPUT_PARAM_TEMPLATES: + raise ValueError(f"OutputParam template for {template_name} not found") + + template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy() + + # Determine the actual param name: + # 1. From overrides if provided + # 2. From template if present + # 3. Fall back to template_name + name = overrides.pop("name", template_kwargs.pop("name", template_name)) + + if note and "description" in template_kwargs: + template_kwargs["description"] = f"{template_kwargs['description']} ({note})" + + template_kwargs.update(overrides) + return cls(name=name, **template_kwargs) + def format_inputs_short(inputs): """ @@ -509,10 +731,12 @@ def wrap_text(text, indent, max_length): desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" + else: + param_str += f"\n{desc_indent}TODO: Add description." formatted_params.append(param_str) - return "\n\n".join(formatted_params) + return "\n".join(formatted_params) def format_input_params(input_params, indent_level=4, max_line_length=115): @@ -582,7 +806,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty loading_field_values = [] for field_name in component.loading_fields(): field_value = getattr(component, field_name) - if field_value is not None: + if field_value: loading_field_values.append(f"{field_name}={field_value}") # Add loading field information if available @@ -669,17 +893,17 @@ def make_doc_string( # Add description if description: desc_lines = description.strip().split("\n") - aligned_desc = "\n".join(" " + line for line in desc_lines) + aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines) output += aligned_desc + "\n\n" # Add components section if provided if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) output += components_str + "\n\n" # Add configs section if provided if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) output += configs_str + "\n\n" # Add inputs section diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index d9c8cbb01d18..338caf514b1d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -118,7 +118,40 @@ def get_timesteps(scheduler, num_inference_steps, strength): # ==================== +# auto_docstring class QwenImagePrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise for the generation process + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + model_name = "qwenimage" @property @@ -134,28 +167,20 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents"), - InputParam(name="height"), - InputParam(name="width"), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="generator"), - InputParam( - name="batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", - ), - InputParam( - name="dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs, can be generated in input step.", - ), + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), OutputParam( name="latents", type_hint=torch.Tensor, @@ -209,7 +234,42 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise (B, layers+1, C, H, W) for the generation process + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + model_name = "qwenimage-layered" @property @@ -225,29 +285,21 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents"), - InputParam(name="height"), - InputParam(name="width"), - InputParam(name="layers", default=4), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="generator"), - InputParam( - name="batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", - ), - InputParam( - name="dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs, can be generated in input step.", - ), + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("layers"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), OutputParam( name="latents", type_hint=torch.Tensor, @@ -301,7 +353,31 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): + """ + Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, + prepare_latents. Both noise and image latents should alreadybe patchified. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + """ + model_name = "qwenimage" @property @@ -323,12 +399,7 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The initial random noised, can be generated in prepare latent step.", ), - InputParam( - name="image_latents", - required=True, - type_hint=torch.Tensor, - description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.", - ), + InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."), InputParam( name="timesteps", required=True, @@ -345,6 +416,11 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=torch.Tensor, description="The initial random noised used for inpainting denoising.", ), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The scaled noisy latents to use for inpainting/image-to-image denoising.", + ), ] @staticmethod @@ -382,7 +458,29 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): + """ + Step that creates mask latents from preprocessed mask_image by interpolating to latent space. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage" @property @@ -404,9 +502,9 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The processed mask to use for the inpainting process.", ), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="dtype", required=True), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("dtype"), ] @property @@ -450,7 +548,28 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - # ==================== +# auto_docstring class QwenImageSetTimestepsStep(ModularPipelineBlocks): + """ + Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents + step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The initial random noised latents for the denoising process. Can be generated in prepare latents step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process + """ + model_name = "qwenimage" @property @@ -466,13 +585,13 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), InputParam( name="latents", required=True, type_hint=torch.Tensor, - description="The latents to use for the denoising process, used to calculate the image sequence length.", + description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.", ), ] @@ -516,7 +635,27 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): + """ + Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + """ + model_name = "qwenimage-layered" @property @@ -532,15 +671,17 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("num_inference_steps", default=50, type_hint=int), - InputParam("sigmas", type_hint=List[float]), - InputParam("image_latents", required=True, type_hint=torch.Tensor), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam.template("image_latents"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="timesteps", type_hint=torch.Tensor), + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process." + ), ] @torch.no_grad() @@ -574,7 +715,32 @@ def __call__(self, components, state: PipelineState) -> PipelineState: return components, state +# auto_docstring class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + """ + Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after + prepare latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The latents to use for the denoising process. Can be generated in prepare latents step. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + num_inference_steps (`int`): + The number of denoising steps to perform at inference time. Updated based on strength. + """ + model_name = "qwenimage" @property @@ -590,15 +756,15 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), InputParam( - name="latents", + "latents", required=True, type_hint=torch.Tensor, - description="The latents to use for the denoising process, used to calculate the image sequence length.", + description="The latents to use for the denoising process. Can be generated in prepare latents step.", ), - InputParam(name="strength", default=0.9), + InputParam.template("strength", default=0.9), ] @property @@ -607,7 +773,12 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( name="timesteps", type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + description="The timesteps to use for the denoising process.", + ), + OutputParam( + name="num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time. Updated based on strength.", ), ] @@ -654,7 +825,33 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ## RoPE inputs for denoiser +# auto_docstring class QwenImageRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ + model_name = "qwenimage" @property @@ -666,11 +863,11 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -702,7 +899,38 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after + prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`int`): + The height of the reference image. Can be generated in input step. + image_width (`int`): + The width of the reference image. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ + model_name = "qwenimage" @property @@ -712,13 +940,23 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="image_height", required=True), - InputParam(name="image_width", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=int, + description="The height of the reference image. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=int, + description="The width of the reference image. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -756,7 +994,39 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus. + Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed + after prepare_latents step. + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`List`): + The heights of the reference images. Can be generated in input step. + image_width (`List`): + The widths of the reference images. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ + model_name = "qwenimage-edit-plus" @property @@ -770,13 +1040,23 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="image_height", required=True, type_hint=List[int]), - InputParam(name="image_width", required=True, type_hint=List[int]), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=List[int], + description="The heights of the reference images. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=List[int], + description="The widths of the reference images. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -832,7 +1112,37 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + additional_t_cond (`Tensor`): + The additional t cond, used for RoPE calculation + """ + model_name = "qwenimage-layered" @property @@ -844,12 +1154,12 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="layers", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam.template("layers"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -914,7 +1224,34 @@ def __call__(self, components, state: PipelineState) -> PipelineState: ## ControlNet inputs for denoiser + + +# auto_docstring class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): + """ + step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step. + + Components: + controlnet (`QwenImageControlNetModel`) + + Inputs: + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + controlnet_keep (`List`): + The controlnet keep values + """ + model_name = "qwenimage" @property @@ -930,12 +1267,17 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("control_image_latents", required=True), + InputParam.template("control_guidance_start"), + InputParam.template("control_guidance_end"), + InputParam.template("controlnet_conditioning_scale"), InputParam( - "timesteps", + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam( + name="timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 24a88ebfca3c..1adbf6bdd355 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Any, Dict, List -import numpy as np -import PIL import torch from ...configuration_utils import FrozenDict @@ -31,7 +29,30 @@ # after denoising loop (unpack latents) + + +# auto_docstring class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + """ + Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, + channels, 1, height, width) + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + latents (`Tensor`): + The latents to decode, can be generated in the denoise step. + + Outputs: + latents (`Tensor`): + The denoisedlatents unpacked to B, C, 1, H, W + """ + model_name = "qwenimage" @property @@ -49,13 +70,21 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="height", required=True), - InputParam(name="width", required=True), + InputParam.template("height", required=True), + InputParam.template("width", required=True), InputParam( name="latents", required=True, type_hint=torch.Tensor, - description="The latents to decode, can be generated in the denoise step", + description="The latents to decode, can be generated in the denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W" ), ] @@ -72,7 +101,29 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + + Outputs: + latents (`Tensor`): + Denoised latents. (unpacked to B, C, layers+1, H, W) + """ + model_name = "qwenimage-layered" @property @@ -88,10 +139,21 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("height", required=True, type_hint=int), - InputParam("width", required=True, type_hint=int), - InputParam("layers", required=True, type_hint=int), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("layers"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"), ] @torch.no_grad() @@ -112,7 +174,26 @@ def __call__(self, components, state: PipelineState) -> PipelineState: # decode step + + +# auto_docstring class QwenImageDecoderStep(ModularPipelineBlocks): + """ + Step that decodes the latents to images + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage" @property @@ -134,19 +215,13 @@ def inputs(self) -> List[InputParam]: name="latents", required=True, type_hint=torch.Tensor, - description="The latents to decode, can be generated in the denoise step", + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", ), ] @property - def intermediate_outputs(self) -> List[str]: - return [ - OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], - description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", - ) - ] + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images", note="tensor output of the vae decoder.")] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -176,7 +251,26 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageLayeredDecoderStep(ModularPipelineBlocks): + """ + Decode unpacked latents (B, C, layers+1, H, W) into layer images. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-layered" @property @@ -198,14 +292,19 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("output_type", default="pil", type_hint=str), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + InputParam.template("output_type"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]), + OutputParam.template("images"), ] @torch.no_grad() @@ -251,7 +350,27 @@ def __call__(self, components, state: PipelineState) -> PipelineState: # postprocess the decoded images + + +# auto_docstring class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" @property @@ -272,15 +391,19 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("images", required=True, description="the generated image from decoders step"), InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", ), + InputParam.template("output_type"), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images")] + @staticmethod def check_inputs(output_type): if output_type not in ["pil", "np", "pt"]: @@ -301,7 +424,28 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image, optional apply the mask overally to the original image.. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" @property @@ -322,16 +466,24 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("images", required=True, description="the generated image from decoders step"), InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", + ), + InputParam.template("output_type"), + InputParam( + name="mask_overlay_kwargs", + type_hint=Dict[str, Any], + description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.", ), - InputParam("mask_overlay_kwargs"), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images")] + @staticmethod def check_inputs(output_type, mask_overlay_kwargs): if output_type not in ["pil", "np", "pt"]: diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index d6bcb4a94f80..8579c9843a89 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -50,7 +50,7 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam( - "latents", + name="latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", @@ -80,17 +80,12 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam( - "latents", + name="latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.", - ), + InputParam.template("image_latents"), ] @torch.no_grad() @@ -134,29 +129,12 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", ), + InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."), InputParam( - "controlnet_conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", - ), - InputParam( - "controlnet_keep", + name="controlnet_keep", required=True, type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description=( - "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds." - ), + description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.", ), ] @@ -217,28 +195,13 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", required=True, type_hint=List[Tuple[int, int]], - description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.", + description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.", ), ] @@ -317,23 +280,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", required=True, @@ -415,7 +363,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."), + OutputParam.template("latents"), ] @torch.no_grad() @@ -456,24 +404,19 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.", - ), + InputParam.template("image_latents"), InputParam( "initial_noise", required=True, type_hint=torch.Tensor, description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", - ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("latents"), ] @torch.no_grad() @@ -515,17 +458,12 @@ def loop_expected_components(self) -> List[ComponentSpec]: def loop_inputs(self) -> List[InputParam]: return [ InputParam( - "timesteps", + name="timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), + InputParam.template("num_inference_steps", required=True), ] @torch.no_grad() @@ -557,7 +495,42 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - # Qwen Image (text2image, image2image) + + +# auto_docstring class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2image and image2image tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ @@ -570,8 +543,8 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): @property def description(self) -> str: return ( - "Denoise step that iteratively denoise the latents. \n" - "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "Denoise step that iteratively denoise the latents.\n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `QwenImageLoopBeforeDenoiser`\n" " - `QwenImageLoopDenoiser`\n" @@ -581,7 +554,47 @@ def description(self) -> str: # Qwen Image (inpainting) +# auto_docstring class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -606,7 +619,47 @@ def description(self) -> str: # Qwen Image (text2image, image2image) with controlnet +# auto_docstring class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2img/img2img tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -631,7 +684,54 @@ def description(self) -> str: # Qwen Image (inpainting) with controlnet +# auto_docstring class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -664,7 +764,42 @@ def description(self) -> str: # Qwen Image Edit (image2image) +# auto_docstring class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, @@ -687,7 +822,47 @@ def description(self) -> str: # Qwen Image Edit (inpainting) +# auto_docstring class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, @@ -712,7 +887,42 @@ def description(self) -> str: # Qwen Image Layered (image2image) +# auto_docstring class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Layered. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageEditLoopBeforeDenoiser, diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 4b66dd32e521..5e1821cca5c0 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -30,7 +30,7 @@ from ...utils import logging from ...utils.torch_utils import unwrap_module from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import QwenImageModularPipeline from .prompt_templates import ( QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, @@ -259,33 +259,47 @@ def encode_vae_image( # ==================== # 1. RESIZE # ==================== +# In QwenImage pipelines, resize is a separate step because the resized image is used in VL encoding and vae encoder blocks: +# +# image (PIL.Image.Image) +# │ +# ▼ +# resized_image ([PIL.Image.Image]) +# │ +# ├──► text_encoder ──► prompt_embeds, prompt_embeds_mask +# │ (VL encoding needs the resized image for vision-language fusion) +# │ +# └──► image_processor ──► processed_image (torch.Tensor, pixel space) +# │ +# ▼ +# vae_encoder ──► image_latents (torch.Tensor, latent space) +# +# In most of our other pipelines, resizing is done as part of the image preprocessing step. +# ==================== + + +# auto_docstring class QwenImageEditResizeStep(ModularPipelineBlocks): - model_name = "qwenimage-edit" + """ + Image Resize step that resize the image to target area while maintaining the aspect ratio. - def __init__( - self, - input_name: str = "image", - output_name: str = "resized_image", - ): - """Create a configurable step for resizing images to the target area while maintaining the aspect ratio. + Components: + image_resize_processor (`VaeImageProcessor`) - Args: - input_name (str, optional): Name of the image field to read from the - pipeline state. Defaults to "image". - output_name (str, optional): Name of the resized image field to write - back to the pipeline state. Defaults to "resized_image". - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self._image_input_name = input_name - self._resized_image_output_name = output_name - super().__init__() + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`List`): + The resized images + """ + + model_name = "qwenimage-edit" @property def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio." + return "Image Resize step that resize the image to target area while maintaining the aspect ratio." @property def expected_components(self) -> List[ComponentSpec]: @@ -300,17 +314,15 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [ - InputParam( - name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" - ), - ] + return [InputParam.template("image")] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" + name="resized_image", + type_hint=List[PIL.Image.Image], + description="The resized images", ), ] @@ -318,7 +330,7 @@ def intermediate_outputs(self) -> List[OutputParam]: def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - images = getattr(block_state, self._image_input_name) + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") @@ -334,38 +346,36 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): for image in images ] - setattr(block_state, self._resized_image_output_name, resized_images) + block_state.resized_image = resized_images self.set_block_state(state, block_state) return components, state +# auto_docstring class QwenImageLayeredResizeStep(ModularPipelineBlocks): - model_name = "qwenimage-layered" + """ + Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while + maintaining the aspect ratio. - def __init__( - self, - input_name: str = "image", - output_name: str = "resized_image", - ): - """Create a configurable step for resizing images to the target area while maintaining the aspect ratio. + Components: + image_resize_processor (`VaeImageProcessor`) - Args: - input_name (str, optional): Name of the image field to read from the - pipeline state. Defaults to "image". - output_name (str, optional): Name of the resized image field to write - back to the pipeline state. Defaults to "resized_image". - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self._image_input_name = input_name - self._resized_image_output_name = output_name - super().__init__() + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + + Outputs: + resized_image (`List`): + The resized images + """ + + model_name = "qwenimage-layered" @property def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to target area while maintaining the aspect ratio." + return "Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while maintaining the aspect ratio." @property def expected_components(self) -> List[ComponentSpec]: @@ -381,9 +391,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam( - name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" - ), + InputParam.template("image"), InputParam( name="resolution", default=640, @@ -396,8 +404,10 @@ def inputs(self) -> List[InputParam]: def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" - ), + name="resized_image", + type_hint=List[PIL.Image.Image], + description="The resized images", + ) ] @staticmethod @@ -411,7 +421,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): self.check_inputs(resolution=block_state.resolution) - images = getattr(block_state, self._image_input_name) + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") @@ -428,45 +438,40 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): for image in images ] - setattr(block_state, self._resized_image_output_name, resized_images) + block_state.resized_image = resized_images self.set_block_state(state, block_state) return components, state +# auto_docstring class QwenImageEditPlusResizeStep(ModularPipelineBlocks): - """Resize each image independently based on its own aspect ratio. For QwenImage Edit Plus.""" + """ + Resize images for QwenImage Edit Plus pipeline. + Produces two outputs: resized_image (1024x1024) for VAE encoding, resized_cond_image (384x384) for VL text + encoding. Each image is resized independently based on its own aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + """ model_name = "qwenimage-edit-plus" - def __init__( - self, - input_name: str = "image", - output_name: str = "resized_image", - target_area: int = 1024 * 1024, - ): - """Create a step for resizing images to a target area. - - Each image is resized independently based on its own aspect ratio. This is suitable for Edit Plus where - multiple reference images can have different dimensions. - - Args: - input_name (str, optional): Name of the image field to read. Defaults to "image". - output_name (str, optional): Name of the resized image field to write. Defaults to "resized_image". - target_area (int, optional): Target area in pixels. Defaults to 1024*1024. - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self._image_input_name = input_name - self._resized_image_output_name = output_name - self._target_area = target_area - super().__init__() - @property def description(self) -> str: return ( - f"Image Resize step that resizes {self._image_input_name} to target area {self._target_area}.\n" + "Resize images for QwenImage Edit Plus pipeline.\n" + "Produces two outputs: resized_image (1024x1024) for VAE encoding, " + "resized_cond_image (384x384) for VL text encoding.\n" "Each image is resized independently based on its own aspect ratio." ) @@ -483,20 +488,21 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [ - InputParam( - name=self._image_input_name, - required=True, - type_hint=torch.Tensor, - description="The image(s) to resize", - ), - ] + # image + return [InputParam.template("image")] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" + name="resized_image", + type_hint=List[PIL.Image.Image], + description="Images resized to 1024x1024 target area for VAE encoding", + ), + OutputParam( + name="resized_cond_image", + type_hint=List[PIL.Image.Image], + description="Images resized to 384x384 target area for VL text encoding", ), ] @@ -504,7 +510,7 @@ def intermediate_outputs(self) -> List[OutputParam]: def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - images = getattr(block_state, self._image_input_name) + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") @@ -514,16 +520,22 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): # Resize each image independently based on its own aspect ratio resized_images = [] + resized_cond_images = [] for image in images: image_width, image_height = image.size - calculated_width, calculated_height, _ = calculate_dimensions( - self._target_area, image_width / image_height - ) - resized_images.append( - components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + + # For VAE encoder (1024x1024 target area) + vae_width, vae_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + resized_images.append(components.image_resize_processor.resize(image, height=vae_height, width=vae_width)) + + # For VL text encoder (384x384 target area) + vl_width, vl_height, _ = calculate_dimensions(384 * 384, image_width / image_height) + resized_cond_images.append( + components.image_resize_processor.resize(image, height=vl_height, width=vl_width) ) - setattr(block_state, self._resized_image_output_name, resized_images) + block_state.resized_image = resized_images + block_state.resized_cond_image = resized_cond_images self.set_block_state(state, block_state) return components, state @@ -531,14 +543,38 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): # ==================== # 2. GET IMAGE PROMPT # ==================== + + +# auto_docstring class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): """ - Auto-caption step that generates a text prompt from the input image if none is provided. Uses the VL model to - generate a description of the image. + Auto-caption step that generates a text prompt from the input image if none is provided. + Uses the VL model (text_encoder) to generate a description of the image. If prompt is already provided, this step + passes through unchanged. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + resized_image (`Image`): + The image to generate caption from, should be resized use the resize step + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + + Outputs: + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption """ model_name = "qwenimage-layered" + def __init__(self): + self.image_caption_prompt_en = QWENIMAGE_LAYERED_CAPTION_PROMPT_EN + self.image_caption_prompt_cn = QWENIMAGE_LAYERED_CAPTION_PROMPT_CN + super().__init__() + @property def description(self) -> str: return ( @@ -554,17 +590,12 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("processor", Qwen2VLProcessor), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="image_caption_prompt_en", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_EN), - ConfigSpec(name="image_caption_prompt_cn", default=QWENIMAGE_LAYERED_CAPTION_PROMPT_CN), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", type_hint=str, description="The prompt to encode"), + InputParam.template( + "prompt", required=False + ), # it is not required for qwenimage-layered, unlike other pipelines InputParam( name="resized_image", required=True, @@ -579,6 +610,16 @@ def inputs(self) -> List[InputParam]: ), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="prompt", + type_hint=str, + description="The prompt or prompts to guide image generation. If not provided, updated using image caption", + ), + ] + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -588,9 +629,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - # If prompt is empty or None, generate caption from image if block_state.prompt is None or block_state.prompt == "" or block_state.prompt == " ": if block_state.use_en_prompt: - caption_prompt = components.config.image_caption_prompt_en + caption_prompt = self.image_caption_prompt_en else: - caption_prompt = components.config.image_caption_prompt_cn + caption_prompt = self.image_caption_prompt_cn model_inputs = components.processor( text=caption_prompt, @@ -616,9 +657,44 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - # ==================== # 3. TEXT ENCODER # ==================== + + +# auto_docstring class QwenImageTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that generates text embeddings to guide the image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + model_name = "qwenimage" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_PROMPT_TEMPLATE_START_IDX + self.tokenizer_max_length = 1024 + super().__init__() + @property def description(self) -> str: return "Text Encoder step that generates text embeddings to guide the image generation." @@ -636,51 +712,21 @@ def expected_components(self) -> List[ComponentSpec]: ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_PROMPT_TEMPLATE), - ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_PROMPT_TEMPLATE_START_IDX), - ConfigSpec(name="tokenizer_max_length", default=1024), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), - InputParam( - name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024 - ), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length", default=1024), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -709,9 +755,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.tokenizer, prompt=block_state.prompt, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, - tokenizer_max_length=components.config.tokenizer_max_length, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, device=device, ) @@ -726,9 +772,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.tokenizer, prompt=negative_prompt, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, - tokenizer_max_length=components.config.tokenizer_max_length, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, device=device, ) block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[ @@ -742,9 +788,42 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageEditTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image + generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_image (`Image`): + The image prompt to encode, should be resized using resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + model_name = "qwenimage" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX + super().__init__() + @property def description(self) -> str: return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation." @@ -762,18 +841,11 @@ def expected_components(self) -> List[ComponentSpec]: ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE), - ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), InputParam( name="resized_image", required=True, @@ -785,30 +857,10 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -836,8 +888,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.processor, prompt=block_state.prompt, image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -850,8 +902,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.processor, prompt=negative_prompt, image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -859,11 +911,44 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): - """Text encoder for QwenImage Edit Plus (VL encoding with multiple images).""" + """ + Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together to generate text + embeddings for guiding image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_cond_image (`Tensor`): + The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using + resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-edit-plus" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE + self.img_template_encode = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX + super().__init__() + @property def description(self) -> str: return ( @@ -884,19 +969,11 @@ def expected_components(self) -> List[ComponentSpec]: ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec(name="prompt_template_encode", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE), - ConfigSpec(name="img_template_encode", default=QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE), - ConfigSpec(name="prompt_template_encode_start_idx", default=QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), InputParam( name="resized_cond_image", required=True, @@ -908,30 +985,10 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -959,9 +1016,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.processor, prompt=block_state.prompt, image=block_state.resized_cond_image, - prompt_template_encode=components.config.prompt_template_encode, - img_template_encode=components.config.img_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -975,9 +1032,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.processor, prompt=negative_prompt, image=block_state.resized_cond_image, - prompt_template_encode=components.config.prompt_template_encode, - img_template_encode=components.config.img_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) ) @@ -989,7 +1046,38 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): # ==================== # 4. IMAGE PREPROCESS # ==================== + + +# auto_docstring class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be + resized to the given height and width. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + model_name = "qwenimage" @property @@ -1010,18 +1098,26 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("mask_image", required=True), - InputParam("image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("padding_mask_crop"), + InputParam.template("mask_image"), + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("padding_mask_crop"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="processed_image"), - OutputParam(name="processed_mask_image"), + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), OutputParam( name="mask_overlay_kwargs", type_hint=Dict, @@ -1061,7 +1157,32 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be + resized first. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + resized_image (`Image`): + The resized image. should be generated using a resize step + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + model_name = "qwenimage-edit" @property @@ -1082,16 +1203,25 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("mask_image", required=True), - InputParam("resized_image", required=True), - InputParam("padding_mask_crop"), + InputParam.template("mask_image"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The resized image. should be generated using a resize step", + ), + InputParam.template("padding_mask_crop"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="processed_image"), - OutputParam(name="processed_mask_image"), + OutputParam(name="processed_image", type_hint=torch.Tensor, description="The processed image"), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), OutputParam( name="mask_overlay_kwargs", type_hint=Dict, @@ -1119,7 +1249,27 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. will resize the image to the given height and width. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage" @property @@ -1140,14 +1290,20 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("image", required=True), - InputParam("height"), - InputParam("width"), + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam(name="processed_image")] + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] @staticmethod def check_inputs(height, width, vae_scale_factor): @@ -1177,7 +1333,23 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images needs to be resized first. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage-edit" @property @@ -1198,12 +1370,23 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("resized_image", required=True), + InputParam( + name="resized_image", + required=True, + type_hint=List[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ), ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam(name="processed_image")] + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): @@ -1221,12 +1404,29 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of + processed images. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage-edit-plus" @property def description(self) -> str: - return "Image Preprocess step. Images can be resized first using QwenImageEditResizeStep." + return "Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of processed images." @property def expected_components(self) -> List[ComponentSpec]: @@ -1241,11 +1441,24 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [InputParam("resized_image")] + return [ + InputParam( + name="resized_image", + required=True, + type_hint=List[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ) + ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam(name="processed_image")] + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): @@ -1263,7 +1476,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): processed_images.append( components.image_processor.preprocess(image=img, height=img_height, width=img_width) ) - block_state.processed_image = processed_images + if is_image_list: block_state.processed_image = processed_images else: @@ -1276,15 +1489,34 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): # ==================== # 5. VAE ENCODER # ==================== + + +# auto_docstring class QwenImageVaeEncoderStep(ModularPipelineBlocks): - """VAE encoder that handles both single images and lists of images with varied resolutions.""" + """ + VAE Encoder step that converts processed_image into latent representations image_latents. + Handles both single images and lists of images with varied resolutions. + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + processed_image (`Tensor`): + The image tensor to encode + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + """ model_name = "qwenimage" def __init__( self, - input_name: str = "processed_image", - output_name: str = "image_latents", + input: Optional[InputParam] = None, + output: Optional[OutputParam] = None, ): """Initialize a VAE encoder step for converting images to latent representations. @@ -1292,11 +1524,26 @@ def __init__( a single tensor, outputs a single latent tensor. Args: - input_name (str, optional): Name of the input image tensor or list. Defaults to "processed_image". - output_name (str, optional): Name of the output latent tensor or list. Defaults to "image_latents". + input (InputParam, optional): Input parameter for the processed image. Defaults to "processed_image". + output (OutputParam, optional): Output parameter for the image latents. Defaults to "image_latents". """ - self._image_input_name = input_name - self._image_latents_output_name = output_name + if input is None: + input = InputParam( + name="processed_image", required=True, type_hint=torch.Tensor, description="The image tensor to encode" + ) + + if output is None: + output = OutputParam.template("image_latents") + + if not isinstance(input, InputParam): + raise ValueError(f"input must be InputParam but is {type(input)}") + if not isinstance(output, OutputParam): + raise ValueError(f"output must be OutputParam but is {type(output)}") + + self._input = input + self._output = output + self._image_input_name = input.name + self._image_latents_output_name = output.name super().__init__() @property @@ -1312,17 +1559,14 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [InputParam(self._image_input_name, required=True), InputParam("generator")] + return [ + self._input, # default is "processed_image" + InputParam.template("generator"), + ] @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - self._image_latents_output_name, - type_hint=torch.Tensor, - description="The latents representing the reference image(s). Single tensor or list depending on input.", - ) - ] + return [self._output] # default is "image_latents" @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -1359,7 +1603,30 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts `control_image` into latent representations control_image_latents. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + model_name = "qwenimage" @property @@ -1383,10 +1650,10 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam("control_image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("generator"), + InputParam.template("control_image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("generator"), ] return inputs @@ -1473,23 +1740,38 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - # ==================== # 6. PERMUTE LATENTS # ==================== + + +# auto_docstring class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): - """Permute image latents from VAE format to Layered format.""" + """ + Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing. - model_name = "qwenimage-layered" + Inputs: + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. - def __init__(self, input_name: str = "image_latents"): - self._input_name = input_name - super().__init__() + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. (permuted from [B, C, 1, H, W] to [B, 1, C, H, W]) + """ + + model_name = "qwenimage-layered" @property def description(self) -> str: - return f"Permute {self._input_name} from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." + return "Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." @property def inputs(self) -> List[InputParam]: return [ - InputParam(self._input_name, required=True), + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("image_latents", note="permuted from [B, C, 1, H, W] to [B, 1, C, H, W]"), ] @torch.no_grad() @@ -1497,8 +1779,8 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Permute: (B, C, 1, H, W) -> (B, 1, C, H, W) - latents = getattr(block_state, self._input_name) - setattr(block_state, self._input_name, latents.permute(0, 2, 1, 3, 4)) + latents = block_state.image_latents + block_state.image_latents = latents.permute(0, 2, 1, 3, 4) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 4a1cf3700c57..818bbca5ed0a 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import List, Optional, Tuple import torch @@ -109,7 +109,44 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: in return height, width +# auto_docstring class QwenImageTextInputsStep(ModularPipelineBlocks): + """ + Text input processing step that standardizes text embeddings for the pipeline. + This step: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt) + + This block should be placed after all encoder steps to process the text embeddings before they are used in + subsequent pipeline steps. + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + """ + model_name = "qwenimage" @property @@ -129,26 +166,22 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"), - InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"), - InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"), - InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"), + InputParam.template("num_images_per_prompt"), + InputParam.template("prompt_embeds"), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask"), ] @property - def intermediate_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - "batch_size", - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", - ), - OutputParam( - "dtype", - type_hint=torch.dtype, - description="Data type of model tensor inputs (determined by `prompt_embeds`)", - ), + OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"), + OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"), + OutputParam.template("prompt_embeds", note="batch-expanded"), + OutputParam.template("prompt_embeds_mask", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"), ] @staticmethod @@ -221,20 +254,76 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage: update height/width, expand batch, patchify.""" + """ + Input processing step that: + 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ model_name = "qwenimage" def __init__( self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, ): + # by default, process `image_latents` + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + if not isinstance(image_latent_inputs, list): - image_latent_inputs = [image_latent_inputs] + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + if not isinstance(additional_batch_inputs, list): - additional_batch_inputs = [additional_batch_inputs] + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs @@ -252,9 +341,9 @@ def description(self) -> str: if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" if self._image_latent_inputs: - inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" if self._additional_batch_inputs: - inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." @@ -269,23 +358,19 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="batch_size", required=True), - InputParam(name="height"), - InputParam(name="width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), ] - - for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) - - for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ + outputs = [ OutputParam( name="image_height", type_hint=int, @@ -298,11 +383,43 @@ def intermediate_outputs(self) -> List[OutputParam]: ), ] + # `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Process image latent inputs - for image_latent_input_name in self._image_latent_inputs: + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue @@ -331,7 +448,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - setattr(block_state, image_latent_input_name, image_latent_tensor) # Process additional batch inputs (only batch expansion) - for input_name in self._additional_batch_inputs: + for input_param in self._additional_batch_inputs: + input_name = input_param.name input_tensor = getattr(block_state, input_name) if input_tensor is None: continue @@ -349,20 +467,76 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage Edit Plus: handles list of latents with different sizes.""" + """ + Input processing step for Edit Plus that: + 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch + 2. For additional batch inputs: Expands batch dimensions to match final batch size + Height/width defaults to last image in the list. + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ model_name = "qwenimage-edit-plus" def __init__( self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + if not isinstance(image_latent_inputs, list): - image_latent_inputs = [image_latent_inputs] + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + if not isinstance(additional_batch_inputs, list): - additional_batch_inputs = [additional_batch_inputs] + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs @@ -381,9 +555,9 @@ def description(self) -> str: if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" if self._image_latent_inputs: - inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" if self._additional_batch_inputs: - inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." @@ -398,23 +572,20 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="batch_size", required=True), - InputParam(name="height"), - InputParam(name="width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), ] - for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) - - for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ + outputs = [ OutputParam( name="image_height", type_hint=List[int], @@ -427,11 +598,43 @@ def intermediate_outputs(self) -> List[OutputParam]: ), ] + # `height`/`width` are updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified, concatenated, and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified, concatenated, and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Process image latent inputs - for image_latent_input_name in self._image_latent_inputs: + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue @@ -476,7 +679,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - setattr(block_state, image_latent_input_name, packed_image_latent_tensors) # Process additional batch inputs (only batch expansion) - for input_name in self._additional_batch_inputs: + for input_param in self._additional_batch_inputs: + input_name = input_param.name input_tensor = getattr(block_state, input_name) if input_tensor is None: continue @@ -494,22 +698,75 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -# YiYi TODO: support define config default component from the ModularPipeline level. -# it is same as QwenImageAdditionalInputsStep, but with layered pachifier. +# same as QwenImageAdditionalInputsStep, but with layered pachifier. + + +# auto_docstring class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier.""" + """ + Input processing step for Layered that: + 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch + size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ model_name = "qwenimage-layered" def __init__( self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + if not isinstance(image_latent_inputs, list): - image_latent_inputs = [image_latent_inputs] + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + if not isinstance(additional_batch_inputs, list): - additional_batch_inputs = [additional_batch_inputs] + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs @@ -527,9 +784,9 @@ def description(self) -> str: if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" if self._image_latent_inputs: - inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" if self._additional_batch_inputs: - inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." @@ -544,21 +801,18 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="batch_size", required=True), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), ] + # default is `image_latents` - for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) - - for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + inputs += self._image_latent_inputs + self._additional_batch_inputs return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ + outputs = [ OutputParam( name="image_height", type_hint=int, @@ -569,15 +823,44 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=int, description="The image width calculated from the image latents dimension", ), - OutputParam(name="height", type_hint=int, description="The height of the image output"), - OutputParam(name="width", type_hint=int, description="The width of the image output"), ] + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified with layered pachifier and batch-expanded)", + ) + ) + + # Add outputs for additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # Process image latent inputs - for image_latent_input_name in self._image_latent_inputs: + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue @@ -608,7 +891,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - setattr(block_state, image_latent_input_name, image_latent_tensor) # Process additional batch inputs (only batch expansion) - for input_name in self._additional_batch_inputs: + for input_param in self._additional_batch_inputs: + input_name = input_param.name input_tensor = getattr(block_state, input_name) if input_tensor is None: continue @@ -626,7 +910,34 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageControlNetInputsStep(ModularPipelineBlocks): + """ + prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps. + + Inputs: + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + control_image_latents (`Tensor`): + The control image latents (patchified and batch-expanded). + height (`int`): + if not provided, updated to control image height + width (`int`): + if not provided, updated to control image width + """ + model_name = "qwenimage" @property @@ -636,11 +947,28 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="control_image_latents", required=True), - InputParam(name="batch_size", required=True), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="height"), - InputParam(name="width"), + InputParam( + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam.template("batch_size"), + InputParam.template("num_images_per_prompt"), + InputParam.template("height"), + InputParam.template("width"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="control_image_latents", + type_hint=torch.Tensor, + description="The control image latents (patchified and batch-expanded).", + ), + OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"), + OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"), ] @torch.no_grad() diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py index ebe0bbbd75ba..5837799d3431 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import PIL.Image import torch from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict, OutputParam +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam from .before_denoise import ( QwenImageControlNetBeforeDenoiserStep, QwenImageCreateMaskLatentsStep, @@ -59,11 +56,91 @@ # ==================== -# 1. VAE ENCODER +# 1. TEXT ENCODER # ==================== +# auto_docstring +class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): + """ + Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextEncoderStep()] + block_names = ["text_encoder"] + block_trigger_inputs = ["prompt"] + + @property + def description(self) -> str: + return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block." + " - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided." + " - if `prompt` is not provided, step will be skipped." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for inpainting tasks. It: + - Resizes the image to the target size, based on `height` and `width`. + - Processes and updates `image` and `mask_image`. + - Creates `image_latents`. + + Components: + image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage" block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()] block_names = ["preprocess", "encode"] @@ -78,7 +155,31 @@ def description(self) -> str: ) +# auto_docstring class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage" block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()] @@ -89,7 +190,6 @@ def description(self) -> str: return "Vae encoder step that preprocess andencode the image inputs into their latent representations." -# Auto VAE encoder class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] block_names = ["inpaint", "img2img"] @@ -107,7 +207,33 @@ def description(self): # optional controlnet vae encoder +# auto_docstring class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block. + - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided. + - if `control_image` is not provided, step will be skipped. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + block_classes = [QwenImageControlNetVaeEncoderStep] block_names = ["controlnet"] block_trigger_inputs = ["control_image"] @@ -123,14 +249,65 @@ def description(self): # ==================== -# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) # ==================== # assemble input steps +# auto_docstring class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + model_name = "qwenimage" - block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])] + block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()] block_names = ["text_inputs", "additional_inputs"] @property @@ -140,12 +317,69 @@ def description(self): " - update height/width based `image_latents`, patchify `image_latents`." +# auto_docstring class QwenImageInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the inpainting denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), QwenImageAdditionalInputsStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] ), ] block_names = ["text_inputs", "additional_inputs"] @@ -158,7 +392,42 @@ def description(self): # assemble prepare latents steps +# auto_docstring class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the pachified latents `mask` based on the processedmask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage" block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] block_names = ["add_noise_to_latents", "create_mask_latents"] @@ -176,7 +445,49 @@ def description(self) -> str: # Qwen Image (text2image) +# auto_docstring class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), @@ -199,9 +510,63 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (inpainting) +# auto_docstring class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageInpaintInputStep(), @@ -226,9 +591,61 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (image2image) +# auto_docstring class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageImg2ImgInputStep(), @@ -253,9 +670,66 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (text2image) with controlnet +# auto_docstring class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), @@ -282,9 +756,72 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (inpainting) with controlnet +# auto_docstring class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageInpaintInputStep(), @@ -313,9 +850,70 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image (image2image) with controlnet +# auto_docstring class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageImg2ImgInputStep(), @@ -344,6 +942,12 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Auto denoise step for QwenImage class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): @@ -402,19 +1006,36 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] # ==================== -# 3. DECODE +# 4. DECODE # ==================== # standard decode step works for most tasks except for inpaint +# auto_docstring class QwenImageDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -425,7 +1046,30 @@ def description(self): # Inpaint decode step +# auto_docstring class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask + overally to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage" block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -452,11 +1096,11 @@ def description(self): # ==================== -# 4. AUTO BLOCKS & PRESETS +# 5. AUTO BLOCKS & PRESETS # ==================== AUTO_BLOCKS = InsertableDict( [ - ("text_encoder", QwenImageTextEncoderStep()), + ("text_encoder", QwenImageAutoTextEncoderStep()), ("vae_encoder", QwenImageAutoVaeEncoderStep()), ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), ("denoise", QwenImageAutoCoreDenoiseStep()), @@ -465,7 +1109,89 @@ def description(self): ) +# auto_docstring class QwenImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage. + - for image-to-image generation, you need to provide `image` + - for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`. + - to run the controlnet workflow, you need to provide `control_image` + - for text-to-image generation, all you need to provide is `prompt` + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`) + control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + mask_image (`Image`, *optional*): + Mask image for inpainting. + image (`Union[Image, List]`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_image_latents (`Tensor`, *optional*): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" block_classes = AUTO_BLOCKS.values() @@ -476,7 +1202,7 @@ def description(self): return ( "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" + "- for image-to-image generation, you need to provide `image`\n" - + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n" + "- to run the controlnet workflow, you need to provide `control_image`\n" + "- for text-to-image generation, all you need to provide is `prompt`" ) @@ -484,5 +1210,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py index 2683e64080bf..e1e5c4335481 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Optional -import PIL.Image import torch from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict, OutputParam +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam from .before_denoise import ( QwenImageCreateMaskLatentsStep, QwenImageEditRoPEInputsStep, @@ -59,8 +58,35 @@ # ==================== +# auto_docstring class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): - """VL encoder that takes both image and text prompts.""" + """ + QwenImage-Edit VL encoder step that encode the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`List`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-edit" block_classes = [ @@ -80,7 +106,30 @@ def description(self) -> str: # Edit VAE encoder +# auto_docstring class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditResizeStep(), @@ -95,12 +144,46 @@ def description(self) -> str: # Edit Inpaint VAE encoder +# auto_docstring class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It: + - resize the image for target area (1024 * 1024) while maintaining the aspect ratio. + - process the resized image and mask image. + - create image latents. + + Components: + image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + mask_image (`Image`): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditResizeStep(), QwenImageEditInpaintProcessImagesInputStep(), - QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"), + QwenImageVaeEncoderStep(), ] block_names = ["resize", "preprocess", "encode"] @@ -137,11 +220,64 @@ def description(self): # assemble input steps +# auto_docstring class QwenImageEditInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageTextInputsStep(), - QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + QwenImageAdditionalInputsStep(), ] block_names = ["text_inputs", "additional_inputs"] @@ -154,12 +290,71 @@ def description(self): ) +# auto_docstring class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit inpaint denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageTextInputsStep(), QwenImageAdditionalInputsStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] ), ] block_names = ["text_inputs", "additional_inputs"] @@ -174,7 +369,42 @@ def description(self): # assemble prepare latents steps +# auto_docstring class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the patchified latents `mask` based on the processed mask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage-edit" block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] block_names = ["add_noise_to_latents", "create_mask_latents"] @@ -189,7 +419,50 @@ def description(self) -> str: # Qwen Image Edit (image2image) core denoise step +# auto_docstring class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditInputStep(), @@ -212,9 +485,62 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Core denoising workflow for QwenImage-Edit edit (img2img) task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Qwen Image Edit (inpainting) core denoise step +# auto_docstring class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit inpaint task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditInpaintInputStep(), @@ -239,6 +565,12 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Core denoising workflow for QwenImage-Edit edit inpaint task." + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # Auto core denoise step for QwenImage Edit class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): @@ -267,6 +599,12 @@ def description(self): "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit." ) + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + # ==================== # 4. DECODE @@ -274,7 +612,26 @@ def description(self): # Decode step (standard) +# auto_docstring class QwenImageEditDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage-edit" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -285,7 +642,30 @@ def description(self): # Inpaint decode step +# auto_docstring class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask + overlay to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage-edit" block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -313,9 +693,7 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] @@ -333,7 +711,66 @@ def outputs(self): ) +# auto_docstring class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit. + - for edit (img2img) generation, you need to provide `image` + - for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide + `padding_mask_crop` + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + mask_image (`Image`, *optional*): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit" block_classes = EDIT_AUTO_BLOCKS.values() block_names = EDIT_AUTO_BLOCKS.keys() @@ -349,5 +786,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py index 99c5b109bf38..37656cef5d76 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -12,11 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import PIL.Image -import torch - from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam @@ -53,12 +48,41 @@ # ==================== +# auto_docstring class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): - """VL encoder that takes both image and text prompts. Uses 384x384 target area.""" + """ + QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-edit-plus" block_classes = [ - QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"), + QwenImageEditPlusResizeStep(), QwenImageEditPlusTextEncoderStep(), ] block_names = ["resize", "encode"] @@ -73,12 +97,36 @@ def description(self) -> str: # ==================== +# auto_docstring class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - """VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area.""" + """ + VAE encoder step that encodes image inputs into latent representations. + Each image is resized independently based on its own aspect ratio to 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ model_name = "qwenimage-edit-plus" block_classes = [ - QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"), + QwenImageEditPlusResizeStep(), QwenImageEditPlusProcessImagesInputStep(), QwenImageVaeEncoderStep(), ] @@ -98,11 +146,66 @@ def description(self) -> str: # assemble input steps +# auto_docstring class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the Edit Plus denoising step. It: + - Standardizes text embeddings batch size. + - Processes list of image latents: patchifies, concatenates along dim=1, expands batch. + - Outputs lists of image_height/image_width for RoPE calculation. + - Defaults height/width from last image in the list. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + model_name = "qwenimage-edit-plus" block_classes = [ QwenImageTextInputsStep(), - QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]), + QwenImageEditPlusAdditionalInputsStep(), ] block_names = ["text_inputs", "additional_inputs"] @@ -118,7 +221,50 @@ def description(self): # Qwen Image Edit Plus (image2image) core denoise step +# auto_docstring class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit Plus edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit-plus" block_classes = [ QwenImageEditPlusInputStep(), @@ -144,9 +290,7 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] @@ -155,7 +299,26 @@ def outputs(self): # ==================== +# auto_docstring class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocesses the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + model_name = "qwenimage-edit-plus" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -179,7 +342,53 @@ def description(self): ) +# auto_docstring class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus. + - `image` is required input (can be single image or list of images). + - Each image is resized independently based on its own aspect ratio. + - VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) + transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit-plus" block_classes = EDIT_PLUS_AUTO_BLOCKS.values() block_names = EDIT_PLUS_AUTO_BLOCKS.keys() @@ -196,5 +405,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py index 63ee36df5112..fdfeab048835 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -12,12 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from typing import List - -import PIL.Image -import torch - from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam @@ -55,8 +49,44 @@ # ==================== +# auto_docstring class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): - """Text encoder that takes text prompt, will generate a prompt based on image if not provided.""" + """ + QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not + provided. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + resized_image (`List`): + The resized images + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-layered" block_classes = [ @@ -77,7 +107,32 @@ def description(self) -> str: # Edit VAE encoder +# auto_docstring class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageLayeredResizeStep(), @@ -98,11 +153,60 @@ def description(self) -> str: # assemble input steps +# auto_docstring class QwenImageLayeredInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the layered denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageTextInputsStep(), - QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]), + QwenImageLayeredAdditionalInputsStep(), ] block_names = ["text_inputs", "additional_inputs"] @@ -116,7 +220,48 @@ def description(self): # Qwen Image Layered (image2image) core denoise step +# auto_docstring class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Layered img2img task. + + Components: + pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageLayeredInputStep(), @@ -142,9 +287,7 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.template("latents"), ] @@ -162,7 +305,54 @@ def outputs(self): ) +# auto_docstring class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for layered denoising tasks using QwenImage-Layered. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`) + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-layered" block_classes = LAYERED_AUTO_BLOCKS.values() block_names = LAYERED_AUTO_BLOCKS.keys() @@ -174,5 +364,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.template("images"), ] diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py index 3d5a00a9df50..5f76a8459fde 100644 --- a/src/diffusers/modular_pipelines/z_image/denoise.py +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -131,7 +131,7 @@ def inputs(self) -> List[Tuple[str, Any]]: ), InputParam( kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", ), ] guider_input_names = [] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 94c4c394465b..2ea7307fec32 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -84,7 +84,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index d259f7ee7865..b41d9772a7cc 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -53,7 +53,6 @@ >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> from diffusers import HiDreamImagePipeline - >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index df5b3f5c10a5..5a6b8d5e9f37 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -85,7 +85,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 66d5ffa6b849..a1d0407caf5e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -459,7 +459,6 @@ def __call__( >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch - >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ) diff --git a/utils/modular_auto_docstring.py b/utils/modular_auto_docstring.py new file mode 100644 index 000000000000..7bb2c87e81da --- /dev/null +++ b/utils/modular_auto_docstring.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Auto Docstring Generator for Modular Pipeline Blocks + +This script scans Python files for classes that have `# auto_docstring` comment above them +and inserts/updates the docstring from the class's `doc` property. + +Run from the root of the repo: + python utils/modular_auto_docstring.py [path] [--fix_and_overwrite] + +Examples: + # Check for auto_docstring markers (will error if found without proper docstring) + python utils/modular_auto_docstring.py + + # Check specific directory + python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/ + + # Fix and overwrite the docstrings + python utils/modular_auto_docstring.py --fix_and_overwrite + +Usage in code: + # auto_docstring + class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + # docstring will be automatically inserted here + + @property + def doc(self): + return "Your docstring content..." +""" + +import argparse +import ast +import glob +import importlib +import os +import re +import sys + + +# All paths are set with the intent you should run this script from the root of the repo +DIFFUSERS_PATH = "src/diffusers" +REPO_PATH = "." + +# Pattern to match the auto_docstring comment +AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$") + + +def setup_diffusers_import(): + """Setup import path to use the local diffusers module.""" + src_path = os.path.join(REPO_PATH, "src") + if src_path not in sys.path: + sys.path.insert(0, src_path) + + +def get_module_from_filepath(filepath: str) -> str: + """Convert a filepath to a module name.""" + filepath = os.path.normpath(filepath) + + if filepath.startswith("src" + os.sep): + filepath = filepath[4:] + + if filepath.endswith(".py"): + filepath = filepath[:-3] + + module_name = filepath.replace(os.sep, ".") + return module_name + + +def load_module(filepath: str): + """Load a module from filepath.""" + setup_diffusers_import() + module_name = get_module_from_filepath(filepath) + + try: + module = importlib.import_module(module_name) + return module + except Exception as e: + print(f"Warning: Could not import module {module_name}: {e}") + return None + + +def get_doc_from_class(module, class_name: str) -> str: + """Get the doc property from an instantiated class.""" + if module is None: + return None + + cls = getattr(module, class_name, None) + if cls is None: + return None + + try: + instance = cls() + if hasattr(instance, "doc"): + return instance.doc + except Exception as e: + print(f"Warning: Could not instantiate {class_name}: {e}") + + return None + + +def find_auto_docstring_classes(filepath: str) -> list: + """ + Find all classes in a file that have # auto_docstring comment above them. + + Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line) + """ + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Parse AST to find class locations and their docstrings + content = "".join(lines) + try: + tree = ast.parse(content) + except SyntaxError as e: + print(f"Syntax error in {filepath}: {e}") + return [] + + # Build a map of class_name -> (class_line, has_docstring, docstring_end_line) + class_info = {} + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + has_docstring = False + docstring_end_line = node.lineno # default to class line + + if node.body and isinstance(node.body[0], ast.Expr): + first_stmt = node.body[0] + if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str): + has_docstring = True + docstring_end_line = first_stmt.end_lineno or first_stmt.lineno + + class_info[node.name] = (node.lineno, has_docstring, docstring_end_line) + + # Now scan for # auto_docstring comments + classes_to_update = [] + + for i, line in enumerate(lines): + if AUTO_DOCSTRING_PATTERN.match(line): + # Found the marker, look for class definition on next non-empty, non-comment line + j = i + 1 + while j < len(lines): + next_line = lines[j].strip() + if next_line and not next_line.startswith("#"): + break + j += 1 + + if j < len(lines) and lines[j].strip().startswith("class "): + # Extract class name + match = re.match(r"class\s+(\w+)", lines[j].strip()) + if match: + class_name = match.group(1) + if class_name in class_info: + class_line, has_docstring, docstring_end_line = class_info[class_name] + classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line)) + + return classes_to_update + + +def strip_class_name_line(doc: str, class_name: str) -> str: + """Remove the 'class ClassName' line from the doc if present.""" + lines = doc.strip().split("\n") + if lines and lines[0].strip() == f"class {class_name}": + # Remove the class line and any blank line following it + lines = lines[1:] + while lines and not lines[0].strip(): + lines = lines[1:] + return "\n".join(lines) + + +def format_docstring(doc: str, indent: str = " ") -> str: + """Format a doc string as a properly indented docstring.""" + lines = doc.strip().split("\n") + + if len(lines) == 1: + return f'{indent}"""{lines[0]}"""\n' + else: + result = [f'{indent}"""\n'] + for line in lines: + if line.strip(): + result.append(f"{indent}{line}\n") + else: + result.append("\n") + result.append(f'{indent}"""\n') + return "".join(result) + + +def process_file(filepath: str, overwrite: bool = False) -> list: + """ + Process a file and find/insert docstrings for # auto_docstring marked classes. + + Returns list of classes that need updating. + """ + classes_to_update = find_auto_docstring_classes(filepath) + + if not classes_to_update: + return [] + + if not overwrite: + # Just return the list of classes that need updating + return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] + + # Load the module to get doc properties + module = load_module(filepath) + + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Process in reverse order to maintain line numbers + updated = False + for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update): + doc = get_doc_from_class(module, class_name) + + if doc is None: + print(f"Warning: Could not get doc for {class_name} in {filepath}") + continue + + # Remove the "class ClassName" line since it's redundant in a docstring + doc = strip_class_name_line(doc, class_name) + + # Format the new docstring with 4-space indent + new_docstring = format_docstring(doc, " ") + + if has_docstring: + # Replace existing docstring (line after class definition to docstring_end_line) + # class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line + lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:] + else: + # Insert new docstring right after class definition line + # class_line is 1-indexed, so lines[class_line-1] is the class line + # Insert at position class_line (which is right after the class line) + lines = lines[:class_line] + [new_docstring] + lines[class_line:] + + updated = True + print(f"Updated docstring for {class_name} in {filepath}") + + if updated: + with open(filepath, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + + return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] + + +def check_auto_docstrings(path: str = None, overwrite: bool = False): + """ + Check all files for # auto_docstring markers and optionally fix them. + """ + if path is None: + path = DIFFUSERS_PATH + + if os.path.isfile(path): + all_files = [path] + else: + all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True) + + all_markers = [] + + for filepath in all_files: + markers = process_file(filepath, overwrite) + all_markers.extend(markers) + + if not overwrite and len(all_markers) > 0: + message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers]) + raise ValueError( + f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n" + f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them." + ) + + if overwrite and len(all_markers) > 0: + print(f"\nUpdated {len(all_markers)} docstring(s).") + elif len(all_markers) == 0: + print("No # auto_docstring markers found.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Check and fix # auto_docstring markers in modular pipeline blocks", + ) + parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)") + parser.add_argument( + "--fix_and_overwrite", + action="store_true", + help="Whether to fix the docstrings by inserting them from doc property.", + ) + + args = parser.parse_args() + + check_auto_docstrings(args.path, args.fix_and_overwrite)