-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[modular]support klein #13002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[modular]support klein #13002
Changes from all commits
618a8a9
fb2cb18
9357d8f
3c7494a
d295367
e1e1629
c10041e
dea47aa
e13377e
5c1fc44
f49c68c
1c500c8
a232cd9
eb221d5
a81893c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,17 +129,9 @@ def inputs(self) -> List[InputParam]: | |
| InputParam("num_inference_steps", default=50), | ||
| InputParam("timesteps"), | ||
| InputParam("sigmas"), | ||
| InputParam("guidance_scale", default=4.0), | ||
| InputParam("latents", type_hint=torch.Tensor), | ||
| InputParam("num_images_per_prompt", default=1), | ||
| InputParam("height", type_hint=int), | ||
| InputParam("width", type_hint=int), | ||
| InputParam( | ||
| "batch_size", | ||
| required=True, | ||
| type_hint=int, | ||
| description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", | ||
| ), | ||
| ] | ||
|
|
||
| @property | ||
|
|
@@ -151,13 +143,12 @@ def intermediate_outputs(self) -> List[OutputParam]: | |
| type_hint=int, | ||
| description="The number of denoising steps to perform at inference time", | ||
| ), | ||
| OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), | ||
| ] | ||
|
|
||
| @torch.no_grad() | ||
| def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: | ||
| block_state = self.get_block_state(state) | ||
| block_state.device = components._execution_device | ||
| device = components._execution_device | ||
|
|
||
| scheduler = components.scheduler | ||
|
|
||
|
|
@@ -183,19 +174,14 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi | |
| timesteps, num_inference_steps = retrieve_timesteps( | ||
| scheduler, | ||
| num_inference_steps, | ||
| block_state.device, | ||
| device, | ||
| timesteps=timesteps, | ||
| sigmas=sigmas, | ||
| mu=mu, | ||
| ) | ||
| block_state.timesteps = timesteps | ||
| block_state.num_inference_steps = num_inference_steps | ||
|
|
||
| batch_size = block_state.batch_size * block_state.num_images_per_prompt | ||
| guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32) | ||
| guidance = guidance.expand(batch_size) | ||
| block_state.guidance = guidance | ||
|
|
||
| components.scheduler.set_begin_index(0) | ||
|
|
||
| self.set_block_state(state, block_state) | ||
|
|
@@ -353,7 +339,61 @@ def description(self) -> str: | |
| def inputs(self) -> List[InputParam]: | ||
| return [ | ||
| InputParam(name="prompt_embeds", required=True), | ||
| InputParam(name="latent_ids"), | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed because latent_ids are not used in this block I think |
||
| ] | ||
|
|
||
| @property | ||
| def intermediate_outputs(self) -> List[OutputParam]: | ||
| return [ | ||
| OutputParam( | ||
| name="txt_ids", | ||
| kwargs_type="denoiser_input_fields", | ||
| type_hint=torch.Tensor, | ||
| description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", | ||
| ), | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): | ||
| """Prepare 4D position IDs for text tokens.""" | ||
| B, L, _ = x.shape | ||
| out_ids = [] | ||
|
|
||
| for i in range(B): | ||
| t = torch.arange(1) if t_coord is None else t_coord[i] | ||
| h = torch.arange(1) | ||
| w = torch.arange(1) | ||
| seq_l = torch.arange(L) | ||
|
|
||
| coords = torch.cartesian_prod(t, h, w, seq_l) | ||
| out_ids.append(coords) | ||
|
|
||
| return torch.stack(out_ids) | ||
|
|
||
| def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: | ||
| block_state = self.get_block_state(state) | ||
|
|
||
| prompt_embeds = block_state.prompt_embeds | ||
| device = prompt_embeds.device | ||
|
|
||
| block_state.txt_ids = self._prepare_text_ids(prompt_embeds) | ||
| block_state.txt_ids = block_state.txt_ids.to(device) | ||
|
|
||
| self.set_block_state(state, block_state) | ||
| return components, state | ||
|
|
||
|
|
||
| class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks): | ||
| model_name = "flux2-klein" | ||
|
|
||
| @property | ||
| def description(self) -> str: | ||
| return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps." | ||
|
|
||
| @property | ||
| def inputs(self) -> List[InputParam]: | ||
| return [ | ||
| InputParam(name="prompt_embeds", required=True), | ||
| InputParam(name="negative_prompt_embeds", required=False), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No strong opinions but WDYT of creating a separate block for Klein altogether? I think this way it will be a bit easier to debug and also separate concerns? My suggestions mainly comes from the fact that Flux.2-Dev doesn't use
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a fair point, but on the other hand, I've personally found that having too many blocks can become overwhelming - each time you need to add something, you still need to go through all of them and understand which ones to use.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I changed my mind - I agree it's better to separate them out. Otherwise negative_prompt_embeds will show up as an optional argument in the auto docstring for both Klein and Dev, which is confusing.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! |
||
| ] | ||
|
|
||
| @property | ||
|
|
@@ -366,10 +406,10 @@ def intermediate_outputs(self) -> List[OutputParam]: | |
| description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", | ||
| ), | ||
| OutputParam( | ||
| name="latent_ids", | ||
| name="negative_txt_ids", | ||
| kwargs_type="denoiser_input_fields", | ||
| type_hint=torch.Tensor, | ||
| description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.", | ||
| description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.", | ||
| ), | ||
| ] | ||
|
|
||
|
|
@@ -399,6 +439,11 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi | |
| block_state.txt_ids = self._prepare_text_ids(prompt_embeds) | ||
| block_state.txt_ids = block_state.txt_ids.to(device) | ||
|
|
||
| block_state.negative_txt_ids = None | ||
| if block_state.negative_prompt_embeds is not None: | ||
| block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds) | ||
| block_state.negative_txt_ids = block_state.negative_txt_ids.to(device) | ||
|
|
||
| self.set_block_state(state, block_state) | ||
| return components, state | ||
|
|
||
|
|
@@ -506,3 +551,42 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi | |
|
|
||
| self.set_block_state(state, block_state) | ||
| return components, state | ||
|
|
||
|
|
||
| class Flux2PrepareGuidanceStep(ModularPipelineBlocks): | ||
| model_name = "flux2" | ||
|
|
||
| @property | ||
| def description(self) -> str: | ||
| return "Step that prepares the guidance scale tensor for Flux2 inference" | ||
|
|
||
| @property | ||
| def inputs(self) -> List[InputParam]: | ||
| return [ | ||
| InputParam("guidance_scale", default=4.0), | ||
| InputParam("num_images_per_prompt", default=1), | ||
| InputParam( | ||
| "batch_size", | ||
| required=True, | ||
| type_hint=int, | ||
| description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", | ||
| ), | ||
| ] | ||
|
|
||
| @property | ||
| def intermediate_outputs(self) -> List[OutputParam]: | ||
| return [ | ||
| OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), | ||
| ] | ||
|
|
||
| @torch.no_grad() | ||
| def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: | ||
| block_state = self.get_block_state(state) | ||
| device = components._execution_device | ||
| batch_size = block_state.batch_size * block_state.num_images_per_prompt | ||
| guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32) | ||
| guidance = guidance.expand(batch_size) | ||
| block_state.guidance = guidance | ||
|
|
||
| self.set_block_state(state, block_state) | ||
| return components, state | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separated this to a prepare_guidance block