@@ -44,8 +44,6 @@ def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, gro
4444 module .weight_global_scale = global_scale
4545
4646 scale , zero_point = observer (tensor )
47-
48- # Sanity check: scales should be non-negative
4947 assert (scale >= 0 ).all (), "Scale values should be non-negative"
5048
5149 weights_clean = _create_base_quantization_args (num_bits , strategy , symmetric , group_size )
@@ -59,18 +57,11 @@ def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, gro
5957
6058
6159def _assert_mse_comparison (mse_mse , minmax_mse , strategy , symmetric , is_real_weights = False ):
62- """
63- Assert MSE observer performance with appropriate slack.
64-
65- For tensor+symmetric: strict assertion (MSE should be better)
66- For others: allow slack (10% for synthetic, 20% for real weights)
67- Also add epsilon to handle cases where minmax_mse is near 0.
68- """
60+ """Assert MSE observer performance with appropriate slack."""
6961 epsilon = 1e-8
7062 slack = 1.20 if is_real_weights else 1.10
7163
7264 if strategy == "tensor" and symmetric :
73- # Cases where MSE SHOULD be better
7465 assert mse_mse <= minmax_mse + epsilon , (
7566 f"MSE observer performed worse than MinMax observer!\n "
7667 f"Strategy: { strategy } , Symmetric: { symmetric } \n "
@@ -79,7 +70,6 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei
7970 f"Difference: { (mse_mse - minmax_mse ).item ():.6e} "
8071 )
8172 else :
82- # Not guaranteed, but ensure not catastrophically worse
8373 assert mse_mse <= minmax_mse * slack + epsilon , (
8474 f"MSE observer performed significantly worse than MinMax observer!\n "
8575 f"Strategy: { strategy } , Symmetric: { symmetric } \n "
@@ -109,26 +99,12 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei
10999 ids = ["narrow" , "medium" , "wide" ],
110100)
111101def test_mse_vs_minmax_on_random_tensor (strategy , symmetric , num_bits , std ):
112- """
113- Test that MSE observer produces quantization error <= MinMax observer
114- on random tensors with normal distribution (similar to real model weights).
115-
116- Real model weights typically follow a normal distribution with:
117- - Mean near 0
118- - Standard deviation around 0.02-0.1 for initialized weights
119- - Range roughly [-0.5, 0.5] for most layers
120-
121- Testing with different std values exposes cases where MinMax performs poorly
122- on wide or heavy-tailed distributions, where MSE should shine.
123- """
124- # Generate random tensor with normal distribution similar to real weights
102+ """Test that MSE observer produces quantization error <= MinMax observer on random tensors."""
125103 torch .manual_seed (42 )
126- # Use different std values to test various distribution widths
127- tensor = torch .randn (128 , 256 ) * std # Normal distribution with specified std
104+ tensor = torch .randn (128 , 256 ) * std
128105
129106 group_size = 32 if strategy == "tensor_group" else None
130107
131- # Create separate modules for tensor_group to avoid shared mutable state
132108 module_minmax = None
133109 module_mse = None
134110 if strategy == "tensor_group" :
@@ -137,17 +113,14 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
137113 module_mse = torch .nn .Linear (256 , 128 )
138114 module_mse .weight .data = tensor
139115
140- # Test with MinMax observer
141116 _ , _ , _ , minmax_mse , _ = _run_observer_test (
142117 tensor , "memoryless_minmax" , strategy , symmetric , num_bits , group_size , module_minmax
143118 )
144119
145- # Test with MSE observer
146120 _ , _ , _ , mse_mse , _ = _run_observer_test (
147121 tensor , "memoryless_mse" , strategy , symmetric , num_bits , group_size , module_mse
148122 )
149123
150- # Assert with appropriate slack for synthetic data
151124 _assert_mse_comparison (mse_mse , minmax_mse , strategy , symmetric , is_real_weights = False )
152125
153126
@@ -161,37 +134,27 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
161134 ],
162135)
163136def test_mse_vs_minmax_various_shapes (tensor_shape ):
164- """
165- Test MSE vs MinMax on tensors of various shapes with normal distribution.
166- Uses realistic weight distribution parameters.
167- """
137+ """Test MSE vs MinMax on tensors of various shapes."""
168138 torch .manual_seed (42 )
169- # Use realistic weight distribution: mean=0, std=0.05
170139 tensor = torch .randn (* tensor_shape ) * 0.05
171140
172- # MinMax
173141 _ , _ , _ , minmax_mse , _ = _run_observer_test (
174142 tensor , "memoryless_minmax" , "channel" , True , 8 , None , None
175143 )
176144
177- # MSE
178145 _ , _ , _ , mse_mse , _ = _run_observer_test (
179146 tensor , "memoryless_mse" , "channel" , True , 8 , None , None
180147 )
181148
182- # Channel quantization: MSE not guaranteed better, allow 10% slack
183149 _assert_mse_comparison (mse_mse , minmax_mse , "channel" , True , is_real_weights = False )
184150
185151
186152def test_mse_vs_minmax_extreme_values ():
187153 """Test MSE vs MinMax on tensors with extreme values."""
188154 torch .manual_seed (42 )
189155
190- # Test with very small values
191156 tensor_small = torch .randn (64 , 128 ) * 0.01
192- # Test with very large values
193157 tensor_large = torch .randn (64 , 128 ) * 100.0
194- # Test with skewed distribution
195158 tensor_skewed = torch .cat ([
196159 torch .randn (64 , 100 ) * 0.1 ,
197160 torch .randn (64 , 28 ) * 10.0
@@ -202,17 +165,14 @@ def test_mse_vs_minmax_extreme_values():
202165 (tensor_large , "large" ),
203166 (tensor_skewed , "skewed" ),
204167 ]:
205- # MinMax
206168 _ , _ , _ , minmax_mse , _ = _run_observer_test (
207169 tensor , "memoryless_minmax" , "channel" , True , 8 , None , None
208170 )
209171
210- # MSE
211172 _ , _ , _ , mse_mse , _ = _run_observer_test (
212173 tensor , "memoryless_mse" , "channel" , True , 8 , None , None
213174 )
214175
215- # Channel quantization: MSE not guaranteed better, allow 10% slack
216176 _assert_mse_comparison (mse_mse , minmax_mse , "channel" , True , is_real_weights = False )
217177
218178
@@ -227,30 +187,20 @@ def test_mse_vs_minmax_extreme_values():
227187 ],
228188)
229189def test_mse_vs_minmax_on_real_model_weights (strategy , symmetric , num_bits ):
230- """
231- Test that MSE observer produces quantization error <= MinMax observer
232- on actual model weights from a real neural network.
233-
234- This test loads weights from a small model to verify observer behavior
235- on real weight distributions, which may differ from synthetic data.
236- """
190+ """Test that MSE observer produces quantization error <= MinMax observer on real model weights."""
237191 try :
238192 from transformers import AutoModelForCausalLM
239193 except ImportError :
240194 pytest .skip ("transformers not available" )
241195
242- # Use a small, publicly available model for testing
243196 model_id = "nm-testing/tinysmokellama-3.2"
244197
245198 try :
246- # Load model and extract a weight tensor
247- # Use no_grad context to avoid unnecessary gradient computation
248199 with torch .no_grad ():
249200 model = AutoModelForCausalLM .from_pretrained (
250201 model_id , torch_dtype = torch .float32
251202 )
252203
253- # Get a representative weight tensor (e.g., from first Linear layer)
254204 weight_tensor = None
255205 for name , module in model .named_modules ():
256206 if isinstance (module , torch .nn .Linear ) and weight_tensor is None :
@@ -260,13 +210,11 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
260210 if weight_tensor is None :
261211 pytest .skip ("No Linear layer found in model" )
262212
263- # Flatten or reshape to 2D if needed for testing
264213 if weight_tensor .dim () > 2 :
265214 weight_tensor = weight_tensor .view (- 1 , weight_tensor .shape [- 1 ])
266215 elif weight_tensor .dim () == 1 :
267216 weight_tensor = weight_tensor .unsqueeze (0 )
268217
269- # Limit size for faster testing
270218 if weight_tensor .shape [0 ] > 512 :
271219 weight_tensor = weight_tensor [:512 , :]
272220 if weight_tensor .shape [1 ] > 512 :
@@ -277,7 +225,6 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
277225
278226 group_size = 32 if strategy == "tensor_group" else None
279227
280- # Create separate modules for tensor_group to avoid shared mutable state
281228 module_minmax = None
282229 module_mse = None
283230 if strategy == "tensor_group" :
@@ -286,17 +233,13 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
286233 module_mse = torch .nn .Linear (weight_tensor .shape [1 ], weight_tensor .shape [0 ])
287234 module_mse .weight .data = weight_tensor
288235
289- # Test with MinMax observer
290236 _ , _ , _ , minmax_mse , _ = _run_observer_test (
291237 weight_tensor , "memoryless_minmax" , strategy , symmetric , num_bits , group_size , module_minmax
292238 )
293239
294- # Test with MSE observer
295240 _ , _ , _ , mse_mse , _ = _run_observer_test (
296241 weight_tensor , "memoryless_mse" , strategy , symmetric , num_bits , group_size , module_mse
297242 )
298243
299- # For channel and tensor_group strategies, MSE is not guaranteed to be better
300- # Allow 20% slack for real model weights (more structure & extreme channels)
301244 _assert_mse_comparison (mse_mse , minmax_mse , strategy , symmetric , is_real_weights = True )
302245
0 commit comments