Skip to content
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
Expand Down Expand Up @@ -1146,6 +1149,9 @@
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
Expand All @@ -81,7 +84,13 @@
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .flux2 import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
Expand Down
17 changes: 11 additions & 6 deletions src/diffusers/modular_pipelines/flux2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,19 @@
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks"] = [
_import_structure["modular_blocks_flux2"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"IMAGE_CONDITIONED_BLOCKS",
"Flux2AutoBlocks",
"Flux2AutoVaeEncoderStep",
"Flux2BeforeDenoiseStep",
"Flux2CoreDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand Down Expand Up @@ -85,18 +86,22 @@
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks import (
from .modular_blocks_flux2 import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
REMOTE_AUTO_BLOCKS,
TEXT2IMAGE_BLOCKS,
Flux2AutoBlocks,
Flux2AutoVaeEncoderStep,
Flux2BeforeDenoiseStep,
Flux2CoreDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_pipeline import Flux2ModularPipeline
from .modular_blocks_flux2_klein import (
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

Expand Down
122 changes: 103 additions & 19 deletions src/diffusers/modular_pipelines/flux2/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,9 @@ def inputs(self) -> List[InputParam]:
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=4.0),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]

@property
Expand All @@ -151,13 +143,12 @@ def intermediate_outputs(self) -> List[OutputParam]:
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]

@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device

scheduler = components.scheduler

Expand All @@ -183,19 +174,14 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
block_state.device,
device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps

batch_size = block_state.batch_size * block_state.num_images_per_prompt
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separated this to a prepare_guidance block

guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance

components.scheduler.set_begin_index(0)

self.set_block_state(state, block_state)
Expand Down Expand Up @@ -353,7 +339,61 @@ def description(self) -> str:
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="latent_ids"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed because latent_ids are not used in this block I think

]

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
]

@staticmethod
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
"""Prepare 4D position IDs for text tokens."""
B, L, _ = x.shape
out_ids = []

for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
seq_l = torch.arange(L)

coords = torch.cartesian_prod(t, h, w, seq_l)
out_ids.append(coords)

return torch.stack(out_ids)

def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

prompt_embeds = block_state.prompt_embeds
device = prompt_embeds.device

block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)

self.set_block_state(state, block_state)
return components, state


class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
model_name = "flux2-klein"

@property
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."

@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="negative_prompt_embeds", required=False),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinions but WDYT of creating a separate block for Klein altogether? I think this way it will be a bit easier to debug and also separate concerns?

My suggestions mainly comes from the fact that Flux.2-Dev doesn't use negative_prompt_embeds while Flux.2-Klein does. So, maybe that warrants creating separate blocks.

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a fair point, but on the other hand, I've personally found that having too many blocks can become overwhelming - each time you need to add something, you still need to go through all of them and understand which ones to use.
I think it makes sense to just add the code in the same blocks here, it is so small and fits in. but this is really a matter of preference, not right or wrong. Maybe we'll know better in the future though after building more pipelines :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I changed my mind - I agree it's better to separate them out. Otherwise negative_prompt_embeds will show up as an optional argument in the auto docstring for both Klein and Dev, which is confusing.
Note that in Qwen (https://github.com/huggingface/diffusers/blob/main/src/diffusers/modular_pipelines/qwenimage/inputs.py#L232), I'm experimenting with more composable blocks for situations like this that you can just reuse. But it also makes the blocks more complex, and I'm not sure if I'm over-engineering. So let's keep them simple here and see how it goes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

]

@property
Expand All @@ -366,10 +406,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="latent_ids",
name="negative_txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
),
]

Expand Down Expand Up @@ -399,6 +439,11 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)

block_state.negative_txt_ids = None
if block_state.negative_prompt_embeds is not None:
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)

self.set_block_state(state, block_state)
return components, state

Expand Down Expand Up @@ -506,3 +551,42 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi

self.set_block_state(state, block_state)
return components, state


class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
model_name = "flux2"

@property
def description(self) -> str:
return "Step that prepares the guidance scale tensor for Flux2 inference"

@property
def inputs(self) -> List[InputParam]:
return [
InputParam("guidance_scale", default=4.0),
InputParam("num_images_per_prompt", default=1),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]

@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance

self.set_block_state(state, block_state)
return components, state
Loading
Loading