@@ -850,12 +850,15 @@ namespace Flux {
850850
851851 // auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends
852852 auto arrange = y;
853- auto modulation_index = ggml_nn_timestep_embedding (ctx, arrange, 32 , 10000 , 1000 .f );
853+ auto modulation_index = ggml_nn_timestep_embedding (ctx, arrange, 32 , 10000 , 1000 .f );// [1, 344, 32]
854+
855+ // Batch broadcast (will it ever be useful)
856+ modulation_index = ggml_repeat (ctx, modulation_index, ggml_new_tensor_4d (ctx, GGML_TYPE_F32, modulation_index->ne [0 ], modulation_index->ne [1 ], img->ne [2 ], modulation_index->ne [3 ]));// [N, 344, 32]
854857
855- auto timestep_guidance = ggml_concat (ctx, distill_timestep, distill_guidance, 0 );
856- timestep_guidance = ggml_repeat (ctx, distill_timestep, modulation_index);
857- // TODO Batch broadcast?
858858
859+ auto timestep_guidance = ggml_concat (ctx, distill_timestep, distill_guidance, 0 ); // [N, 1, 32]
860+ timestep_guidance = ggml_repeat (ctx, timestep_guidance, modulation_index); // [N, 344, 32]
861+
859862 vec = ggml_concat (ctx, timestep_guidance, modulation_index, 0 ); // [N, 344, 64]
860863 vec = approx->forward (ctx, vec); // [N, 344, hidden_size]
861864
@@ -1094,7 +1097,7 @@ namespace Flux {
10941097 set_backend_tensor_data (y, range.data ());
10951098 }
10961099 timesteps = to_backend (timesteps);
1097- if (flux_params.guidance_embed ) {
1100+ if (flux_params.guidance_embed || flux_params. is_chroma ) {
10981101 guidance = to_backend (guidance);
10991102 }
11001103
0 commit comments