11"""
22Test to verify that MSE observer performs equal to or better than MinMax observer
3- on various tensor distributions, including normal distributions (similar to real weights)
3+ on various tensor distributions, including normal distributions (similar to real
4+ weights)
45and actual model weights.
56
67This test checks that the quantization error (MSE) from using MSE observer
@@ -25,42 +26,51 @@ def _create_base_quantization_args(num_bits, strategy, symmetric, group_size):
2526 )
2627
2728
28- def _run_observer_test (tensor , observer_name , strategy , symmetric , num_bits , group_size , module = None ):
29+ def _run_observer_test (
30+ tensor , observer_name , strategy , symmetric , num_bits , group_size , module = None
31+ ):
2932 """
3033 Helper function to run observer and compute quantization error.
31-
34+
3235 Returns: (scale, zero_point, quantized_tensor, mse, global_scale)
3336 """
3437 weights = _create_base_quantization_args (num_bits , strategy , symmetric , group_size )
3538 weights .observer = observer_name
36-
39+
3740 observer = Observer .load_from_registry (
3841 observer_name , base_name = "weight" , args = weights , module = module
3942 )
40-
43+
4144 global_scale = None
4245 if strategy == "tensor_group" and module is not None :
4346 global_scale = observer .get_global_scale (tensor )
4447 module .weight_global_scale = global_scale
45-
48+
4649 scale , zero_point = observer (tensor )
4750 assert (scale >= 0 ).all (), "Scale values should be non-negative"
48-
49- weights_clean = _create_base_quantization_args (num_bits , strategy , symmetric , group_size )
51+
52+ weights_clean = _create_base_quantization_args (
53+ num_bits , strategy , symmetric , group_size
54+ )
5055 quantized = fake_quantize (
51- tensor , scale , zero_point , weights_clean ,
56+ tensor ,
57+ scale ,
58+ zero_point ,
59+ weights_clean ,
5260 global_scale = global_scale if strategy == "tensor_group" else None
5361 )
5462 mse = torch .nn .functional .mse_loss (quantized , tensor )
55-
63+
5664 return scale , zero_point , quantized , mse , global_scale
5765
5866
59- def _assert_mse_comparison (mse_mse , minmax_mse , strategy , symmetric , is_real_weights = False ):
67+ def _assert_mse_comparison (
68+ mse_mse , minmax_mse , strategy , symmetric , is_real_weights = False
69+ ):
6070 """Assert MSE observer performance with appropriate slack."""
6171 epsilon = 1e-8
6272 slack = 1.20 if is_real_weights else 1.10
63-
73+
6474 if strategy == "tensor" and symmetric :
6575 assert mse_mse <= minmax_mse + epsilon , (
6676 f"MSE observer performed worse than MinMax observer!\n "
@@ -99,29 +109,43 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei
99109 ids = ["narrow" , "medium" , "wide" ],
100110)
101111def test_mse_vs_minmax_on_random_tensor (strategy , symmetric , num_bits , std ):
102- """Test that MSE observer produces quantization error <= MinMax observer on random tensors."""
112+ """Test MSE observer error <= MinMax observer error on random tensors."""
103113 torch .manual_seed (42 )
104114 tensor = torch .randn (128 , 256 ) * std
105-
115+
106116 group_size = 32 if strategy == "tensor_group" else None
107-
117+
108118 module_minmax = None
109119 module_mse = None
110120 if strategy == "tensor_group" :
111121 module_minmax = torch .nn .Linear (256 , 128 )
112122 module_minmax .weight .data = tensor
113123 module_mse = torch .nn .Linear (256 , 128 )
114124 module_mse .weight .data = tensor
115-
125+
116126 _ , _ , _ , minmax_mse , _ = _run_observer_test (
117- tensor , "memoryless_minmax" , strategy , symmetric , num_bits , group_size , module_minmax
127+ tensor ,
128+ "memoryless_minmax" ,
129+ strategy ,
130+ symmetric ,
131+ num_bits ,
132+ group_size ,
133+ module_minmax ,
118134 )
119-
135+
120136 _ , _ , _ , mse_mse , _ = _run_observer_test (
121- tensor , "memoryless_mse" , strategy , symmetric , num_bits , group_size , module_mse
137+ tensor ,
138+ "memoryless_mse" ,
139+ strategy ,
140+ symmetric ,
141+ num_bits ,
142+ group_size ,
143+ module_mse ,
144+ )
145+
146+ _assert_mse_comparison (
147+ mse_mse , minmax_mse , strategy , symmetric , is_real_weights = False
122148 )
123-
124- _assert_mse_comparison (mse_mse , minmax_mse , strategy , symmetric , is_real_weights = False )
125149
126150
127151@pytest .mark .parametrize (
@@ -137,29 +161,29 @@ def test_mse_vs_minmax_various_shapes(tensor_shape):
137161 """Test MSE vs MinMax on tensors of various shapes."""
138162 torch .manual_seed (42 )
139163 tensor = torch .randn (* tensor_shape ) * 0.05
140-
164+
141165 _ , _ , _ , minmax_mse , _ = _run_observer_test (
142166 tensor , "memoryless_minmax" , "channel" , True , 8 , None , None
143167 )
144-
168+
145169 _ , _ , _ , mse_mse , _ = _run_observer_test (
146170 tensor , "memoryless_mse" , "channel" , True , 8 , None , None
147171 )
148-
172+
149173 _assert_mse_comparison (mse_mse , minmax_mse , "channel" , True , is_real_weights = False )
150174
151175
152176def test_mse_vs_minmax_extreme_values ():
153177 """Test MSE vs MinMax on tensors with extreme values."""
154178 torch .manual_seed (42 )
155-
179+
156180 tensor_small = torch .randn (64 , 128 ) * 0.01
157181 tensor_large = torch .randn (64 , 128 ) * 100.0
158182 tensor_skewed = torch .cat ([
159183 torch .randn (64 , 100 ) * 0.1 ,
160184 torch .randn (64 , 28 ) * 10.0
161185 ], dim = 1 )
162-
186+
163187 for tensor , name in [
164188 (tensor_small , "small" ),
165189 (tensor_large , "large" ),
@@ -168,12 +192,14 @@ def test_mse_vs_minmax_extreme_values():
168192 _ , _ , _ , minmax_mse , _ = _run_observer_test (
169193 tensor , "memoryless_minmax" , "channel" , True , 8 , None , None
170194 )
171-
195+
172196 _ , _ , _ , mse_mse , _ = _run_observer_test (
173197 tensor , "memoryless_mse" , "channel" , True , 8 , None , None
174198 )
175-
176- _assert_mse_comparison (mse_mse , minmax_mse , "channel" , True , is_real_weights = False )
199+
200+ _assert_mse_comparison (
201+ mse_mse , minmax_mse , "channel" , True , is_real_weights = False
202+ )
177203
178204
179205@pytest .mark .slow
@@ -187,59 +213,75 @@ def test_mse_vs_minmax_extreme_values():
187213 ],
188214)
189215def test_mse_vs_minmax_on_real_model_weights (strategy , symmetric , num_bits ):
190- """Test that MSE observer produces quantization error <= MinMax observer on real model weights."""
216+ """Test MSE observer error <= MinMax observer error on real model weights."""
191217 try :
192218 from transformers import AutoModelForCausalLM
193219 except ImportError :
194220 pytest .skip ("transformers not available" )
195221
196222 model_id = "nm-testing/tinysmokellama-3.2"
197-
223+
198224 try :
199225 with torch .no_grad ():
200226 model = AutoModelForCausalLM .from_pretrained (
201227 model_id , torch_dtype = torch .float32
202228 )
203-
229+
204230 weight_tensor = None
205231 for name , module in model .named_modules ():
206232 if isinstance (module , torch .nn .Linear ) and weight_tensor is None :
207233 weight_tensor = module .weight .data .clone ()
208234 break
209-
235+
210236 if weight_tensor is None :
211237 pytest .skip ("No Linear layer found in model" )
212-
238+
213239 if weight_tensor .dim () > 2 :
214240 weight_tensor = weight_tensor .view (- 1 , weight_tensor .shape [- 1 ])
215241 elif weight_tensor .dim () == 1 :
216242 weight_tensor = weight_tensor .unsqueeze (0 )
217-
243+
218244 if weight_tensor .shape [0 ] > 512 :
219245 weight_tensor = weight_tensor [:512 , :]
220246 if weight_tensor .shape [1 ] > 512 :
221247 weight_tensor = weight_tensor [:, :512 ]
222-
248+
223249 except Exception as e :
224250 pytest .skip (f"Could not load model { model_id } : { e } " )
225-
251+
226252 group_size = 32 if strategy == "tensor_group" else None
227-
253+
228254 module_minmax = None
229255 module_mse = None
230256 if strategy == "tensor_group" :
231- module_minmax = torch .nn .Linear (weight_tensor .shape [1 ], weight_tensor .shape [0 ])
257+ module_minmax = torch .nn .Linear (
258+ weight_tensor .shape [1 ], weight_tensor .shape [0 ]
259+ )
232260 module_minmax .weight .data = weight_tensor
233261 module_mse = torch .nn .Linear (weight_tensor .shape [1 ], weight_tensor .shape [0 ])
234262 module_mse .weight .data = weight_tensor
235-
263+
236264 _ , _ , _ , minmax_mse , _ = _run_observer_test (
237- weight_tensor , "memoryless_minmax" , strategy , symmetric , num_bits , group_size , module_minmax
265+ weight_tensor ,
266+ "memoryless_minmax" ,
267+ strategy ,
268+ symmetric ,
269+ num_bits ,
270+ group_size ,
271+ module_minmax ,
238272 )
239-
273+
240274 _ , _ , _ , mse_mse , _ = _run_observer_test (
241- weight_tensor , "memoryless_mse" , strategy , symmetric , num_bits , group_size , module_mse
275+ weight_tensor ,
276+ "memoryless_mse" ,
277+ strategy ,
278+ symmetric ,
279+ num_bits ,
280+ group_size ,
281+ module_mse ,
282+ )
283+
284+ _assert_mse_comparison (
285+ mse_mse , minmax_mse , strategy , symmetric , is_real_weights = True
242286 )
243-
244- _assert_mse_comparison (mse_mse , minmax_mse , strategy , symmetric , is_real_weights = True )
245287
0 commit comments