@@ -91,6 +91,41 @@ const float flux_latent_rgb_proj[16][3] = {
9191 {-0 .111849f , -0 .055589f , -0 .032361f }};
9292float flux_latent_rgb_bias[3 ] = {0 .024600f , -0 .006937f , -0 .008089f };
9393
94+ const float flux2_latent_rgb_proj[32 ][3 ] = {
95+ {0 .000736f , -0 .008385f , -0 .019710f },
96+ {-0 .001352f , -0 .016392f , 0 .020693f },
97+ {-0 .006376f , 0 .002428f , 0 .036736f },
98+ {0 .039384f , 0 .074167f , 0 .119789f },
99+ {0 .007464f , -0 .005705f , -0 .004734f },
100+ {-0 .004086f , 0 .005287f , -0 .000409f },
101+ {-0 .032835f , 0 .050802f , -0 .028120f },
102+ {-0 .003158f , -0 .000835f , 0 .000406f },
103+ {-0 .112840f , -0 .084337f , -0 .023083f },
104+ {0 .001462f , -0 .006656f , 0 .000549f },
105+ {-0 .009980f , -0 .007480f , 0 .009702f },
106+ {0 .032540f , 0 .000214f , -0 .061388f },
107+ {0 .011023f , 0 .000694f , 0 .007143f },
108+ {-0 .001468f , -0 .006723f , -0 .001678f },
109+ {-0 .005921f , -0 .010320f , -0 .003907f },
110+ {-0 .028434f , 0 .027584f , 0 .018457f },
111+ {0 .014349f , 0 .011523f , 0 .000441f },
112+ {0 .009874f , 0 .003081f , 0 .001507f },
113+ {0 .002218f , 0 .005712f , 0 .001563f },
114+ {0 .053010f , -0 .019844f , 0 .008683f },
115+ {-0 .002507f , 0 .005384f , 0 .000938f },
116+ {-0 .002177f , -0 .011366f , 0 .003559f },
117+ {-0 .000261f , 0 .015121f , -0 .003240f },
118+ {-0 .003944f , -0 .002083f , 0 .005043f },
119+ {-0 .009138f , 0 .011336f , 0 .003781f },
120+ {0 .011429f , 0 .003985f , -0 .003855f },
121+ {0 .010518f , -0 .005586f , 0 .010131f },
122+ {0 .007883f , 0 .002912f , -0 .001473f },
123+ {-0 .003318f , -0 .003160f , 0 .003684f },
124+ {-0 .034560f , -0 .008740f , 0 .012996f },
125+ {0 .000166f , 0 .001079f , -0 .012153f },
126+ {0 .017772f , 0 .000937f , -0 .011953f }};
127+ float flux2_latent_rgb_bias[3 ] = {-0 .028738f , -0 .098463f , -0 .107619f };
128+
94129// This one was taken straight from
95130// https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303
96131// (MiT Licence)
@@ -128,16 +163,42 @@ const float sd_latent_rgb_proj[4][3] = {
128163 {-0 .178022f , -0 .200862f , -0 .678514f }};
129164float sd_latent_rgb_bias[3 ] = {-0 .017478f , -0 .055834f , -0 .105825f };
130165
131- void preview_latent_video (uint8_t * buffer, struct ggml_tensor * latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim ) {
166+ void preview_latent_video (uint8_t * buffer, struct ggml_tensor * latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int patch_size ) {
132167 size_t buffer_head = 0 ;
168+
169+ uint32_t latent_width = latents->ne [0 ];
170+ uint32_t latent_height = latents->ne [1 ];
171+ uint32_t dim = latents->ne [ggml_n_dims (latents) - 1 ];
172+ uint32_t frames = 1 ;
173+ if (ggml_n_dims (latents) == 4 ) {
174+ frames = latents->ne [2 ];
175+ }
176+
177+ uint32_t rgb_width = latent_width * patch_size;
178+ uint32_t rgb_height = latent_height * patch_size;
179+
180+ uint32_t unpatched_dim = dim / (patch_size * patch_size);
181+
133182 for (int k = 0 ; k < frames; k++) {
134- for (int j = 0 ; j < height; j++) {
135- for (int i = 0 ; i < width; i++) {
136- size_t latent_id = (i * latents->nb [0 ] + j * latents->nb [1 ] + k * latents->nb [2 ]);
183+ for (int rgb_x = 0 ; rgb_x < rgb_width; rgb_x++) {
184+ for (int rgb_y = 0 ; rgb_y < rgb_height; rgb_y++) {
185+ int latent_x = rgb_x / patch_size;
186+ int latent_y = rgb_y / patch_size;
187+
188+ int channel_offset = 0 ;
189+ if (patch_size > 1 ) {
190+ channel_offset = ((rgb_y % patch_size) * patch_size + (rgb_x % patch_size));
191+ }
192+
193+ size_t latent_id = (latent_x * latents->nb [0 ] + latent_y * latents->nb [1 ] + k * latents->nb [2 ]);
194+
195+ // should be incremented by 1 for each pixel
196+ size_t pixel_id = k * rgb_width * rgb_height + rgb_y * rgb_width + rgb_x;
197+
137198 float r = 0 , g = 0 , b = 0 ;
138199 if (latent_rgb_proj != nullptr ) {
139- for (int d = 0 ; d < dim ; d++) {
140- float value = *(float *)((char *)latents->data + latent_id + d * latents->nb [ggml_n_dims (latents) - 1 ]);
200+ for (int d = 0 ; d < unpatched_dim ; d++) {
201+ float value = *(float *)((char *)latents->data + latent_id + (d * patch_size * patch_size + channel_offset) * latents->nb [ggml_n_dims (latents) - 1 ]);
141202 r += value * latent_rgb_proj[d][0 ];
142203 g += value * latent_rgb_proj[d][1 ];
143204 b += value * latent_rgb_proj[d][2 ];
@@ -164,9 +225,9 @@ void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const fl
164225 g = g >= 0 ? g <= 1 ? g : 1 : 0 ;
165226 b = b >= 0 ? b <= 1 ? b : 1 : 0 ;
166227
167- buffer[buffer_head++ ] = (uint8_t )(r * 255 );
168- buffer[buffer_head++ ] = (uint8_t )(g * 255 );
169- buffer[buffer_head++ ] = (uint8_t )(b * 255 );
228+ buffer[pixel_id * 3 + 0 ] = (uint8_t )(r * 255 );
229+ buffer[pixel_id * 3 + 1 ] = (uint8_t )(g * 255 );
230+ buffer[pixel_id * 3 + 2 ] = (uint8_t )(b * 255 );
170231 }
171232 }
172233 }
0 commit comments