diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 24b9c12db6d4..52ec30c536bd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -413,6 +413,9 @@ _import_structure["modular_pipelines"].extend( [ "Flux2AutoBlocks", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", + "Flux2KleinModularPipeline", "Flux2ModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", @@ -1146,6 +1149,9 @@ else: from .modular_pipelines import ( Flux2AutoBlocks, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, Flux2ModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index e64db23f3831..823a3d263ea9 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -54,7 +54,10 @@ ] _import_structure["flux2"] = [ "Flux2AutoBlocks", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", "Flux2ModularPipeline", + "Flux2KleinModularPipeline", ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", @@ -81,7 +84,13 @@ else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline - from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline + from .flux2 import ( + Flux2AutoBlocks, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, + Flux2ModularPipeline, + ) from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index 21a41c1fe941..220ec0c4ab65 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -43,7 +43,7 @@ "Flux2ProcessImagesInputStep", "Flux2TextInputStep", ] - _import_structure["modular_blocks"] = [ + _import_structure["modular_blocks_flux2"] = [ "ALL_BLOCKS", "AUTO_BLOCKS", "REMOTE_AUTO_BLOCKS", @@ -51,10 +51,11 @@ "IMAGE_CONDITIONED_BLOCKS", "Flux2AutoBlocks", "Flux2AutoVaeEncoderStep", - "Flux2BeforeDenoiseStep", + "Flux2CoreDenoiseStep", "Flux2VaeEncoderSequentialStep", ] - _import_structure["modular_pipeline"] = ["Flux2ModularPipeline"] + _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"] + _import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -85,7 +86,7 @@ Flux2ProcessImagesInputStep, Flux2TextInputStep, ) - from .modular_blocks import ( + from .modular_blocks_flux2 import ( ALL_BLOCKS, AUTO_BLOCKS, IMAGE_CONDITIONED_BLOCKS, @@ -93,10 +94,14 @@ TEXT2IMAGE_BLOCKS, Flux2AutoBlocks, Flux2AutoVaeEncoderStep, - Flux2BeforeDenoiseStep, + Flux2CoreDenoiseStep, Flux2VaeEncoderSequentialStep, ) - from .modular_pipeline import Flux2ModularPipeline + from .modular_blocks_flux2_klein import ( + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + ) + from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py index 42624688adfa..d5bab16586d7 100644 --- a/src/diffusers/modular_pipelines/flux2/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -129,17 +129,9 @@ def inputs(self) -> List[InputParam]: InputParam("num_inference_steps", default=50), InputParam("timesteps"), InputParam("sigmas"), - InputParam("guidance_scale", default=4.0), InputParam("latents", type_hint=torch.Tensor), - InputParam("num_images_per_prompt", default=1), InputParam("height", type_hint=int), InputParam("width", type_hint=int), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", - ), ] @property @@ -151,13 +143,12 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=int, description="The number of denoising steps to perform at inference time", ), - OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), ] @torch.no_grad() def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device scheduler = components.scheduler @@ -183,7 +174,7 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi timesteps, num_inference_steps = retrieve_timesteps( scheduler, num_inference_steps, - block_state.device, + device, timesteps=timesteps, sigmas=sigmas, mu=mu, @@ -191,11 +182,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.timesteps = timesteps block_state.num_inference_steps = num_inference_steps - batch_size = block_state.batch_size * block_state.num_images_per_prompt - guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) - guidance = guidance.expand(batch_size) - block_state.guidance = guidance - components.scheduler.set_begin_index(0) self.set_block_state(state, block_state) @@ -353,7 +339,61 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam(name="prompt_embeds", required=True), - InputParam(name="latent_ids"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + InputParam(name="negative_prompt_embeds", required=False), ] @property @@ -366,10 +406,10 @@ def intermediate_outputs(self) -> List[OutputParam]: description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", ), OutputParam( - name="latent_ids", + name="negative_txt_ids", kwargs_type="denoiser_input_fields", type_hint=torch.Tensor, - description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.", + description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.", ), ] @@ -399,6 +439,11 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.txt_ids = self._prepare_text_ids(prompt_embeds) block_state.txt_ids = block_state.txt_ids.to(device) + block_state.negative_txt_ids = None + if block_state.negative_prompt_embeds is not None: + block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds) + block_state.negative_txt_ids = block_state.negative_txt_ids.to(device) + self.set_block_state(state, block_state) return components, state @@ -506,3 +551,42 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi self.set_block_state(state, block_state) return components, state + + +class Flux2PrepareGuidanceStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the guidance scale tensor for Flux2 inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("guidance_scale", default=4.0), + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/decoders.py b/src/diffusers/modular_pipelines/flux2/decoders.py index b769d9119891..c79375072037 100644 --- a/src/diffusers/modular_pipelines/flux2/decoders.py +++ b/src/diffusers/modular_pipelines/flux2/decoders.py @@ -29,29 +29,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class Flux2DecodeStep(ModularPipelineBlocks): +class Flux2UnpackLatentsStep(ModularPipelineBlocks): model_name = "flux2" - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKLFlux2), - ComponentSpec( - "image_processor", - Flux2ImageProcessor, - config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), - default_creation_method="from_config", - ), - ] - @property def description(self) -> str: - return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + return "Step that unpacks the latents from the denoising step" @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("output_type", default="pil"), InputParam( "latents", required=True, @@ -70,9 +57,9 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediate_outputs(self) -> List[str]: return [ OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], - description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + "latents", + type_hint=torch.Tensor, + description="The denoise latents from denoising step, unpacked with position IDs.", ) ] @@ -107,6 +94,62 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tens return torch.stack(x_list, dim=0) + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + block_state.latents = latents + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + @staticmethod def _unpatchify_latents(latents): """Convert patchified latents back to regular format.""" @@ -121,26 +164,20 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) vae = components.vae - if block_state.output_type == "latent": - block_state.images = block_state.latents - else: - latents = block_state.latents - latent_ids = block_state.latent_ids - - latents = self._unpack_latents_with_ids(latents, latent_ids) + latents = block_state.latents - latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) - latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( - latents.device, latents.dtype - ) - latents = latents * latents_bn_std + latents_bn_mean + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean - latents = self._unpatchify_latents(latents) + latents = self._unpatchify_latents(latents) - block_state.images = vae.decode(latents, return_dict=False)[0] - block_state.images = components.image_processor.postprocess( - block_state.images, output_type=block_state.output_type - ) + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index c12eca65c6a9..a726959a29e2 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -16,6 +16,8 @@ import torch +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...models import Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging @@ -25,8 +27,8 @@ ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import Flux2ModularPipeline +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline if is_torch_xla_available(): @@ -134,6 +136,229 @@ def __call__( return components, block_state +# same as Flux2LoopDenoiser but guidance=None +class Flux2KleinLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# support CFG for Flux2-Klein base model +class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", Flux2Transformer2DModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "negative_prompt_embeds", + required=False, + type_hint=torch.Tensor, + description="Negative text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "negative_txt_ids", + required=False, + type_hint=torch.Tensor, + description="4D position IDs for negative text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "txt_ids": ( + getattr(block_state, "txt_ids", None), + getattr(block_state, "negative_txt_ids", None), + ), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)] + components.guider.cleanup_models(components.transformer) + + # perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + class Flux2LoopAfterDenoiser(ModularPipelineBlocks): model_name = "flux2" @@ -250,3 +475,35 @@ def description(self) -> str: " - `Flux2LoopAfterDenoiser`\n" "This block supports both text-to-image and image-conditioned generation." ) + + +class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinBaseLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 6cb0e3bf0a26..265fb387367c 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -15,13 +15,15 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import Flux2ModularPipeline +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -79,10 +81,8 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), - InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), InputParam("max_sequence_length", type_hint=int, default=512, required=False), InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False), - InputParam("joint_attention_kwargs"), ] @property @@ -99,14 +99,7 @@ def intermediate_outputs(self) -> List[OutputParam]: @staticmethod def check_inputs(block_state): prompt = block_state.prompt - prompt_embeds = getattr(block_state, "prompt_embeds", None) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " - "Please make sure to only forward one of the two." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @staticmethod @@ -165,10 +158,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.device = components._execution_device - if block_state.prompt_embeds is not None: - self.set_block_state(state, block_state) - return components, state - prompt = block_state.prompt if prompt is None: prompt = "" @@ -205,7 +194,6 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), - InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), ] @property @@ -222,15 +210,8 @@ def intermediate_outputs(self) -> List[OutputParam]: @staticmethod def check_inputs(block_state): prompt = block_state.prompt - prompt_embeds = getattr(block_state, "prompt_embeds", None) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " - "Please make sure to only forward one of the two." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") @torch.no_grad() def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: @@ -244,10 +225,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.device = components._execution_device - if block_state.prompt_embeds is not None: - self.set_block_state(state, block_state) - return components, state - prompt = block_state.prompt if prompt is None: prompt = "" @@ -270,6 +247,289 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi return components, state +class Flux2KleinTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=True), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Negative text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + if components.requires_unconditional_embeds: + negative_prompt = [""] * len(prompt) + block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + else: + block_state.negative_prompt_embeds = None + + self.set_block_state(state, block_state) + return components, state + + class Flux2VaeEncoderStep(ModularPipelineBlocks): model_name = "flux2" diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py index c9e337fb0bf0..3463de1999c6 100644 --- a/src/diffusers/modular_pipelines/flux2/inputs.py +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -47,7 +47,7 @@ def inputs(self) -> List[InputParam]: required=True, kwargs_type="denoiser_input_fields", type_hint=torch.Tensor, - description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.", + description="Pre-generated text embeddings. Can be generated from text_encoder step.", ), ] @@ -89,6 +89,90 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi return components, state +class Flux2KleinBaseTextInputStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + required=False, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + class Flux2ProcessImagesInputStep(ModularPipelineBlocks): model_name = "flux2" diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py similarity index 64% rename from src/diffusers/modular_pipelines/flux2/modular_blocks.py rename to src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index a31673b6e78c..41a0ff7dee28 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +import PIL.Image +import torch + from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict +from ..modular_pipeline_utils import InsertableDict, OutputParam from .before_denoise import ( + Flux2PrepareGuidanceStep, Flux2PrepareImageLatentsStep, Flux2PrepareLatentsStep, Flux2RoPEInputsStep, Flux2SetTimestepsStep, ) -from .decoders import Flux2DecodeStep +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep from .denoise import Flux2DenoiseStep from .encoders import ( Flux2RemoteTextEncoderStep, @@ -41,7 +47,6 @@ [ ("preprocess", Flux2ProcessImagesInputStep()), ("encode", Flux2VaeEncoderStep()), - ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ] ) @@ -72,33 +77,56 @@ def description(self): ) -Flux2BeforeDenoiseBlocks = InsertableDict( +Flux2CoreDenoiseBlocks = InsertableDict( [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ] ) -class Flux2BeforeDenoiseStep(SequentialPipelineBlocks): +class Flux2CoreDenoiseStep(SequentialPipelineBlocks): model_name = "flux2" - block_classes = Flux2BeforeDenoiseBlocks.values() - block_names = Flux2BeforeDenoiseBlocks.keys() + block_classes = Flux2CoreDenoiseBlocks.values() + block_names = Flux2CoreDenoiseBlocks.keys() @property def description(self): - return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation." + return ( + "Core denoise step that performs the denoising process for Flux2-dev.\n" + " - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] AUTO_BLOCKS = InsertableDict( [ ("text_encoder", Flux2TextEncoderStep()), - ("text_input", Flux2TextInputStep()), - ("vae_image_encoder", Flux2AutoVaeEncoderStep()), - ("before_denoise", Flux2BeforeDenoiseStep()), - ("denoise", Flux2DenoiseStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), ("decode", Flux2DecodeStep()), ] ) @@ -107,10 +135,8 @@ def description(self): REMOTE_AUTO_BLOCKS = InsertableDict( [ ("text_encoder", Flux2RemoteTextEncoderStep()), - ("text_input", Flux2TextInputStep()), - ("vae_image_encoder", Flux2AutoVaeEncoderStep()), - ("before_denoise", Flux2BeforeDenoiseStep()), - ("denoise", Flux2DenoiseStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), ("decode", Flux2DecodeStep()), ] ) @@ -130,6 +156,16 @@ def description(self): "- For image-conditioned generation, you need to provide `image` (list of PIL images)." ) + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + TEXT2IMAGE_BLOCKS = InsertableDict( [ @@ -137,8 +173,10 @@ def description(self): ("text_input", Flux2TextInputStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) @@ -152,8 +190,10 @@ def description(self): ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py new file mode 100644 index 000000000000..984832d77be5 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -0,0 +1,232 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import PIL.Image +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2KleinBaseRoPEInputsStep, + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep +from .encoders import ( + Flux2KleinBaseTextEncoderStep, + Flux2KleinTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2KleinBaseTextInputStep, + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +################ +# VAE encoder +################ + +Flux2KleinVaeEncoderBlocks = InsertableDict( + [ + ("preprocess", Flux2ProcessImagesInputStep()), + ("encode", Flux2VaeEncoderStep()), + ] +) + + +class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2KleinVaeEncoderBlocks.values() + block_names = Flux2KleinVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + +class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2KleinVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +### +### Core denoise +### + +Flux2KleinCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + + block_classes = Flux2KleinCoreDenoiseBlocks.values() + block_names = Flux2KleinCoreDenoiseBlocks.keys() + + @property + def description(self): + return ( + "Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] + + +Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2KleinBaseTextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." + return ( + "Core denoise step that performs the denoising process for Flux2-Klein (base model).\n" + " - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n" + " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] + + +### +### Auto blocks +### +class Flux2KleinAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + + +class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [ + Flux2KleinBaseTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinBaseCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py index 3e497f3b1e98..29fbeba07c24 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -13,6 +13,8 @@ # limitations under the License. +from typing import Any, Dict, Optional + from ...loaders import Flux2LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline @@ -55,3 +57,56 @@ def num_channels_latents(self): if getattr(self, "transformer", None): num_channels_latents = self.transformer.config.in_channels // 4 return num_channels_latents + + +class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2-Klein. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" + + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: + return "Flux2KleinAutoBlocks" + else: + return "Flux2KleinBaseAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index d857fd040955..98ede73c21fe 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -59,6 +59,7 @@ ("flux", "FluxModularPipeline"), ("flux-kontext", "FluxKontextModularPipeline"), ("flux2", "Flux2ModularPipeline"), + ("flux2-klein", "Flux2KleinModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 63f381419fda..a23f852616c0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -17,6 +17,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2KleinBaseAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2KleinModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2ModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py new file mode 100644 index 000000000000..26653b20f8c4 --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2KleinAutoBlocks, + Flux2KleinModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular" + + params = frozenset(["prompt", "height", "width", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py new file mode 100644 index 000000000000..701dd0fed896 --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinBaseAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinBaseAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular" + + params = frozenset(["prompt", "height", "width", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return