Skip to content

Commit c93122b

Browse files
committed
sam: fuse batch norm into conv-2d
* model conversion fuses batch norm weights into conv-2d kernel * inference just does conv-2d with bias
1 parent f3ba06a commit c93122b

File tree

6 files changed

+79
-52
lines changed

6 files changed

+79
-52
lines changed

scripts/convert.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def convert_sam(
7979
input_filepath, map_location="cpu", weights_only=True
8080
)
8181

82-
for name, tensor in model.items():
82+
for key, tensor in model.items():
83+
name = key
8384
name = name.replace("image_encoder.", "enc.")
8485
name = name.replace("mask_decoder.", "dec.")
8586
name = name.replace("_image_to_token.", "_i2t.")
@@ -92,14 +93,17 @@ def convert_sam(
9293
name = name + "_indexed"
9394
tensor = tensor[:, attention_bias_idxs]
9495

95-
if name.endswith("running_var"):
96-
tensor = torch.sqrt(tensor + batch_norm_eps)
97-
98-
if (
99-
name.endswith("c.weight")
100-
or name.endswith("neck.0.weight")
101-
or name.endswith("neck.2.weight")
102-
):
96+
if name.endswith("c.weight"):
97+
name = name.removesuffix(".c.weight")
98+
weight, bias = fuse_conv_2d_batch_norm(model, key.removesuffix(".c.weight"))
99+
weight = conv_2d_to_nhwc(weight)
100+
add_tensor(writer, f"{name}.weight", weight, quantize, verbose)
101+
add_tensor(writer, f"{name}.bias", bias, quantize, verbose)
102+
continue
103+
if ".bn." in name:
104+
continue # batch norm is fused above
105+
106+
if name.endswith("neck.0.weight") or name.endswith("neck.2.weight"):
103107
assert tensor.shape[2] == tensor.shape[3] and tensor.shape[2] <= 3
104108
tensor = conv_2d_to_nhwc(tensor)
105109

@@ -115,6 +119,19 @@ def convert_sam(
115119
add_tensor(writer, name, tensor, data_type, verbose)
116120

117121

122+
def fuse_conv_2d_batch_norm(model: dict[str, Tensor], key: str):
123+
conv_weight = model[f"{key}.c.weight"]
124+
bn_weight = model[f"{key}.bn.weight"]
125+
bn_bias = model[f"{key}.bn.bias"]
126+
bn_mean = model[f"{key}.bn.running_mean"]
127+
bn_var = model[f"{key}.bn.running_var"]
128+
129+
bn_weight = bn_weight / torch.sqrt(bn_var + batch_norm_eps)
130+
fused_weight = conv_weight * bn_weight[:, None, None, None]
131+
fused_bias = bn_bias - bn_mean * bn_weight
132+
return fused_weight, fused_bias
133+
134+
118135
def build_attention_bias_indices(resolution: int):
119136
points = list(itertools.product(range(resolution), range(resolution)))
120137
N = len(points)
@@ -238,11 +255,7 @@ def convert_esrgan(
238255
"esrgan": "esrgan",
239256
}
240257

241-
file_types = {
242-
None: 0,
243-
"f32": 0,
244-
"f16": 1
245-
}
258+
file_types = {None: 0, "f32": 0, "f16": 1}
246259

247260
if __name__ == "__main__":
248261
# fmt: off
@@ -269,7 +282,9 @@ def convert_esrgan(
269282

270283
try:
271284
writer = GGUFWriter(output_path, arch_names.get(args.arch, args.arch))
272-
metadata = Metadata.load(args.metadata, input_path.with_suffix(""), args.model_name)
285+
metadata = Metadata.load(
286+
args.metadata, input_path.with_suffix(""), args.model_name
287+
)
273288

274289
match args.arch:
275290
case "sam":

src/visp/mobile-sam.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,33 +59,24 @@ tensor window_reverse(model_ref m, tensor x, int w, int h, int window) {
5959
// Image encoder
6060
//
6161

62-
tensor conv_2d_batch_norm(model_ref m, tensor x, int stride, int pad, int groups) {
63-
if (groups == 1) {
64-
x = conv_2d(m["c"], x, stride, pad);
65-
} else {
66-
x = conv_2d_depthwise(m["c"], x, stride, pad);
67-
}
68-
x = batch_norm_2d(m["bn"], x);
69-
return named(m, x);
70-
}
7162

7263
tensor patch_embed(model_ref m, tensor x) {
73-
x = conv_2d_batch_norm(m["seq.0"], x, 2, 1);
64+
x = conv_2d(m["seq.0"], x, 2, 1);
7465
x = ggml_gelu_inplace(m, x);
75-
x = conv_2d_batch_norm(m["seq.2"], x, 2, 1);
66+
x = conv_2d(m["seq.2"], x, 2, 1);
7667
return named(m, x);
7768
}
7869

7970
tensor mb_conv(model_ref m, tensor x) {
8071
tensor shortcut = x;
8172

82-
x = conv_2d_batch_norm(m["conv1"], x);
73+
x = conv_2d(m["conv1"], x);
8374
x = ggml_gelu_inplace(m, x);
8475

85-
x = conv_2d_batch_norm(m["conv2"], x, 1, 1, /* groups */ int(x->ne[2]));
76+
x = conv_2d_depthwise(m["conv2"], x, 1, 1);
8677
x = ggml_gelu_inplace(m, x);
8778

88-
x = conv_2d_batch_norm(m["conv3"], x);
79+
x = conv_2d(m["conv3"], x);
8980
x = ggml_add_inplace(m, x, shortcut);
9081
x = ggml_gelu_inplace(m, x);
9182

@@ -96,16 +87,16 @@ tensor patch_merging(model_ref m, tensor x, int input_resolution) {
9687
if (x->ne[2] == 1) {
9788
x = ggml_reshape_4d(m, x, x->ne[0], input_resolution, input_resolution, x->ne[3]);
9889
}
99-
x = conv_2d_batch_norm(m["conv1"], x);
90+
x = conv_2d(m["conv1"], x);
10091
x = ggml_gelu_inplace(m, x);
10192

102-
int c_out = int(m.weights("conv2.c.weight")->ne[0]);
93+
int c_out = int(m.weights("conv2.weight")->ne[0]);
10394
int stride = (c_out == 320 || c_out == 448 || c_out == 576) ? 1 : 2;
104-
x = conv_2d_batch_norm(m["conv2"], x, stride, 1, c_out);
95+
x = conv_2d_depthwise(m["conv2"], x, stride, 1);
10596
x = ggml_gelu_inplace(m, x);
10697

10798
auto [c, h, w, b] = nelements(x);
108-
x = conv_2d_batch_norm(m["conv3"], x);
99+
x = conv_2d(m["conv3"], x);
109100
x = ggml_reshape_3d(m, x, c, w * h, b);
110101
return named(m, x);
111102
}
@@ -175,7 +166,7 @@ tensor tiny_vit_block(
175166
x = ggml_add_inplace(m, x, res_x);
176167

177168
x = ggml_reshape_4d(m, x, c, w, h, b);
178-
x = conv_2d_batch_norm(m["local_conv"], x, 1, 1, /* groups */ dim);
169+
x = conv_2d_depthwise(m["local_conv"], x, 1, 1);
179170
x = ggml_reshape_3d(m, x, c, spatial, b);
180171

181172
tensor x_mlp = mlp(m["mlp"], x);

src/visp/mobile-sam.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ struct tiny_vit_params {
3939

4040
float resize_longest_side(i32x2 extent, int target_longest_side);
4141

42-
tensor conv_2d_batch_norm(model_ref m, tensor x, int stride = 1, int pad = 0, int groups = 1);
4342
tensor patch_embed(model_ref m, tensor x);
4443
tensor mb_conv(model_ref m, tensor x);
4544
tensor patch_merging(model_ref m, tensor x, int input_resolution);

tests/test_mobile_sam.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,30 @@ def __init__(
3030
self.add_module("bn", bn)
3131

3232

33-
def add_variance_epsilon(state: dict[str, torch.Tensor], epsilon=1e-5):
34-
for k in state:
35-
if k.endswith("running_var"):
36-
state[k] = torch.sqrt(state[k] + 1e-5).contiguous()
33+
def fuse_conv_2d_batch_norm(model: dict[str, Tensor], key: str, epsilon=1e-5):
34+
conv_weight = model[f"{key}c.weight"]
35+
bn_weight = model[f"{key}bn.weight"]
36+
bn_bias = model[f"{key}bn.bias"]
37+
bn_mean = model[f"{key}bn.running_mean"]
38+
bn_var = model[f"{key}bn.running_var"]
39+
40+
bn_weight = bn_weight / torch.sqrt(bn_var + epsilon)
41+
fused_weight = conv_weight * bn_weight[:, None, None, None]
42+
fused_bias = bn_bias - bn_mean * bn_weight
43+
return fused_weight, fused_bias
44+
45+
46+
def fuse_all_conv_2d_batch_norm(model: dict[str, Tensor]):
47+
fused_weights = {}
48+
for k in model:
49+
if k.endswith("c.weight"):
50+
key = k.removesuffix("c.weight")
51+
weight, bias = fuse_conv_2d_batch_norm(model, key)
52+
fused_weights[f"{key}weight"] = weight
53+
fused_weights[f"{key}bias"] = bias
54+
elif not k.endswith("num_batches_tracked"):
55+
fused_weights[k] = model[k]
56+
return fused_weights
3757

3858

3959
def test_conv_2d_batch_norm():
@@ -45,8 +65,8 @@ def test_conv_2d_batch_norm():
4565
x = torch.rand(1, 4, 8, 8)
4666
expected = conv2dbn(x)
4767

48-
add_variance_epsilon(state)
49-
convert_to_nhwc(state)
68+
state = fuse_all_conv_2d_batch_norm(state)
69+
state = convert_to_nhwc(state)
5070
x = to_nhwc(x)
5171
result = workbench.invoke_test("sam_conv_2d_batch_norm", x, state)
5272
result = to_nchw(result)
@@ -89,7 +109,7 @@ def test_patch_embed():
89109
x = torch.rand(1, 3, 8, 8)
90110
expected = patch_embed(x)
91111

92-
add_variance_epsilon(state)
112+
state = fuse_all_conv_2d_batch_norm(state)
93113
convert_to_nhwc(state)
94114
x = to_nhwc(x)
95115
result = to_nhwc(torch.zeros_like(expected))
@@ -184,7 +204,7 @@ def test_mb_conv():
184204
x = torch.rand(1, 4, 8, 8)
185205
expected = mb_conv(x)
186206

187-
add_variance_epsilon(state)
207+
state = fuse_all_conv_2d_batch_norm(state)
188208
convert_to_nhwc(state)
189209
x = to_nhwc(x)
190210
result = workbench.invoke_test("sam_mb_conv", x, state)
@@ -235,10 +255,10 @@ def test_patch_merging():
235255
x = torch.rand(1, 8, 32, 32)
236256
expected = patch_merging(x)
237257

238-
add_variance_epsilon(state)
258+
state = fuse_all_conv_2d_batch_norm(state)
239259
convert_to_nhwc(state)
240260
x = to_nhwc(x)
241-
result = result = workbench.invoke_test("sam_patch_merging", x, state)
261+
result = workbench.invoke_test("sam_patch_merging", x, state)
242262
result = result.transpose(1, 2).reshape_as(expected)
243263

244264
# precision: ggml_gelu uses fp16 look-up table & tanh approximation
@@ -497,9 +517,9 @@ def test_tiny_vit_block():
497517
state["attn.attention_biases_indexed"] = state["attn.attention_biases"][
498518
:, tiny_vit_block.attn.attention_bias_idxs
499519
]
500-
add_variance_epsilon(state)
501-
convert_to_nhwc(state)
502-
result = result = workbench.invoke_test("sam_tiny_vit_block", x, state)
520+
state = fuse_all_conv_2d_batch_norm(state)
521+
state = convert_to_nhwc(state)
522+
result = workbench.invoke_test("sam_tiny_vit_block", x, state)
503523

504524
assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
505525

tests/workbench.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ DEF(linear)(model_ref m, span<tensor> input, param_dict const& p) {
118118
// Mobile SAM
119119

120120
DEF(sam_conv_2d_batch_norm)(model_ref m, span<tensor> input, param_dict const& p) {
121-
return {sam::conv_2d_batch_norm(m, input[0], 2, 1)};
121+
return {conv_2d(m, input[0], 2, 1)}; // fused conv_2d + batch_norm
122122
}
123123

124124
DEF(sam_patch_embed)(model_ref m, span<tensor> input, param_dict const& p) {

tests/workbench.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def encode_params(params: dict[str, str | int | float]):
9090
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
9191

9292
root_dir = Path(__file__).parent.parent
93-
lib = ctypes.CDLL(str(root_dir / "build" / "bin" / "vision-workbench.dll"))
93+
bin_dir = root_dir / "build" / "bin"
94+
95+
lib = ctypes.CDLL(str(bin_dir / "vision-workbench.dll"))
9496
lib.visp_workbench.argtypes = [
9597
ctypes.c_char_p,
9698
ctypes.POINTER(RawTensor),
@@ -174,15 +176,15 @@ def to_nchw(tensor: torch.Tensor):
174176
return tensor.permute(0, 3, 1, 2).contiguous()
175177

176178

177-
def convert_to_nhwc(state: dict[str, torch.Tensor], key="c."):
179+
def convert_to_nhwc(state: dict[str, torch.Tensor], key=""):
178180
for k, v in state.items():
179181
is_conv = (
180182
v.ndim == 4
181183
and v.shape[2] == v.shape[3]
182184
and v.shape[2] in (1, 3, 4, 7)
183185
and k.endswith("weight")
184186
)
185-
if key in k and is_conv:
187+
if is_conv and (key == "" or key in k):
186188
if v.shape[1] == 1: # depthwise
187189
state[k] = v.permute(2, 3, 1, 0).contiguous()
188190
else:

0 commit comments

Comments
 (0)