Skip to content

Commit 9e16c3c

Browse files
author
Avishek Goswami
committed
Fix shape mismatches and remove dead code in MSE vs MinMax tests
- Fix tensor shape mismatch: use tensor directly instead of tensor.T - Fix weight_tensor shape mismatch: use weight_tensor directly instead of weight_tensor.T - Remove unused weights variable in test_mse_vs_minmax_extreme_values Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 5352b1c commit 9e16c3c

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

tests/llmcompressor/observers/test_mse_vs_minmax.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
133133
module_mse = None
134134
if strategy == "tensor_group":
135135
module_minmax = torch.nn.Linear(256, 128)
136-
module_minmax.weight.data = tensor.T
136+
module_minmax.weight.data = tensor
137137
module_mse = torch.nn.Linear(256, 128)
138-
module_mse.weight.data = tensor.T
138+
module_mse.weight.data = tensor
139139

140140
# Test with MinMax observer
141141
_, _, _, minmax_mse, _ = _run_observer_test(
@@ -202,13 +202,6 @@ def test_mse_vs_minmax_extreme_values():
202202
(tensor_large, "large"),
203203
(tensor_skewed, "skewed"),
204204
]:
205-
weights = QuantizationArgs(
206-
num_bits=8,
207-
strategy="channel",
208-
symmetric=True,
209-
observer="memoryless_minmax",
210-
)
211-
212205
# MinMax
213206
_, _, _, minmax_mse, _ = _run_observer_test(
214207
tensor, "memoryless_minmax", "channel", True, 8, None, None
@@ -289,9 +282,9 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
289282
module_mse = None
290283
if strategy == "tensor_group":
291284
module_minmax = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0])
292-
module_minmax.weight.data = weight_tensor.T
285+
module_minmax.weight.data = weight_tensor
293286
module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0])
294-
module_mse.weight.data = weight_tensor.T
287+
module_mse.weight.data = weight_tensor
295288

296289
# Test with MinMax observer
297290
_, _, _, minmax_mse, _ = _run_observer_test(

0 commit comments

Comments
 (0)