@@ -166,7 +166,6 @@ class ControlNetBlock : public GGMLBlock {
166166
167167 struct ggml_tensor * resblock_forward (std::string name,
168168 struct ggml_context * ctx,
169- struct ggml_allocr * allocr,
170169 struct ggml_tensor * x,
171170 struct ggml_tensor * emb) {
172171 auto block = std::dynamic_pointer_cast<ResBlock>(blocks[name]);
@@ -175,7 +174,6 @@ class ControlNetBlock : public GGMLBlock {
175174
176175 struct ggml_tensor * attention_layer_forward (std::string name,
177176 struct ggml_context * ctx,
178- struct ggml_allocr * allocr,
179177 struct ggml_tensor * x,
180178 struct ggml_tensor * context) {
181179 auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
@@ -201,11 +199,10 @@ class ControlNetBlock : public GGMLBlock {
201199 }
202200
203201 std::vector<struct ggml_tensor *> forward (struct ggml_context * ctx,
204- struct ggml_allocr * allocr,
205202 struct ggml_tensor * x,
206203 struct ggml_tensor * hint,
207204 struct ggml_tensor * guided_hint,
208- std::vector< float > timesteps,
205+ struct ggml_tensor * timesteps,
209206 struct ggml_tensor * context,
210207 struct ggml_tensor * y = NULL ) {
211208 // x: [N, in_channels, h, w] or [N, in_channels/2, h, w]
@@ -231,7 +228,7 @@ class ControlNetBlock : public GGMLBlock {
231228
232229 auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks[" middle_block_out.0" ]);
233230
234- auto t_emb = new_timestep_embedding (ctx, allocr , timesteps, model_channels); // [N, model_channels]
231+ auto t_emb = ggml_nn_timestep_embedding (ctx, timesteps, model_channels); // [N, model_channels]
235232
236233 auto emb = time_embed_0->forward (ctx, t_emb);
237234 emb = ggml_silu_inplace (ctx, emb);
@@ -272,10 +269,10 @@ class ControlNetBlock : public GGMLBlock {
272269 for (int j = 0 ; j < num_res_blocks; j++) {
273270 input_block_idx += 1 ;
274271 std::string name = " input_blocks." + std::to_string (input_block_idx) + " .0" ;
275- h = resblock_forward (name, ctx, allocr, h, emb); // [N, mult*model_channels, h, w]
272+ h = resblock_forward (name, ctx, h, emb); // [N, mult*model_channels, h, w]
276273 if (std::find (attention_resolutions.begin (), attention_resolutions.end (), ds) != attention_resolutions.end ()) {
277274 std::string name = " input_blocks." + std::to_string (input_block_idx) + " .1" ;
278- h = attention_layer_forward (name, ctx, allocr, h, context); // [N, mult*model_channels, h, w]
275+ h = attention_layer_forward (name, ctx, h, context); // [N, mult*model_channels, h, w]
279276 }
280277
281278 auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks[" zero_convs." + std::to_string (input_block_idx) + " .0" ]);
@@ -299,9 +296,9 @@ class ControlNetBlock : public GGMLBlock {
299296 // [N, 4*model_channels, h/8, w/8]
300297
301298 // middle_block
302- h = resblock_forward (" middle_block.0" , ctx, allocr, h, emb); // [N, 4*model_channels, h/8, w/8]
303- h = attention_layer_forward (" middle_block.1" , ctx, allocr, h, context); // [N, 4*model_channels, h/8, w/8]
304- h = resblock_forward (" middle_block.2" , ctx, allocr, h, emb); // [N, 4*model_channels, h/8, w/8]
299+ h = resblock_forward (" middle_block.0" , ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
300+ h = attention_layer_forward (" middle_block.1" , ctx, h, context); // [N, 4*model_channels, h/8, w/8]
301+ h = resblock_forward (" middle_block.2" , ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
305302
306303 // out
307304 outs.push_back (middle_block_out->forward (ctx, h));
@@ -386,18 +383,22 @@ struct ControlNet : public GGMLModule {
386383
387384 struct ggml_cgraph * build_graph (struct ggml_tensor * x,
388385 struct ggml_tensor * hint,
389- std::vector< float > timesteps,
386+ struct ggml_tensor * timesteps,
390387 struct ggml_tensor * context,
391388 struct ggml_tensor * y = NULL ) {
392389 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, CONTROL_NET_GRAPH_SIZE, false );
393390
394- x = to_backend (x);
395- hint = to_backend (hint);
396- context = to_backend (context);
397- y = to_backend (y);
391+ x = to_backend (x);
392+ if (guided_hint_cached) {
393+ hint = NULL ;
394+ } else {
395+ hint = to_backend (hint);
396+ }
397+ context = to_backend (context);
398+ y = to_backend (y);
399+ timesteps = to_backend (timesteps);
398400
399401 auto outs = control_net.forward (compute_ctx,
400- compute_allocr,
401402 x,
402403 hint,
403404 guided_hint_cached ? guided_hint : NULL ,
@@ -420,7 +421,7 @@ struct ControlNet : public GGMLModule {
420421 void compute (int n_threads,
421422 struct ggml_tensor * x,
422423 struct ggml_tensor * hint,
423- std::vector< float > timesteps,
424+ struct ggml_tensor * timesteps,
424425 struct ggml_tensor * context,
425426 struct ggml_tensor * y,
426427 struct ggml_tensor ** output = NULL ,
@@ -434,7 +435,6 @@ struct ControlNet : public GGMLModule {
434435 };
435436
436437 GGMLModule::compute (get_graph, n_threads, false , output, output_ctx);
437-
438438 guided_hint_cached = true ;
439439 }
440440
0 commit comments