Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 236 additions & 12 deletions src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union

import PIL.Image
import torch

from ..configuration_utils import ConfigMixin, FrozenDict
Expand Down Expand Up @@ -323,11 +324,192 @@ class ConfigSpec:
description: Optional[str] = None


# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
# however some fields are not relevant for intermediate_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
# -> should we use different class for inputs and intermediate_inputs?
# ======================================================
# InputParam and OutputParam templates
# ======================================================

INPUT_PARAM_TEMPLATES = {
"prompt": {
"type_hint": str,
"required": True,
"description": "The prompt or prompts to guide image generation.",
},
"negative_prompt": {
"type_hint": str,
"description": "The prompt or prompts not to guide the image generation.",
},
"max_sequence_length": {
"type_hint": int,
"default": 512,
"description": "Maximum sequence length for prompt encoding.",
},
"height": {
"type_hint": int,
"description": "The height in pixels of the generated image.",
},
"width": {
"type_hint": int,
"description": "The width in pixels of the generated image.",
},
"num_inference_steps": {
"type_hint": int,
"default": 50,
"description": "The number of denoising steps.",
},
"num_images_per_prompt": {
"type_hint": int,
"default": 1,
"description": "The number of images to generate per prompt.",
},
"generator": {
"type_hint": torch.Generator,
"description": "Torch generator for deterministic generation.",
},
"sigmas": {
"type_hint": List[float],
"description": "Custom sigmas for the denoising process.",
},
"strength": {
"type_hint": float,
"default": 0.9,
"description": "Strength for img2img/inpainting.",
},
"image": {
"type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]],
"required": True,
"description": "Reference image(s) for denoising. Can be a single image or list of images.",
},
"latents": {
"type_hint": torch.Tensor,
"description": "Pre-generated noisy latents for image generation.",
},
"timesteps": {
"type_hint": torch.Tensor,
"description": "Timesteps for the denoising process.",
},
"output_type": {
"type_hint": str,
"default": "pil",
"description": "Output format: 'pil', 'np', 'pt'.",
},
"attention_kwargs": {
"type_hint": Dict[str, Any],
"description": "Additional kwargs for attention processors.",
},
"denoiser_input_fields": {
"name": None,
"kwargs_type": "denoiser_input_fields",
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
},
# inpainting
"mask_image": {
"type_hint": PIL.Image.Image,
"required": True,
"description": "Mask image for inpainting.",
},
"padding_mask_crop": {
"type_hint": int,
"description": "Padding for mask cropping in inpainting.",
},
# controlnet
"control_image": {
"type_hint": PIL.Image.Image,
"required": True,
"description": "Control image for ControlNet conditioning.",
},
"control_guidance_start": {
"type_hint": float,
"default": 0.0,
"description": "When to start applying ControlNet.",
},
"control_guidance_end": {
"type_hint": float,
"default": 1.0,
"description": "When to stop applying ControlNet.",
},
"controlnet_conditioning_scale": {
"type_hint": float,
"default": 1.0,
"description": "Scale for ControlNet conditioning.",
},
"layers": {
"type_hint": int,
"default": 4,
"description": "Number of layers to extract from the image",
},
# common intermediate inputs
"prompt_embeds": {
"type_hint": torch.Tensor,
"required": True,
"description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.",
},
"prompt_embeds_mask": {
"type_hint": torch.Tensor,
"required": True,
"description": "mask for the text embeddings. Can be generated from text_encoder step.",
},
"negative_prompt_embeds": {
"type_hint": torch.Tensor,
"description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.",
},
"negative_prompt_embeds_mask": {
"type_hint": torch.Tensor,
"description": "mask for the negative text embeddings. Can be generated from text_encoder step.",
},
"image_latents": {
"type_hint": torch.Tensor,
"required": True,
"description": "image latents used to guide the image generation. Can be generated from vae_encoder step.",
},
"batch_size": {
"type_hint": int,
"default": 1,
"description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
},
"dtype": {
"type_hint": torch.dtype,
"default": torch.float32,
"description": "The dtype of the model inputs, can be generated in input step.",
},
}

OUTPUT_PARAM_TEMPLATES = {
"images": {
"type_hint": List[PIL.Image.Image],
"description": "Generated images.",
},
"latents": {
"type_hint": torch.Tensor,
"description": "Denoised latents.",
},
# intermediate outputs
"prompt_embeds": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The prompt embeddings.",
},
"prompt_embeds_mask": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The encoder attention mask.",
},
"negative_prompt_embeds": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The negative prompt embeddings.",
},
"negative_prompt_embeds_mask": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The negative prompt embeddings mask.",
},
"image_latents": {
"type_hint": torch.Tensor,
"description": "The latent representation of the input image.",
},
}


@dataclass
class InputParam:
"""Specification for an input parameter."""
Expand All @@ -337,11 +519,31 @@ class InputParam:
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
kwargs_type: str = None

def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"

@classmethod
def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam":
"""Get template for name if exists, otherwise raise ValueError."""
if template_name not in INPUT_PARAM_TEMPLATES:
raise ValueError(f"InputParam template for {template_name} not found")

template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy()

# Determine the actual param name:
# 1. From overrides if provided
# 2. From template if present
# 3. Fall back to template_name
name = overrides.pop("name", template_kwargs.pop("name", template_name))

if note and "description" in template_kwargs:
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"

template_kwargs.update(overrides)
return cls(name=name, **template_kwargs)


@dataclass
class OutputParam:
Expand All @@ -350,13 +552,33 @@ class OutputParam:
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
kwargs_type: str = None

def __repr__(self):
return (
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
)

@classmethod
def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam":
"""Get template for name if exists, otherwise raise ValueError."""
if template_name not in OUTPUT_PARAM_TEMPLATES:
raise ValueError(f"OutputParam template for {template_name} not found")

template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy()

# Determine the actual param name:
# 1. From overrides if provided
# 2. From template if present
# 3. Fall back to template_name
name = overrides.pop("name", template_kwargs.pop("name", template_name))

if note and "description" in template_kwargs:
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"

template_kwargs.update(overrides)
return cls(name=name, **template_kwargs)


def format_inputs_short(inputs):
"""
Expand Down Expand Up @@ -509,10 +731,12 @@ def wrap_text(text, indent, max_length):
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
else:
param_str += f"\n{desc_indent}TODO: Add description."

formatted_params.append(param_str)

return "\n\n".join(formatted_params)
return "\n".join(formatted_params)


def format_input_params(input_params, indent_level=4, max_line_length=115):
Expand Down Expand Up @@ -582,7 +806,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
if field_value:
loading_field_values.append(f"{field_name}={field_value}")

# Add loading field information if available
Expand Down Expand Up @@ -669,17 +893,17 @@ def make_doc_string(
# Add description
if description:
desc_lines = description.strip().split("\n")
aligned_desc = "\n".join(" " + line for line in desc_lines)
aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines)
output += aligned_desc + "\n\n"

# Add components section if provided
if expected_components and len(expected_components) > 0:
components_str = format_components(expected_components, indent_level=2)
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
output += components_str + "\n\n"

# Add configs section if provided
if expected_configs and len(expected_configs) > 0:
configs_str = format_configs(expected_configs, indent_level=2)
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
output += configs_str + "\n\n"

# Add inputs section
Expand Down
Loading
Loading