Skip to content

Commit 30fff56

Browse files
author
Avishek Goswami
committed
Remove excessive comments from test file
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 9e16c3c commit 30fff56

File tree

1 file changed

+5
-62
lines changed

1 file changed

+5
-62
lines changed

tests/llmcompressor/observers/test_mse_vs_minmax.py

Lines changed: 5 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, gro
4444
module.weight_global_scale = global_scale
4545

4646
scale, zero_point = observer(tensor)
47-
48-
# Sanity check: scales should be non-negative
4947
assert (scale >= 0).all(), "Scale values should be non-negative"
5048

5149
weights_clean = _create_base_quantization_args(num_bits, strategy, symmetric, group_size)
@@ -59,18 +57,11 @@ def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, gro
5957

6058

6159
def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False):
62-
"""
63-
Assert MSE observer performance with appropriate slack.
64-
65-
For tensor+symmetric: strict assertion (MSE should be better)
66-
For others: allow slack (10% for synthetic, 20% for real weights)
67-
Also add epsilon to handle cases where minmax_mse is near 0.
68-
"""
60+
"""Assert MSE observer performance with appropriate slack."""
6961
epsilon = 1e-8
7062
slack = 1.20 if is_real_weights else 1.10
7163

7264
if strategy == "tensor" and symmetric:
73-
# Cases where MSE SHOULD be better
7465
assert mse_mse <= minmax_mse + epsilon, (
7566
f"MSE observer performed worse than MinMax observer!\n"
7667
f"Strategy: {strategy}, Symmetric: {symmetric}\n"
@@ -79,7 +70,6 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei
7970
f"Difference: {(mse_mse - minmax_mse).item():.6e}"
8071
)
8172
else:
82-
# Not guaranteed, but ensure not catastrophically worse
8373
assert mse_mse <= minmax_mse * slack + epsilon, (
8474
f"MSE observer performed significantly worse than MinMax observer!\n"
8575
f"Strategy: {strategy}, Symmetric: {symmetric}\n"
@@ -109,26 +99,12 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei
10999
ids=["narrow", "medium", "wide"],
110100
)
111101
def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
112-
"""
113-
Test that MSE observer produces quantization error <= MinMax observer
114-
on random tensors with normal distribution (similar to real model weights).
115-
116-
Real model weights typically follow a normal distribution with:
117-
- Mean near 0
118-
- Standard deviation around 0.02-0.1 for initialized weights
119-
- Range roughly [-0.5, 0.5] for most layers
120-
121-
Testing with different std values exposes cases where MinMax performs poorly
122-
on wide or heavy-tailed distributions, where MSE should shine.
123-
"""
124-
# Generate random tensor with normal distribution similar to real weights
102+
"""Test that MSE observer produces quantization error <= MinMax observer on random tensors."""
125103
torch.manual_seed(42)
126-
# Use different std values to test various distribution widths
127-
tensor = torch.randn(128, 256) * std # Normal distribution with specified std
104+
tensor = torch.randn(128, 256) * std
128105

129106
group_size = 32 if strategy == "tensor_group" else None
130107

131-
# Create separate modules for tensor_group to avoid shared mutable state
132108
module_minmax = None
133109
module_mse = None
134110
if strategy == "tensor_group":
@@ -137,17 +113,14 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
137113
module_mse = torch.nn.Linear(256, 128)
138114
module_mse.weight.data = tensor
139115

140-
# Test with MinMax observer
141116
_, _, _, minmax_mse, _ = _run_observer_test(
142117
tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax
143118
)
144119

145-
# Test with MSE observer
146120
_, _, _, mse_mse, _ = _run_observer_test(
147121
tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse
148122
)
149123

150-
# Assert with appropriate slack for synthetic data
151124
_assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False)
152125

153126

@@ -161,37 +134,27 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
161134
],
162135
)
163136
def test_mse_vs_minmax_various_shapes(tensor_shape):
164-
"""
165-
Test MSE vs MinMax on tensors of various shapes with normal distribution.
166-
Uses realistic weight distribution parameters.
167-
"""
137+
"""Test MSE vs MinMax on tensors of various shapes."""
168138
torch.manual_seed(42)
169-
# Use realistic weight distribution: mean=0, std=0.05
170139
tensor = torch.randn(*tensor_shape) * 0.05
171140

172-
# MinMax
173141
_, _, _, minmax_mse, _ = _run_observer_test(
174142
tensor, "memoryless_minmax", "channel", True, 8, None, None
175143
)
176144

177-
# MSE
178145
_, _, _, mse_mse, _ = _run_observer_test(
179146
tensor, "memoryless_mse", "channel", True, 8, None, None
180147
)
181148

182-
# Channel quantization: MSE not guaranteed better, allow 10% slack
183149
_assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False)
184150

185151

186152
def test_mse_vs_minmax_extreme_values():
187153
"""Test MSE vs MinMax on tensors with extreme values."""
188154
torch.manual_seed(42)
189155

190-
# Test with very small values
191156
tensor_small = torch.randn(64, 128) * 0.01
192-
# Test with very large values
193157
tensor_large = torch.randn(64, 128) * 100.0
194-
# Test with skewed distribution
195158
tensor_skewed = torch.cat([
196159
torch.randn(64, 100) * 0.1,
197160
torch.randn(64, 28) * 10.0
@@ -202,17 +165,14 @@ def test_mse_vs_minmax_extreme_values():
202165
(tensor_large, "large"),
203166
(tensor_skewed, "skewed"),
204167
]:
205-
# MinMax
206168
_, _, _, minmax_mse, _ = _run_observer_test(
207169
tensor, "memoryless_minmax", "channel", True, 8, None, None
208170
)
209171

210-
# MSE
211172
_, _, _, mse_mse, _ = _run_observer_test(
212173
tensor, "memoryless_mse", "channel", True, 8, None, None
213174
)
214175

215-
# Channel quantization: MSE not guaranteed better, allow 10% slack
216176
_assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False)
217177

218178

@@ -227,30 +187,20 @@ def test_mse_vs_minmax_extreme_values():
227187
],
228188
)
229189
def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
230-
"""
231-
Test that MSE observer produces quantization error <= MinMax observer
232-
on actual model weights from a real neural network.
233-
234-
This test loads weights from a small model to verify observer behavior
235-
on real weight distributions, which may differ from synthetic data.
236-
"""
190+
"""Test that MSE observer produces quantization error <= MinMax observer on real model weights."""
237191
try:
238192
from transformers import AutoModelForCausalLM
239193
except ImportError:
240194
pytest.skip("transformers not available")
241195

242-
# Use a small, publicly available model for testing
243196
model_id = "nm-testing/tinysmokellama-3.2"
244197

245198
try:
246-
# Load model and extract a weight tensor
247-
# Use no_grad context to avoid unnecessary gradient computation
248199
with torch.no_grad():
249200
model = AutoModelForCausalLM.from_pretrained(
250201
model_id, torch_dtype=torch.float32
251202
)
252203

253-
# Get a representative weight tensor (e.g., from first Linear layer)
254204
weight_tensor = None
255205
for name, module in model.named_modules():
256206
if isinstance(module, torch.nn.Linear) and weight_tensor is None:
@@ -260,13 +210,11 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
260210
if weight_tensor is None:
261211
pytest.skip("No Linear layer found in model")
262212

263-
# Flatten or reshape to 2D if needed for testing
264213
if weight_tensor.dim() > 2:
265214
weight_tensor = weight_tensor.view(-1, weight_tensor.shape[-1])
266215
elif weight_tensor.dim() == 1:
267216
weight_tensor = weight_tensor.unsqueeze(0)
268217

269-
# Limit size for faster testing
270218
if weight_tensor.shape[0] > 512:
271219
weight_tensor = weight_tensor[:512, :]
272220
if weight_tensor.shape[1] > 512:
@@ -277,7 +225,6 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
277225

278226
group_size = 32 if strategy == "tensor_group" else None
279227

280-
# Create separate modules for tensor_group to avoid shared mutable state
281228
module_minmax = None
282229
module_mse = None
283230
if strategy == "tensor_group":
@@ -286,17 +233,13 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
286233
module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0])
287234
module_mse.weight.data = weight_tensor
288235

289-
# Test with MinMax observer
290236
_, _, _, minmax_mse, _ = _run_observer_test(
291237
weight_tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax
292238
)
293239

294-
# Test with MSE observer
295240
_, _, _, mse_mse, _ = _run_observer_test(
296241
weight_tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse
297242
)
298243

299-
# For channel and tensor_group strategies, MSE is not guaranteed to be better
300-
# Allow 20% slack for real model weights (more structure & extreme channels)
301244
_assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True)
302245

0 commit comments

Comments
 (0)