From 618a8a989736318e2eff267bdd8bf69ab0651186 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Tue, 20 Jan 2026 01:20:14 +0000 Subject: [PATCH 01/13] support klein --- src/diffusers/__init__.py | 6 + src/diffusers/modular_pipelines/__init__.py | 5 +- .../modular_pipelines/flux2/__init__.py | 13 +- .../modular_pipelines/flux2/before_denoise.py | 11 +- .../modular_pipelines/flux2/denoise.py | 275 +++++++++++++++++- .../modular_pipelines/flux2/encoders.py | 187 ++++++++++-- .../modular_pipelines/flux2/inputs.py | 22 +- ...ular_blocks.py => modular_blocks_flux2.py} | 0 .../flux2/modular_blocks_flux2_klein.py | 164 +++++++++++ .../flux2/modular_pipeline.py | 59 ++++ .../modular_pipelines/modular_pipeline.py | 1 + 11 files changed, 701 insertions(+), 42 deletions(-) rename src/diffusers/modular_pipelines/flux2/{modular_blocks.py => modular_blocks_flux2.py} (100%) create mode 100644 src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 24b9c12db6d4..71228a5598b3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -414,6 +414,9 @@ [ "Flux2AutoBlocks", "Flux2ModularPipeline", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", + "Flux2KleinModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -1147,6 +1150,9 @@ from .modular_pipelines import ( Flux2AutoBlocks, Flux2ModularPipeline, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index e64db23f3831..099e86a553a8 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -54,7 +54,10 @@ ] _import_structure["flux2"] = [ "Flux2AutoBlocks", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", "Flux2ModularPipeline", + "Flux2KleinModularPipeline", ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", @@ -81,7 +84,7 @@ else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline - from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline + from .flux2 import Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, Flux2ModularPipeline, Flux2KleinModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index 21a41c1fe941..64ced29bddd6 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -43,7 +43,7 @@ "Flux2ProcessImagesInputStep", "Flux2TextInputStep", ] - _import_structure["modular_blocks"] = [ + _import_structure["modular_blocks_flux2"] = [ "ALL_BLOCKS", "AUTO_BLOCKS", "REMOTE_AUTO_BLOCKS", @@ -54,7 +54,8 @@ "Flux2BeforeDenoiseStep", "Flux2VaeEncoderSequentialStep", ] - _import_structure["modular_pipeline"] = ["Flux2ModularPipeline"] + _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"] + _import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -85,7 +86,7 @@ Flux2ProcessImagesInputStep, Flux2TextInputStep, ) - from .modular_blocks import ( + from .modular_blocks_flux2 import ( ALL_BLOCKS, AUTO_BLOCKS, IMAGE_CONDITIONED_BLOCKS, @@ -96,7 +97,11 @@ Flux2BeforeDenoiseStep, Flux2VaeEncoderSequentialStep, ) - from .modular_pipeline import Flux2ModularPipeline + from .modular_blocks_flux2_klein import ( + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + ) + from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py index 42624688adfa..e1001924c70b 100644 --- a/src/diffusers/modular_pipelines/flux2/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -353,7 +353,7 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam(name="prompt_embeds", required=True), - InputParam(name="latent_ids"), + InputParam(name="negative_prompt_embeds", required=False), ] @property @@ -366,10 +366,10 @@ def intermediate_outputs(self) -> List[OutputParam]: description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", ), OutputParam( - name="latent_ids", + name="negative_txt_ids", kwargs_type="denoiser_input_fields", type_hint=torch.Tensor, - description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.", + description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.", ), ] @@ -399,6 +399,11 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.txt_ids = self._prepare_text_ids(prompt_embeds) block_state.txt_ids = block_state.txt_ids.to(device) + block_state.negative_txt_ids = None + if block_state.negative_prompt_embeds is not None: + block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds) + block_state.negative_txt_ids = block_state.negative_txt_ids.to(device) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index c12eca65c6a9..84cad52ab7cf 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -13,20 +13,23 @@ # limitations under the License. from typing import Any, List, Tuple +import inspect import torch +from ...configuration_utils import FrozenDict from ...models import Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging +from ...guiders import ClassifierFreeGuidance from ..modular_pipeline import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import Flux2ModularPipeline +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec +from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline if is_torch_xla_available(): @@ -133,6 +136,241 @@ def __call__( return components, block_state +# sane as Flux2 but guidance=None +class Flux2KleinLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# support CFG for Flux2-Klein base model +class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", Flux2Transformer2DModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "negative_prompt_embeds", + required=False, + type_hint=torch.Tensor, + description="Negative text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "negative_txt_ids", + required=False, + type_hint=torch.Tensor, + description="4D position IDs for negative text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ) + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "txt_ids": ( + getattr(block_state, "txt_ids", None), + getattr(block_state, "negative_txt_ids", None), + ), + } + + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + cond_kwargs.update(additional_cond_kwargs) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)] + components.guider.cleanup_models(components.transformer) + + # perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + + return components, block_state + class Flux2LoopAfterDenoiser(ModularPipelineBlocks): model_name = "flux2" @@ -220,6 +458,8 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 ) + block_state.additional_cond_kwargs = {} + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) @@ -250,3 +490,34 @@ def description(self) -> str: " - `Flux2LoopAfterDenoiser`\n" "This block supports both text-to-image and image-conditioned generation." ) + +class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinBaseLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 6cb0e3bf0a26..5c06746aa24f 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -15,13 +15,13 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen3ForCausalLM, Qwen2TokenizerFast from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import Flux2ModularPipeline +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec +from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -79,10 +79,8 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), - InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), InputParam("max_sequence_length", type_hint=int, default=512, required=False), InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False), - InputParam("joint_attention_kwargs"), ] @property @@ -99,14 +97,7 @@ def intermediate_outputs(self) -> List[OutputParam]: @staticmethod def check_inputs(block_state): prompt = block_state.prompt - prompt_embeds = getattr(block_state, "prompt_embeds", None) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " - "Please make sure to only forward one of the two." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @staticmethod @@ -165,10 +156,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.device = components._execution_device - if block_state.prompt_embeds is not None: - self.set_block_state(state, block_state) - return components, state - prompt = block_state.prompt if prompt is None: prompt = "" @@ -205,7 +192,6 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), - InputParam("prompt_embeds", type_hint=torch.Tensor, required=False), ] @property @@ -222,15 +208,8 @@ def intermediate_outputs(self) -> List[OutputParam]: @staticmethod def check_inputs(block_state): prompt = block_state.prompt - prompt_embeds = getattr(block_state, "prompt_embeds", None) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " - "Please make sure to only forward one of the two." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") @torch.no_grad() def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: @@ -244,10 +223,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.device = components._execution_device - if block_state.prompt_embeds is not None: - self.set_block_state(state, block_state) - return components, state - prompt = block_state.prompt if prompt is None: prompt = "" @@ -270,6 +245,156 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi return components, state + +class Flux2KleinTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Negative text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + if components.requires_unconditional_embeds: + negative_prompt = "" + block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + else: + block_state.negative_prompt_embeds = None + + self.set_block_state(state, block_state) + return components, state + + class Flux2VaeEncoderStep(ModularPipelineBlocks): model_name = "flux2" diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py index c9e337fb0bf0..0de0040c3923 100644 --- a/src/diffusers/modular_pipelines/flux2/inputs.py +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -47,7 +47,14 @@ def inputs(self) -> List[InputParam]: required=True, kwargs_type="denoiser_input_fields", type_hint=torch.Tensor, - description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.", + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + required=False, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", ), ] @@ -70,6 +77,12 @@ def intermediate_outputs(self) -> List[str]: kwargs_type="denoiser_input_fields", description="Text embeddings used to guide the image generation", ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative text embeddings used to guide the image generation", + ), ] @torch.no_grad() @@ -85,6 +98,13 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 ) + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py similarity index 100% rename from src/diffusers/modular_pipelines/flux2/modular_blocks.py rename to src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py new file mode 100644 index 000000000000..2f89106b1351 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -0,0 +1,164 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep +from .denoise import Flux2KleinDenoiseStep, Flux2KleinBaseDenoiseStep +from .encoders import ( + Flux2KleinTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +Flux2KleinVaeEncoderBlocks = InsertableDict( + [ + ("preprocess", Flux2ProcessImagesInputStep()), + ("encode", Flux2VaeEncoderStep()), + ] +) + + +class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2KleinVaeEncoderBlocks.values() + block_names = Flux2KleinVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + + +class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2KleinVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + + +Flux2KleinCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ] +) + + +class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + + block_classes = Flux2KleinCoreDenoiseBlocks.values() + block_names = Flux2KleinCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." + return ( + "Core denoise step that performs the denoising process for Flux2-Klein.\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n" + ) + + +Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ] +) + +class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." + return ( + "Core denoise step that performs the denoising process for Flux2-Klein (base model).\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" + ) + + + +class Flux2KleinAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinCoreDenoiseStep(), Flux2DecodeStep()] + block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) + + + +class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep()] + block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py index 3e497f3b1e98..e37dafcfce6e 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -17,6 +17,8 @@ from ...utils import logging from ..modular_pipeline import ModularPipeline +from typing import Optional, Dict, Any + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,3 +57,60 @@ def num_channels_latents(self): if getattr(self, "transformer", None): num_channels_latents = self.transformer.config.in_channels // 4 return num_channels_latents + + + + +class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2-Klein. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" + + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + + if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: + return "Flux2KleinAutoBlocks" + else: + return "Flux2KleinBaseAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index d857fd040955..98ede73c21fe 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -59,6 +59,7 @@ ("flux", "FluxModularPipeline"), ("flux-kontext", "FluxKontextModularPipeline"), ("flux2", "Flux2ModularPipeline"), + ("flux2-klein", "Flux2KleinModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), From fb2cb18f730d4d668436e9da500103dff8644c34 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Tue, 20 Jan 2026 01:31:41 +0000 Subject: [PATCH 02/13] style --- src/diffusers/__init__.py | 4 +-- src/diffusers/modular_pipelines/__init__.py | 8 ++++- .../modular_pipelines/flux2/__init__.py | 2 +- .../modular_pipelines/flux2/denoise.py | 15 +++++---- .../modular_pipelines/flux2/encoders.py | 11 +++---- .../modular_pipelines/flux2/inputs.py | 4 ++- .../flux2/modular_blocks_flux2_klein.py | 33 +++++++++++-------- .../flux2/modular_pipeline.py | 10 ++---- 8 files changed, 48 insertions(+), 39 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 71228a5598b3..52ec30c536bd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -413,10 +413,10 @@ _import_structure["modular_pipelines"].extend( [ "Flux2AutoBlocks", - "Flux2ModularPipeline", "Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks", "Flux2KleinModularPipeline", + "Flux2ModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -1149,10 +1149,10 @@ else: from .modular_pipelines import ( Flux2AutoBlocks, - Flux2ModularPipeline, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, Flux2KleinModularPipeline, + Flux2ModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 099e86a553a8..823a3d263ea9 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -84,7 +84,13 @@ else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline - from .flux2 import Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, Flux2ModularPipeline, Flux2KleinModularPipeline + from .flux2 import ( + Flux2AutoBlocks, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, + Flux2ModularPipeline, + ) from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index 64ced29bddd6..fb97a56fb049 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -101,7 +101,7 @@ Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, ) - from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline + from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index 84cad52ab7cf..b2e1b41dde88 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple import inspect +from typing import Any, List, Tuple import torch from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...models import Flux2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging -from ...guiders import ClassifierFreeGuidance from ..modular_pipeline import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec -from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline if is_torch_xla_available(): @@ -136,7 +136,8 @@ def __call__( return components, block_state -# sane as Flux2 but guidance=None + +# same as Flux2LoopDenoiser but guidance=None class Flux2KleinLoopDenoiser(ModularPipelineBlocks): model_name = "flux2-klein" @@ -308,7 +309,7 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam( kwargs_type="denoiser_input_fields", description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ) + ), ] @torch.no_grad() @@ -368,7 +369,6 @@ def __call__( # perform guidance block_state.noise_pred = components.guider(guider_state)[0] - return components, block_state @@ -491,6 +491,7 @@ def description(self) -> str: "This block supports both text-to-image and image-conditioned generation." ) + class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 5c06746aa24f..1d9e56bdf028 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -15,13 +15,13 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen3ForCausalLM, Qwen2TokenizerFast +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec -from .modular_pipeline import Flux2ModularPipeline, Flux2KleinModularPipeline +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -245,11 +245,9 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi return components, state - class Flux2KleinTextEncoderStep(ModularPipelineBlocks): model_name = "flux2-klein" - @property def description(self) -> str: return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" @@ -284,7 +282,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=torch.Tensor, description="Text embeddings from qwen3 used to guide the image generation", ), - OutputParam( "negative_prompt_embeds", type_hint=torch.Tensor, @@ -390,7 +387,7 @@ def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) ) else: block_state.negative_prompt_embeds = None - + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py index 0de0040c3923..cc078c826206 100644 --- a/src/diffusers/modular_pipelines/flux2/inputs.py +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -100,7 +100,9 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi if block_state.negative_prompt_embeds is not None: _, seq_len, _ = block_state.negative_prompt_embeds.shape - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 ) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index 2f89106b1351..1dd63a6123e2 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -22,7 +22,7 @@ Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep -from .denoise import Flux2KleinDenoiseStep, Flux2KleinBaseDenoiseStep +from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep from .encoders import ( Flux2KleinTextEncoderStep, Flux2VaeEncoderStep, @@ -55,7 +55,6 @@ def description(self) -> str: return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." - class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [Flux2KleinVaeEncoderSequentialStep] block_names = ["img_conditioning"] @@ -71,9 +70,8 @@ def description(self): ) - Flux2KleinCoreDenoiseBlocks = InsertableDict( - [ + [ ("input", Flux2TextInputStep()), ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), @@ -89,7 +87,7 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): block_classes = Flux2KleinCoreDenoiseBlocks.values() block_names = Flux2KleinCoreDenoiseBlocks.keys() - + @property def description(self): return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." @@ -105,7 +103,7 @@ def description(self): Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( - [ + [ ("input", Flux2TextInputStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("prepare_image_latents", Flux2PrepareImageLatentsStep()), @@ -115,11 +113,12 @@ def description(self): ] ) + class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): model_name = "flux2-klein" block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() - + @property def description(self): return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." @@ -134,12 +133,16 @@ def description(self): ) - class Flux2KleinAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" - block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinCoreDenoiseStep(), Flux2DecodeStep()] + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinCoreDenoiseStep(), + Flux2DecodeStep(), + ] block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] - + @property def description(self): return ( @@ -149,12 +152,16 @@ def description(self): ) - class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" - block_classes = [Flux2KleinTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep()] + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinBaseCoreDenoiseStep(), + Flux2DecodeStep(), + ] block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] - + @property def description(self): return ( diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py index e37dafcfce6e..29fbeba07c24 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -13,12 +13,12 @@ # limitations under the License. +from typing import Any, Dict, Optional + from ...loaders import Flux2LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline -from typing import Optional, Dict, Any - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -59,8 +59,6 @@ def num_channels_latents(self): return num_channels_latents - - class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): """ A ModularPipeline for Flux2-Klein. @@ -71,7 +69,6 @@ class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): default_blocks_name = "Flux2KleinBaseAutoBlocks" def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: return "Flux2KleinAutoBlocks" else: @@ -105,7 +102,6 @@ def num_channels_latents(self): @property def requires_unconditional_embeds(self): - if hasattr(self.config, "is_distilled") and self.config.is_distilled: return False @@ -113,4 +109,4 @@ def requires_unconditional_embeds(self): if hasattr(self, "guider") and self.guider is not None: requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 - return requires_unconditional_embeds \ No newline at end of file + return requires_unconditional_embeds From 9357d8f4f75364ba88b4030537be736c87ea4eb9 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Tue, 20 Jan 2026 01:32:08 +0000 Subject: [PATCH 03/13] copies --- .../dummy_torch_and_transformers_objects.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 63f381419fda..a23f852616c0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -17,6 +17,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2KleinBaseAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2KleinModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2ModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 3c7494a6517836431544ad3a2c82565f802acf44 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 20 Jan 2026 08:09:03 -1000 Subject: [PATCH 04/13] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sayak Paul Co-authored-by: Álvaro Somoza --- src/diffusers/modular_pipelines/flux2/denoise.py | 4 ++-- .../modular_pipelines/flux2/modular_blocks_flux2_klein.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index b2e1b41dde88..3dd5661d4934 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -177,7 +177,7 @@ def inputs(self) -> List[Tuple[str, Any]]: "prompt_embeds", required=True, type_hint=torch.Tensor, - description="Text embeddings from Mistral3", + description="Text embeddings from Qwen3", ), InputParam( "txt_ids", @@ -280,7 +280,7 @@ def inputs(self) -> List[Tuple[str, Any]]: "prompt_embeds", required=True, type_hint=torch.Tensor, - description="Text embeddings from Mistral3", + description="Text embeddings from Qwen3", ), InputParam( "negative_prompt_embeds", diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index 1dd63a6123e2..fc787ad1d2ad 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -141,7 +141,7 @@ class Flux2KleinAutoBlocks(SequentialPipelineBlocks): Flux2KleinCoreDenoiseStep(), Flux2DecodeStep(), ] - block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] @property def description(self): @@ -160,7 +160,7 @@ class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep(), ] - block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] @property def description(self): From d2953678ad9ead0ff6d697abfade52be1ce77e64 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 20 Jan 2026 09:25:29 -1000 Subject: [PATCH 05/13] Update src/diffusers/modular_pipelines/flux2/encoders.py --- src/diffusers/modular_pipelines/flux2/encoders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 1d9e56bdf028..835feb86cce9 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -284,6 +284,7 @@ def intermediate_outputs(self) -> List[OutputParam]: ), OutputParam( "negative_prompt_embeds", + kwargs_type="denoiser_input_fields", type_hint=torch.Tensor, description="Negative text embeddings from qwen3 used to guide the image generation", ), From c10041e57e1f9c3cf2d3ff96fd535e25dfa4f150 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 Jan 2026 23:13:53 +0100 Subject: [PATCH 06/13] a few fix: unpack latents before decoder etc --- .../modular_pipelines/flux2/decoders.py | 109 +++++++++----- .../modular_pipelines/flux2/denoise.py | 14 -- .../modular_pipelines/flux2/encoders.py | 138 ++++++++++++++++++ .../flux2/modular_blocks_flux2.py | 6 +- .../flux2/modular_blocks_flux2_klein.py | 20 ++- 5 files changed, 233 insertions(+), 54 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/decoders.py b/src/diffusers/modular_pipelines/flux2/decoders.py index b769d9119891..e8813672085d 100644 --- a/src/diffusers/modular_pipelines/flux2/decoders.py +++ b/src/diffusers/modular_pipelines/flux2/decoders.py @@ -29,29 +29,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class Flux2DecodeStep(ModularPipelineBlocks): +class Flux2UnpackLatentsStep(ModularPipelineBlocks): model_name = "flux2" - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKLFlux2), - ComponentSpec( - "image_processor", - Flux2ImageProcessor, - config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), - default_creation_method="from_config", - ), - ] - @property def description(self) -> str: - return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + return "Step that unpacks the latents from the denoising step" @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("output_type", default="pil"), InputParam( "latents", required=True, @@ -70,9 +57,9 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediate_outputs(self) -> List[str]: return [ OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], - description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + "latents", + type_hint=torch.Tensor, + description="The denoise latents from denoising step, unpacked with position IDs.", ) ] @@ -107,6 +94,64 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tens return torch.stack(x_list, dim=0) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + block_state.latents = latents + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + @staticmethod def _unpatchify_latents(latents): """Convert patchified latents back to regular format.""" @@ -121,26 +166,20 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) vae = components.vae - if block_state.output_type == "latent": - block_state.images = block_state.latents - else: - latents = block_state.latents - latent_ids = block_state.latent_ids - - latents = self._unpack_latents_with_ids(latents, latent_ids) + latents = block_state.latents - latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) - latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( - latents.device, latents.dtype - ) - latents = latents * latents_bn_std + latents_bn_mean + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean - latents = self._unpatchify_latents(latents) + latents = self._unpatchify_latents(latents) - block_state.images = vae.decode(latents, return_dict=False)[0] - block_state.images = components.image_processor.postprocess( - block_state.images, output_type=block_state.output_type - ) + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index 3dd5661d4934..a30382b5f774 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -306,10 +306,6 @@ def inputs(self) -> List[Tuple[str, Any]]: type_hint=torch.Tensor, description="4D position IDs for latent tokens (T, H, W, L)", ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), ] @torch.no_grad() @@ -339,20 +335,12 @@ def __call__( ), } - transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) - additional_cond_kwargs = {} - for field_name, field_value in block_state.denoiser_input_fields.items(): - if field_name in transformer_args and field_name not in guider_inputs: - additional_cond_kwargs[field_name] = field_value - block_state.additional_cond_kwargs.update(additional_cond_kwargs) - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} - cond_kwargs.update(additional_cond_kwargs) noise_pred = components.transformer( hidden_states=latent_model_input, @@ -458,8 +446,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 ) - block_state.additional_cond_kwargs = {} - with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index 835feb86cce9..b2a93e0a2548 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -17,6 +17,9 @@ import torch from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM +from ...guiders import ClassifierFreeGuidance +from ...configuration_utils import FrozenDict + from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState @@ -259,6 +262,141 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("tokenizer", Qwen2TokenizerFast), ] + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=True), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + @property def expected_configs(self) -> List[ConfigSpec]: return [ diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index a31673b6e78c..bad167f84280 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -21,7 +21,7 @@ Flux2RoPEInputsStep, Flux2SetTimestepsStep, ) -from .decoders import Flux2DecodeStep +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep from .denoise import Flux2DenoiseStep from .encoders import ( Flux2RemoteTextEncoderStep, @@ -99,6 +99,7 @@ def description(self): ("vae_image_encoder", Flux2AutoVaeEncoderStep()), ("before_denoise", Flux2BeforeDenoiseStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) @@ -111,6 +112,7 @@ def description(self): ("vae_image_encoder", Flux2AutoVaeEncoderStep()), ("before_denoise", Flux2BeforeDenoiseStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) @@ -139,6 +141,7 @@ def description(self): ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) @@ -154,6 +157,7 @@ def description(self): ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ("decode", Flux2DecodeStep()), ] ) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index fc787ad1d2ad..22949c99d7e0 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -21,10 +21,11 @@ Flux2RoPEInputsStep, Flux2SetTimestepsStep, ) -from .decoders import Flux2DecodeStep +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep from .encoders import ( Flux2KleinTextEncoderStep, + Flux2KleinBaseTextEncoderStep, Flux2VaeEncoderStep, ) from .inputs import ( @@ -35,7 +36,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - +### +### VAE encoder +### Flux2KleinVaeEncoderBlocks = InsertableDict( [ ("preprocess", Flux2ProcessImagesInputStep()), @@ -69,6 +72,9 @@ def description(self): " - If `image` is not provided, step will be skipped." ) +### +### Core denoise +### Flux2KleinCoreDenoiseBlocks = InsertableDict( [ @@ -78,6 +84,7 @@ def description(self): ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ] ) @@ -99,6 +106,7 @@ def description(self): " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" " - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) @@ -110,6 +118,7 @@ def description(self): ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ] ) @@ -130,9 +139,12 @@ def description(self): " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) - +### +### Auto blocks +### class Flux2KleinAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" block_classes = [ @@ -155,7 +167,7 @@ def description(self): class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" block_classes = [ - Flux2KleinTextEncoderStep(), + Flux2KleinBaseTextEncoderStep(), Flux2KleinAutoVaeEncoderStep(), Flux2KleinBaseCoreDenoiseStep(), Flux2DecodeStep(), From e13377e84170fb1f69b089cc7576817862235bbd Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 20 Jan 2026 23:14:59 +0100 Subject: [PATCH 07/13] style --- src/diffusers/modular_pipelines/flux2/decoders.py | 2 -- src/diffusers/modular_pipelines/flux2/denoise.py | 1 - src/diffusers/modular_pipelines/flux2/encoders.py | 3 +-- .../modular_pipelines/flux2/modular_blocks_flux2_klein.py | 4 +++- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/decoders.py b/src/diffusers/modular_pipelines/flux2/decoders.py index e8813672085d..c79375072037 100644 --- a/src/diffusers/modular_pipelines/flux2/decoders.py +++ b/src/diffusers/modular_pipelines/flux2/decoders.py @@ -94,12 +94,10 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tens return torch.stack(x_list, dim=0) - @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - latents = block_state.latents latent_ids = block_state.latent_ids diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py index a30382b5f774..a726959a29e2 100644 --- a/src/diffusers/modular_pipelines/flux2/denoise.py +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import Any, List, Tuple import torch diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index b2a93e0a2548..b4b6a4b533a7 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -17,9 +17,8 @@ import torch from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM -from ...guiders import ClassifierFreeGuidance from ...configuration_utils import FrozenDict - +from ...guiders import ClassifierFreeGuidance from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index 22949c99d7e0..b681238628e1 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -24,8 +24,8 @@ from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep from .encoders import ( - Flux2KleinTextEncoderStep, Flux2KleinBaseTextEncoderStep, + Flux2KleinTextEncoderStep, Flux2VaeEncoderStep, ) from .inputs import ( @@ -72,6 +72,7 @@ def description(self): " - If `image` is not provided, step will be skipped." ) + ### ### Core denoise ### @@ -142,6 +143,7 @@ def description(self): " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) + ### ### Auto blocks ### From 5c1fc4489f95b162a2f8b5d69fc89e9bfa40c8f5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 00:59:56 +0100 Subject: [PATCH 08/13] remove guidannce to its own block --- .../modular_pipelines/flux2/before_denoise.py | 111 +++++++++++++++--- .../modular_pipelines/flux2/inputs.py | 62 ++++++++++ .../flux2/modular_blocks_flux2.py | 4 + .../flux2/modular_blocks_flux2_klein.py | 56 ++++++++- 4 files changed, 211 insertions(+), 22 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py index e1001924c70b..d5bab16586d7 100644 --- a/src/diffusers/modular_pipelines/flux2/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -129,17 +129,9 @@ def inputs(self) -> List[InputParam]: InputParam("num_inference_steps", default=50), InputParam("timesteps"), InputParam("sigmas"), - InputParam("guidance_scale", default=4.0), InputParam("latents", type_hint=torch.Tensor), - InputParam("num_images_per_prompt", default=1), InputParam("height", type_hint=int), InputParam("width", type_hint=int), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", - ), ] @property @@ -151,13 +143,12 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=int, description="The number of denoising steps to perform at inference time", ), - OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), ] @torch.no_grad() def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device scheduler = components.scheduler @@ -183,7 +174,7 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi timesteps, num_inference_steps = retrieve_timesteps( scheduler, num_inference_steps, - block_state.device, + device, timesteps=timesteps, sigmas=sigmas, mu=mu, @@ -191,11 +182,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi block_state.timesteps = timesteps block_state.num_inference_steps = num_inference_steps - batch_size = block_state.batch_size * block_state.num_images_per_prompt - guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) - guidance = guidance.expand(batch_size) - block_state.guidance = guidance - components.scheduler.set_begin_index(0) self.set_block_state(state, block_state) @@ -349,6 +335,60 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks): def description(self) -> str: return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps." + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps." + @property def inputs(self) -> List[InputParam]: return [ @@ -511,3 +551,42 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi self.set_block_state(state, block_state) return components, state + + +class Flux2PrepareGuidanceStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the guidance scale tensor for Flux2 inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("guidance_scale", default=4.0), + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py index cc078c826206..3463de1999c6 100644 --- a/src/diffusers/modular_pipelines/flux2/inputs.py +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -30,6 +30,68 @@ class Flux2TextInputStep(ModularPipelineBlocks): model_name = "flux2" + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextInputStep(ModularPipelineBlocks): + model_name = "flux2-klein" + @property def description(self) -> str: return ( diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index bad167f84280..af6c1819ec2c 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -19,6 +19,7 @@ Flux2PrepareImageLatentsStep, Flux2PrepareLatentsStep, Flux2RoPEInputsStep, + Flux2PrepareGuidanceStep, Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep @@ -76,6 +77,7 @@ def description(self): [ ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ] ) @@ -139,6 +141,7 @@ def description(self): ("text_input", Flux2TextInputStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), ("after_denoise", Flux2UnpackLatentsStep()), @@ -155,6 +158,7 @@ def description(self): ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), ("denoise", Flux2DenoiseStep()), ("after_denoise", Flux2UnpackLatentsStep()), diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index b681238628e1..6e1cb985e77b 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch +import PIL.Image +from typing import List from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict +from ..modular_pipeline_utils import InsertableDict, OutputParam from .before_denoise import ( Flux2PrepareImageLatentsStep, Flux2PrepareLatentsStep, Flux2RoPEInputsStep, + Flux2KleinBaseRoPEInputsStep, Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep @@ -31,6 +35,7 @@ from .inputs import ( Flux2ProcessImagesInputStep, Flux2TextInputStep, + Flux2KleinBaseTextInputStep, ) @@ -101,7 +106,7 @@ def description(self): return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." return ( "Core denoise step that performs the denoising process for Flux2-Klein.\n" - " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" @@ -110,14 +115,24 @@ def description(self): " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] + Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( [ - ("input", Flux2TextInputStep()), + ("input", Flux2KleinBaseTextInputStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), - ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), ("denoise", Flux2KleinBaseDenoiseStep()), ("after_denoise", Flux2UnpackLatentsStep()), ] @@ -134,14 +149,23 @@ def description(self): return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." return ( "Core denoise step that performs the denoising process for Flux2-Klein (base model).\n" - " - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n" + " - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n" " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" - " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n" " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] ### @@ -165,6 +189,16 @@ def description(self): + " - for text-to-image generation, all you need to provide is `prompt`.\n" ) + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): model_name = "flux2-klein" @@ -183,3 +217,13 @@ def description(self): + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + " - for text-to-image generation, all you need to provide is `prompt`.\n" ) + + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] From f49c68cecfd6665f84cc68aba6daaac55379900f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 01:01:56 +0100 Subject: [PATCH 09/13] style --- .../modular_pipelines/flux2/modular_blocks_flux2.py | 2 +- .../flux2/modular_blocks_flux2_klein.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index af6c1819ec2c..66509454c3ea 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -16,10 +16,10 @@ from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( + Flux2PrepareGuidanceStep, Flux2PrepareImageLatentsStep, Flux2PrepareLatentsStep, Flux2RoPEInputsStep, - Flux2PrepareGuidanceStep, Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index 6e1cb985e77b..0ecbbceb6d85 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -12,17 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -import PIL.Image from typing import List + +import PIL.Image +import torch + from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam from .before_denoise import ( + Flux2KleinBaseRoPEInputsStep, Flux2PrepareImageLatentsStep, Flux2PrepareLatentsStep, Flux2RoPEInputsStep, - Flux2KleinBaseRoPEInputsStep, Flux2SetTimestepsStep, ) from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep @@ -33,9 +35,9 @@ Flux2VaeEncoderStep, ) from .inputs import ( + Flux2KleinBaseTextInputStep, Flux2ProcessImagesInputStep, Flux2TextInputStep, - Flux2KleinBaseTextInputStep, ) @@ -157,6 +159,7 @@ def description(self): " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" ) + @property def outputs(self): return [ From 1c500c8eeb27ec97ef1bcfe20ab33159c4490e2b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 08:06:32 +0100 Subject: [PATCH 10/13] flux2-dev work in modular setting --- .../flux2/modular_blocks_flux2.py | 65 ++++++++++++++----- .../flux2/modular_blocks_flux2_klein.py | 10 +-- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index 66509454c3ea..eba2cbbd00f2 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import PIL.Image +from typing import List +import torch + from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict +from ..modular_pipeline_utils import InsertableDict, OutputParam from .before_denoise import ( Flux2PrepareGuidanceStep, Flux2PrepareImageLatentsStep, @@ -42,7 +46,6 @@ [ ("preprocess", Flux2ProcessImagesInputStep()), ("encode", Flux2VaeEncoderStep()), - ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ] ) @@ -73,35 +76,56 @@ def description(self): ) -Flux2BeforeDenoiseBlocks = InsertableDict( +Flux2CoreDenoiseBlocks = InsertableDict( [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), ("prepare_latents", Flux2PrepareLatentsStep()), ("set_timesteps", Flux2SetTimestepsStep()), ("prepare_guidance", Flux2PrepareGuidanceStep()), ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), ] ) -class Flux2BeforeDenoiseStep(SequentialPipelineBlocks): +class Flux2CoreDenoiseStep(SequentialPipelineBlocks): model_name = "flux2" - block_classes = Flux2BeforeDenoiseBlocks.values() - block_names = Flux2BeforeDenoiseBlocks.keys() + block_classes = Flux2CoreDenoiseBlocks.values() + block_names = Flux2CoreDenoiseBlocks.keys() @property def description(self): - return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation." + return ( + "Core denoise step that performs the denoising process for Flux2-dev.\n" + " - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] AUTO_BLOCKS = InsertableDict( [ ("text_encoder", Flux2TextEncoderStep()), - ("text_input", Flux2TextInputStep()), - ("vae_image_encoder", Flux2AutoVaeEncoderStep()), - ("before_denoise", Flux2BeforeDenoiseStep()), - ("denoise", Flux2DenoiseStep()), - ("after_denoise", Flux2UnpackLatentsStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), ("decode", Flux2DecodeStep()), ] ) @@ -110,11 +134,8 @@ def description(self): REMOTE_AUTO_BLOCKS = InsertableDict( [ ("text_encoder", Flux2RemoteTextEncoderStep()), - ("text_input", Flux2TextInputStep()), - ("vae_image_encoder", Flux2AutoVaeEncoderStep()), - ("before_denoise", Flux2BeforeDenoiseStep()), - ("denoise", Flux2DenoiseStep()), - ("after_denoise", Flux2UnpackLatentsStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), ("decode", Flux2DecodeStep()), ] ) @@ -134,6 +155,16 @@ def description(self): "- For image-conditioned generation, you need to provide `image` (list of PIL images)." ) + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + TEXT2IMAGE_BLOCKS = InsertableDict( [ diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py index 0ecbbceb6d85..984832d77be5 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -43,9 +43,10 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -### -### VAE encoder -### +################ +# VAE encoder +################ + Flux2KleinVaeEncoderBlocks = InsertableDict( [ ("preprocess", Flux2ProcessImagesInputStep()), @@ -105,9 +106,8 @@ class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): @property def description(self): - return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)." return ( - "Core denoise step that performs the denoising process for Flux2-Klein.\n" + "Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n" " - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" From a232cd9d305442f6eb44f5b8dd94b33bceea8218 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 12:29:12 +0100 Subject: [PATCH 11/13] up --- src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py index eba2cbbd00f2..41a0ff7dee28 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import PIL.Image from typing import List + +import PIL.Image import torch from ...utils import logging From eb221d5bc1745acebc14ab1e1697fb52c429a6f3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 12:43:37 +0100 Subject: [PATCH 12/13] up up --- src/diffusers/modular_pipelines/flux2/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index fb97a56fb049..220ec0c4ab65 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -51,7 +51,7 @@ "IMAGE_CONDITIONED_BLOCKS", "Flux2AutoBlocks", "Flux2AutoVaeEncoderStep", - "Flux2BeforeDenoiseStep", + "Flux2CoreDenoiseStep", "Flux2VaeEncoderSequentialStep", ] _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"] @@ -94,7 +94,7 @@ TEXT2IMAGE_BLOCKS, Flux2AutoBlocks, Flux2AutoVaeEncoderStep, - Flux2BeforeDenoiseStep, + Flux2CoreDenoiseStep, Flux2VaeEncoderSequentialStep, ) from .modular_blocks_flux2_klein import ( From a81893c40753bf44872906043d290cd1cff60325 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 21 Jan 2026 13:55:38 +0100 Subject: [PATCH 13/13] add tests --- .../modular_pipelines/flux2/encoders.py | 2 +- .../test_modular_pipeline_flux2_klein.py | 91 +++++++++++++++++++ .../test_modular_pipeline_flux2_klein_base.py | 91 +++++++++++++++++++ 3 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py create mode 100644 tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py index b4b6a4b533a7..265fb387367c 100644 --- a/src/diffusers/modular_pipelines/flux2/encoders.py +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -514,7 +514,7 @@ def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) ) if components.requires_unconditional_embeds: - negative_prompt = "" + negative_prompt = [""] * len(prompt) block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds( text_encoder=components.text_encoder, tokenizer=components.tokenizer, diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py new file mode 100644 index 000000000000..26653b20f8c4 --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2KleinAutoBlocks, + Flux2KleinModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular" + + params = frozenset(["prompt", "height", "width", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py new file mode 100644 index 000000000000..701dd0fed896 --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinBaseAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinBaseAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular" + + params = frozenset(["prompt", "height", "width", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return