@@ -59,33 +59,24 @@ tensor window_reverse(model_ref m, tensor x, int w, int h, int window) {
5959// Image encoder
6060//
6161
62- tensor conv_2d_batch_norm (model_ref m, tensor x, int stride, int pad, int groups) {
63- if (groups == 1 ) {
64- x = conv_2d (m[" c" ], x, stride, pad);
65- } else {
66- x = conv_2d_depthwise (m[" c" ], x, stride, pad);
67- }
68- x = batch_norm_2d (m[" bn" ], x);
69- return named (m, x);
70- }
7162
7263tensor patch_embed (model_ref m, tensor x) {
73- x = conv_2d_batch_norm (m[" seq.0" ], x, 2 , 1 );
64+ x = conv_2d (m[" seq.0" ], x, 2 , 1 );
7465 x = ggml_gelu_inplace (m, x);
75- x = conv_2d_batch_norm (m[" seq.2" ], x, 2 , 1 );
66+ x = conv_2d (m[" seq.2" ], x, 2 , 1 );
7667 return named (m, x);
7768}
7869
7970tensor mb_conv (model_ref m, tensor x) {
8071 tensor shortcut = x;
8172
82- x = conv_2d_batch_norm (m[" conv1" ], x);
73+ x = conv_2d (m[" conv1" ], x);
8374 x = ggml_gelu_inplace (m, x);
8475
85- x = conv_2d_batch_norm (m[" conv2" ], x, 1 , 1 , /* groups */ int (x-> ne [ 2 ]) );
76+ x = conv_2d_depthwise (m[" conv2" ], x, 1 , 1 );
8677 x = ggml_gelu_inplace (m, x);
8778
88- x = conv_2d_batch_norm (m[" conv3" ], x);
79+ x = conv_2d (m[" conv3" ], x);
8980 x = ggml_add_inplace (m, x, shortcut);
9081 x = ggml_gelu_inplace (m, x);
9182
@@ -96,16 +87,16 @@ tensor patch_merging(model_ref m, tensor x, int input_resolution) {
9687 if (x->ne [2 ] == 1 ) {
9788 x = ggml_reshape_4d (m, x, x->ne [0 ], input_resolution, input_resolution, x->ne [3 ]);
9889 }
99- x = conv_2d_batch_norm (m[" conv1" ], x);
90+ x = conv_2d (m[" conv1" ], x);
10091 x = ggml_gelu_inplace (m, x);
10192
102- int c_out = int (m.weights (" conv2.c. weight" )->ne [0 ]);
93+ int c_out = int (m.weights (" conv2.weight" )->ne [0 ]);
10394 int stride = (c_out == 320 || c_out == 448 || c_out == 576 ) ? 1 : 2 ;
104- x = conv_2d_batch_norm (m[" conv2" ], x, stride, 1 , c_out );
95+ x = conv_2d_depthwise (m[" conv2" ], x, stride, 1 );
10596 x = ggml_gelu_inplace (m, x);
10697
10798 auto [c, h, w, b] = nelements (x);
108- x = conv_2d_batch_norm (m[" conv3" ], x);
99+ x = conv_2d (m[" conv3" ], x);
109100 x = ggml_reshape_3d (m, x, c, w * h, b);
110101 return named (m, x);
111102}
@@ -175,7 +166,7 @@ tensor tiny_vit_block(
175166 x = ggml_add_inplace (m, x, res_x);
176167
177168 x = ggml_reshape_4d (m, x, c, w, h, b);
178- x = conv_2d_batch_norm (m[" local_conv" ], x, 1 , 1 , /* groups */ dim );
169+ x = conv_2d_depthwise (m[" local_conv" ], x, 1 , 1 );
179170 x = ggml_reshape_3d (m, x, c, spatial, b);
180171
181172 tensor x_mlp = mlp (m[" mlp" ], x);
0 commit comments