Skip to content

Commit 2c9b5a1

Browse files
author
Avishek Goswami
committed
Fix ruff formatting in MSE vs MinMax observer tests
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 30fff56 commit 2c9b5a1

File tree

1 file changed

+87
-45
lines changed

1 file changed

+87
-45
lines changed

tests/llmcompressor/observers/test_mse_vs_minmax.py

Lines changed: 87 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Test to verify that MSE observer performs equal to or better than MinMax observer
3-
on various tensor distributions, including normal distributions (similar to real weights)
3+
on various tensor distributions, including normal distributions (similar to real
4+
weights)
45
and actual model weights.
56
67
This test checks that the quantization error (MSE) from using MSE observer
@@ -25,42 +26,51 @@ def _create_base_quantization_args(num_bits, strategy, symmetric, group_size):
2526
)
2627

2728

28-
def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, group_size, module=None):
29+
def _run_observer_test(
30+
tensor, observer_name, strategy, symmetric, num_bits, group_size, module=None
31+
):
2932
"""
3033
Helper function to run observer and compute quantization error.
31-
34+
3235
Returns: (scale, zero_point, quantized_tensor, mse, global_scale)
3336
"""
3437
weights = _create_base_quantization_args(num_bits, strategy, symmetric, group_size)
3538
weights.observer = observer_name
36-
39+
3740
observer = Observer.load_from_registry(
3841
observer_name, base_name="weight", args=weights, module=module
3942
)
40-
43+
4144
global_scale = None
4245
if strategy == "tensor_group" and module is not None:
4346
global_scale = observer.get_global_scale(tensor)
4447
module.weight_global_scale = global_scale
45-
48+
4649
scale, zero_point = observer(tensor)
4750
assert (scale >= 0).all(), "Scale values should be non-negative"
48-
49-
weights_clean = _create_base_quantization_args(num_bits, strategy, symmetric, group_size)
51+
52+
weights_clean = _create_base_quantization_args(
53+
num_bits, strategy, symmetric, group_size
54+
)
5055
quantized = fake_quantize(
51-
tensor, scale, zero_point, weights_clean,
56+
tensor,
57+
scale,
58+
zero_point,
59+
weights_clean,
5260
global_scale=global_scale if strategy == "tensor_group" else None
5361
)
5462
mse = torch.nn.functional.mse_loss(quantized, tensor)
55-
63+
5664
return scale, zero_point, quantized, mse, global_scale
5765

5866

59-
def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False):
67+
def _assert_mse_comparison(
68+
mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False
69+
):
6070
"""Assert MSE observer performance with appropriate slack."""
6171
epsilon = 1e-8
6272
slack = 1.20 if is_real_weights else 1.10
63-
73+
6474
if strategy == "tensor" and symmetric:
6575
assert mse_mse <= minmax_mse + epsilon, (
6676
f"MSE observer performed worse than MinMax observer!\n"
@@ -99,29 +109,43 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei
99109
ids=["narrow", "medium", "wide"],
100110
)
101111
def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
102-
"""Test that MSE observer produces quantization error <= MinMax observer on random tensors."""
112+
"""Test MSE observer error <= MinMax observer error on random tensors."""
103113
torch.manual_seed(42)
104114
tensor = torch.randn(128, 256) * std
105-
115+
106116
group_size = 32 if strategy == "tensor_group" else None
107-
117+
108118
module_minmax = None
109119
module_mse = None
110120
if strategy == "tensor_group":
111121
module_minmax = torch.nn.Linear(256, 128)
112122
module_minmax.weight.data = tensor
113123
module_mse = torch.nn.Linear(256, 128)
114124
module_mse.weight.data = tensor
115-
125+
116126
_, _, _, minmax_mse, _ = _run_observer_test(
117-
tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax
127+
tensor,
128+
"memoryless_minmax",
129+
strategy,
130+
symmetric,
131+
num_bits,
132+
group_size,
133+
module_minmax,
118134
)
119-
135+
120136
_, _, _, mse_mse, _ = _run_observer_test(
121-
tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse
137+
tensor,
138+
"memoryless_mse",
139+
strategy,
140+
symmetric,
141+
num_bits,
142+
group_size,
143+
module_mse,
144+
)
145+
146+
_assert_mse_comparison(
147+
mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False
122148
)
123-
124-
_assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False)
125149

126150

127151
@pytest.mark.parametrize(
@@ -137,29 +161,29 @@ def test_mse_vs_minmax_various_shapes(tensor_shape):
137161
"""Test MSE vs MinMax on tensors of various shapes."""
138162
torch.manual_seed(42)
139163
tensor = torch.randn(*tensor_shape) * 0.05
140-
164+
141165
_, _, _, minmax_mse, _ = _run_observer_test(
142166
tensor, "memoryless_minmax", "channel", True, 8, None, None
143167
)
144-
168+
145169
_, _, _, mse_mse, _ = _run_observer_test(
146170
tensor, "memoryless_mse", "channel", True, 8, None, None
147171
)
148-
172+
149173
_assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False)
150174

151175

152176
def test_mse_vs_minmax_extreme_values():
153177
"""Test MSE vs MinMax on tensors with extreme values."""
154178
torch.manual_seed(42)
155-
179+
156180
tensor_small = torch.randn(64, 128) * 0.01
157181
tensor_large = torch.randn(64, 128) * 100.0
158182
tensor_skewed = torch.cat([
159183
torch.randn(64, 100) * 0.1,
160184
torch.randn(64, 28) * 10.0
161185
], dim=1)
162-
186+
163187
for tensor, name in [
164188
(tensor_small, "small"),
165189
(tensor_large, "large"),
@@ -168,12 +192,14 @@ def test_mse_vs_minmax_extreme_values():
168192
_, _, _, minmax_mse, _ = _run_observer_test(
169193
tensor, "memoryless_minmax", "channel", True, 8, None, None
170194
)
171-
195+
172196
_, _, _, mse_mse, _ = _run_observer_test(
173197
tensor, "memoryless_mse", "channel", True, 8, None, None
174198
)
175-
176-
_assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False)
199+
200+
_assert_mse_comparison(
201+
mse_mse, minmax_mse, "channel", True, is_real_weights=False
202+
)
177203

178204

179205
@pytest.mark.slow
@@ -187,59 +213,75 @@ def test_mse_vs_minmax_extreme_values():
187213
],
188214
)
189215
def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
190-
"""Test that MSE observer produces quantization error <= MinMax observer on real model weights."""
216+
"""Test MSE observer error <= MinMax observer error on real model weights."""
191217
try:
192218
from transformers import AutoModelForCausalLM
193219
except ImportError:
194220
pytest.skip("transformers not available")
195221

196222
model_id = "nm-testing/tinysmokellama-3.2"
197-
223+
198224
try:
199225
with torch.no_grad():
200226
model = AutoModelForCausalLM.from_pretrained(
201227
model_id, torch_dtype=torch.float32
202228
)
203-
229+
204230
weight_tensor = None
205231
for name, module in model.named_modules():
206232
if isinstance(module, torch.nn.Linear) and weight_tensor is None:
207233
weight_tensor = module.weight.data.clone()
208234
break
209-
235+
210236
if weight_tensor is None:
211237
pytest.skip("No Linear layer found in model")
212-
238+
213239
if weight_tensor.dim() > 2:
214240
weight_tensor = weight_tensor.view(-1, weight_tensor.shape[-1])
215241
elif weight_tensor.dim() == 1:
216242
weight_tensor = weight_tensor.unsqueeze(0)
217-
243+
218244
if weight_tensor.shape[0] > 512:
219245
weight_tensor = weight_tensor[:512, :]
220246
if weight_tensor.shape[1] > 512:
221247
weight_tensor = weight_tensor[:, :512]
222-
248+
223249
except Exception as e:
224250
pytest.skip(f"Could not load model {model_id}: {e}")
225-
251+
226252
group_size = 32 if strategy == "tensor_group" else None
227-
253+
228254
module_minmax = None
229255
module_mse = None
230256
if strategy == "tensor_group":
231-
module_minmax = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0])
257+
module_minmax = torch.nn.Linear(
258+
weight_tensor.shape[1], weight_tensor.shape[0]
259+
)
232260
module_minmax.weight.data = weight_tensor
233261
module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0])
234262
module_mse.weight.data = weight_tensor
235-
263+
236264
_, _, _, minmax_mse, _ = _run_observer_test(
237-
weight_tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax
265+
weight_tensor,
266+
"memoryless_minmax",
267+
strategy,
268+
symmetric,
269+
num_bits,
270+
group_size,
271+
module_minmax,
238272
)
239-
273+
240274
_, _, _, mse_mse, _ = _run_observer_test(
241-
weight_tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse
275+
weight_tensor,
276+
"memoryless_mse",
277+
strategy,
278+
symmetric,
279+
num_bits,
280+
group_size,
281+
module_mse,
282+
)
283+
284+
_assert_mse_comparison(
285+
mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True
242286
)
243-
244-
_assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True)
245287

0 commit comments

Comments
 (0)