|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import random |
| 8 | +import unittest |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch import nn |
| 12 | +from torch.testing._internal import common_utils |
| 13 | + |
| 14 | +from torchao.prototype.pat.group import ( |
| 15 | + AttentionHeadGrouperDim0, |
| 16 | + AttentionHeadGrouperDim1, |
| 17 | + PackedSVDGrouper, |
| 18 | + QKGrouper, |
| 19 | + QKSVDGrouper, |
| 20 | + SVDGrouper, |
| 21 | +) |
| 22 | +from torchao.prototype.pat.layers.masked_layernorm import MaskedLayerNorm |
| 23 | +from torchao.prototype.pat.optim import ProxGroupLasso, ProxNuclearNorm, PruneOptimizer |
| 24 | +from torchao.prototype.pat.utils import get_param_groups |
| 25 | + |
| 26 | + |
| 27 | +class TestMaskedLayerNorm(common_utils.TestCase): |
| 28 | + @common_utils.parametrize("batch", [1, 4]) |
| 29 | + @common_utils.parametrize("seq_len", [2, 8]) |
| 30 | + @common_utils.parametrize("embed_dim", [16, 64]) |
| 31 | + def test_masked_layernorm(self, batch=1, seq_len=2, embed_dim=16): |
| 32 | + dim2_nz = embed_dim // 2 |
| 33 | + embed = torch.randn(batch, seq_len, embed_dim) |
| 34 | + embed[..., dim2_nz:] = 0 |
| 35 | + |
| 36 | + masked_layer_norm = MaskedLayerNorm(embed_dim) |
| 37 | + layer_norm = nn.LayerNorm(dim2_nz) |
| 38 | + with torch.no_grad(): |
| 39 | + layer_norm.weight.copy_(masked_layer_norm.weight[:dim2_nz]) |
| 40 | + layer_norm.bias.copy_(masked_layer_norm.bias[:dim2_nz]) |
| 41 | + |
| 42 | + out = masked_layer_norm(embed) |
| 43 | + expected_out = layer_norm(embed[..., :dim2_nz]) |
| 44 | + torch.testing.assert_close(out[..., :dim2_nz], expected_out) |
| 45 | + |
| 46 | + |
| 47 | +class MHADummyModel(nn.Module): |
| 48 | + def __init__(self, embed_dim, num_heads, n_cls): |
| 49 | + super().__init__() |
| 50 | + self.mha = nn.MultiheadAttention(embed_dim, num_heads, bias=False) |
| 51 | + self.classifier = nn.Linear(embed_dim, n_cls) |
| 52 | + |
| 53 | + def forward(self, x): |
| 54 | + attn_output, _ = self.mha(x, x, x) |
| 55 | + out = self.classifier(attn_output) |
| 56 | + return out |
| 57 | + |
| 58 | + |
| 59 | +class TestQKGrouper(common_utils.TestCase): |
| 60 | + def __init__(self, methodName): |
| 61 | + super(TestQKGrouper, self).__init__(methodName) |
| 62 | + self.reg_lambda = 1.0 |
| 63 | + self.prox_map = ProxGroupLasso(self.reg_lambda) |
| 64 | + |
| 65 | + @staticmethod |
| 66 | + def _get_qk(p, embed_dim, qk_reg_index): |
| 67 | + qk = p[:embed_dim] if qk_reg_index == 0 else p[embed_dim : (embed_dim * 2)] |
| 68 | + return qk |
| 69 | + |
| 70 | + def get_gamma(self, p): |
| 71 | + """Heuristic that uses the mean of the group to set gamma.""" |
| 72 | + p_col = p[:, 0] |
| 73 | + gamma = (1 - p_col.mean()) * torch.linalg.vector_norm(p_col) |
| 74 | + gamma.div_(self.prox_map.tau(p_col)) |
| 75 | + return gamma |
| 76 | + |
| 77 | + def _test_post_prune(self, p, qk_orig, embed_dim, qk_reg_index, gamma): |
| 78 | + qk = self._get_qk(p, embed_dim, qk_reg_index) |
| 79 | + nz_mask = qk.sum(dim=0).ne(0) |
| 80 | + self.assertTrue(nz_mask.eq(0).any(), "No columns of Q/K were pruned") |
| 81 | + |
| 82 | + # original columns that are <= gamma are pruned |
| 83 | + expect_nz_mask = qk_orig.gt(gamma).all(dim=0) |
| 84 | + torch.testing.assert_close(nz_mask > 1, expect_nz_mask, atol=0, rtol=0) |
| 85 | + |
| 86 | + def _test_mha_inner(self, p, embed_dim, qk_reg_index): |
| 87 | + qk_orig = self._get_qk(p, embed_dim, qk_reg_index).clone() |
| 88 | + qk_no_prune = self._get_qk(p, embed_dim, int(not qk_reg_index)).clone() |
| 89 | + v_orig = p[(embed_dim * 2) :].clone() |
| 90 | + qk_pack_dim = 0 |
| 91 | + with QKGrouper(p, qk_pack_dim, qk_reg_index) as grouper: |
| 92 | + self.assertTrue(grouper.p.equal(qk_orig)) |
| 93 | + |
| 94 | + gamma = self.get_gamma(grouper.p) |
| 95 | + _ = torch.vmap( |
| 96 | + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 |
| 97 | + )(grouper.p, gamma) |
| 98 | + |
| 99 | + self._test_post_prune(p, qk_orig, embed_dim, qk_reg_index, gamma) |
| 100 | + |
| 101 | + # unregularized query or key was not modified |
| 102 | + no_prune = self._get_qk(p, embed_dim, int(not qk_reg_index)) |
| 103 | + torch.testing.assert_close(no_prune, qk_no_prune, atol=0, rtol=0) |
| 104 | + |
| 105 | + # value was not modified |
| 106 | + v = p[(embed_dim * 2) :] |
| 107 | + torch.testing.assert_close(v, v_orig, atol=0, rtol=0) |
| 108 | + |
| 109 | + @common_utils.parametrize("embed_dim", [16, 64]) |
| 110 | + @common_utils.parametrize("num_heads", [2, 4]) |
| 111 | + @common_utils.parametrize("qk_reg_index", [0, 1]) |
| 112 | + def test_pytorch_mha(self, embed_dim=16, num_heads=4, qk_reg_index=0): |
| 113 | + assert embed_dim % num_heads == 0, ( |
| 114 | + f"{embed_dim=} must be divisible by {num_heads=}" |
| 115 | + ) |
| 116 | + |
| 117 | + # single in_proj_weight of shape (embed_dim * 3, embed_dim) |
| 118 | + model = nn.MultiheadAttention(embed_dim, num_heads, bias=False) |
| 119 | + p = model.in_proj_weight.detach() |
| 120 | + self._test_mha_inner(p, embed_dim, qk_reg_index) |
| 121 | + |
| 122 | + @common_utils.parametrize("qk_reg_index", [0, 1]) |
| 123 | + def test_e2e_optimizer(self, embed_dim=64, qk_reg_index=0): |
| 124 | + n_cls = 3 |
| 125 | + model = MHADummyModel(embed_dim, num_heads=4, n_cls=n_cls) |
| 126 | + prune_config = { |
| 127 | + "mha.in_proj_weight": { |
| 128 | + "group_type": "QKGrouper", |
| 129 | + "prox_type": "ProxGroupLasso", |
| 130 | + "qk_pack_dim": 0, |
| 131 | + "qk_reg_index": qk_reg_index, |
| 132 | + } |
| 133 | + } |
| 134 | + param_groups = get_param_groups(model, prune_config, verbose=False) |
| 135 | + self.assertEqual(len(param_groups), 3) |
| 136 | + |
| 137 | + p = model.mha.in_proj_weight.detach() |
| 138 | + qk_orig = self._get_qk(p, embed_dim, qk_reg_index).clone() |
| 139 | + |
| 140 | + # set lr to gamma since we run a single step |
| 141 | + gamma = self.get_gamma(qk_orig) |
| 142 | + optimizer = PruneOptimizer( |
| 143 | + torch.optim.SGD(param_groups, lr=gamma), reg_lambda=self.reg_lambda |
| 144 | + ) |
| 145 | + |
| 146 | + data = torch.randn(1, 8, embed_dim) |
| 147 | + label = torch.arange(0, n_cls) * data.mean(axis=-1, keepdim=True) |
| 148 | + output = model(data) |
| 149 | + loss = nn.functional.mse_loss(output, label) |
| 150 | + |
| 151 | + optimizer.zero_grad() |
| 152 | + loss.backward() |
| 153 | + optimizer.step() |
| 154 | + |
| 155 | + self._test_post_prune(p, qk_orig, embed_dim, qk_reg_index, gamma) |
| 156 | + |
| 157 | + |
| 158 | +class TestAttentionHeadGrouper(common_utils.TestCase): |
| 159 | + def __init__(self, methodName): |
| 160 | + super(TestAttentionHeadGrouper, self).__init__(methodName) |
| 161 | + self.reg_lambda = 1.0 |
| 162 | + self.prox_map = ProxGroupLasso(self.reg_lambda) |
| 163 | + |
| 164 | + @staticmethod |
| 165 | + def _get_view_shape_reduce_dim(dim, num_heads, head_pack_dim): |
| 166 | + if head_pack_dim == 0: |
| 167 | + view_shape = (num_heads, -1, dim) |
| 168 | + reduce_dim = (1, 2) |
| 169 | + else: |
| 170 | + view_shape = (dim, num_heads, -1) |
| 171 | + reduce_dim = (0, 2) |
| 172 | + return view_shape, reduce_dim |
| 173 | + |
| 174 | + def _test_post_prune(self, p, p_orig, head_pack_dim, view_shape, reduce_dim, gamma): |
| 175 | + nz_mask = p.view(*view_shape).sum(dim=reduce_dim).ne(0) |
| 176 | + self.assertTrue(nz_mask.eq(0).any(), "No groups of p were pruned") |
| 177 | + |
| 178 | + # original groups that are <= gamma are pruned |
| 179 | + expect_nz_mask = p_orig.view(*view_shape).gt(gamma).all(dim=reduce_dim) |
| 180 | + torch.testing.assert_close(nz_mask > 1, expect_nz_mask, atol=0, rtol=0) |
| 181 | + |
| 182 | + def get_gamma(self, p, head_pack_dim, view_shape): |
| 183 | + """Heuristic that uses the mean of the group to set gamma.""" |
| 184 | + p = p.view(*view_shape) |
| 185 | + p_group = p[0] if head_pack_dim == 0 else p[:, 0] |
| 186 | + gamma = (1 - p_group.mean()) * torch.linalg.vector_norm(p_group) |
| 187 | + gamma.div_(self.prox_map.tau(p_group)) |
| 188 | + return gamma |
| 189 | + |
| 190 | + @common_utils.parametrize("dim", [64, 128]) |
| 191 | + @common_utils.parametrize("head_pack_dim", [0, 1]) |
| 192 | + def test_head_grouper(self, dim=16, head_pack_dim=0, head_dim_ratio=8): |
| 193 | + assert dim % head_dim_ratio == 0, ( |
| 194 | + f"{dim=} must be divisible by {head_dim_ratio=}" |
| 195 | + ) |
| 196 | + num_heads = dim // 8 |
| 197 | + packed_dim = dim * num_heads |
| 198 | + shape = (dim, packed_dim) if head_pack_dim == 0 else (packed_dim, dim) |
| 199 | + model = nn.Linear(*shape, bias=False) |
| 200 | + p = model.weight.detach() |
| 201 | + p_orig = p.clone() |
| 202 | + view_shape, reduce_dim = self._get_view_shape_reduce_dim( |
| 203 | + dim, num_heads, head_pack_dim |
| 204 | + ) |
| 205 | + grouper_cls = ( |
| 206 | + AttentionHeadGrouperDim0 if head_pack_dim == 0 else AttentionHeadGrouperDim1 |
| 207 | + ) |
| 208 | + with grouper_cls(p, num_heads) as grouper: |
| 209 | + gamma = self.get_gamma(grouper.p, head_pack_dim, view_shape) |
| 210 | + _ = torch.vmap( |
| 211 | + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 |
| 212 | + )(grouper.p, gamma) |
| 213 | + self.assertEqual(grouper.p.size(head_pack_dim), num_heads) |
| 214 | + self._test_post_prune(p, p_orig, head_pack_dim, view_shape, reduce_dim, gamma) |
| 215 | + |
| 216 | + |
| 217 | +class TestSVDGrouper(common_utils.TestCase): |
| 218 | + def __init__(self, methodName): |
| 219 | + super(TestSVDGrouper, self).__init__(methodName) |
| 220 | + self.reg_lambda = 1.0 |
| 221 | + self.prox_map = ProxNuclearNorm(self.reg_lambda) |
| 222 | + |
| 223 | + @common_utils.parametrize("embed_dim", (16, 64)) |
| 224 | + def test_grouper(self, embed_dim=16): |
| 225 | + model = torch.nn.Linear(embed_dim, embed_dim) |
| 226 | + p = model.weight |
| 227 | + with SVDGrouper(p) as grouper: |
| 228 | + gamma = grouper.p.mean() |
| 229 | + p_orig = grouper.p.clone() |
| 230 | + torch.vmap( |
| 231 | + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 |
| 232 | + )(grouper.p, gamma) |
| 233 | + expect_nz_mask = p_orig.gt(gamma) |
| 234 | + torch.testing.assert_close(grouper.p.ne(0), expect_nz_mask, atol=0, rtol=0) |
| 235 | + |
| 236 | + @common_utils.parametrize("embed_dim", (16, 64)) |
| 237 | + @common_utils.parametrize("pack_dim", (0, 1)) |
| 238 | + def test_qk_grouper(self, embed_dim=16, pack_dim=0): |
| 239 | + shape = [embed_dim, embed_dim] |
| 240 | + shape[int(not pack_dim)] *= 3 |
| 241 | + model = torch.nn.Linear(*shape) |
| 242 | + p = model.weight |
| 243 | + with QKSVDGrouper(p, pack_dim=pack_dim) as grouper: |
| 244 | + gamma = grouper.p.mean() |
| 245 | + p_orig = grouper.p.clone() |
| 246 | + torch.vmap( |
| 247 | + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 |
| 248 | + )(grouper.p, gamma) |
| 249 | + expect_nz_mask = p_orig.gt(gamma) |
| 250 | + torch.testing.assert_close(grouper.p.ne(0), expect_nz_mask, atol=0, rtol=0) |
| 251 | + |
| 252 | + @common_utils.parametrize("embed_dim", (16, 64)) |
| 253 | + @common_utils.parametrize("pack_dim", (0, 1)) |
| 254 | + def test_packed_grouper(self, embed_dim=16, npack=3, pack_dim=0): |
| 255 | + shape = [embed_dim, embed_dim] |
| 256 | + shape[int(not pack_dim)] *= npack |
| 257 | + model = torch.nn.Linear(*shape) |
| 258 | + p = model.weight |
| 259 | + with PackedSVDGrouper(p, npack, pack_dim=pack_dim) as grouper: |
| 260 | + gamma = grouper.p.mean(0).mean() |
| 261 | + p_orig = grouper.p.clone() |
| 262 | + torch.vmap( |
| 263 | + self.prox_map.apply_, in_dims=(grouper.in_dims, None), out_dims=0 |
| 264 | + )(grouper.p.flatten(), gamma) |
| 265 | + torch.testing.assert_close( |
| 266 | + grouper.p.ne(0), p_orig.gt(gamma), atol=0, rtol=0 |
| 267 | + ) |
| 268 | + self.assertEqual(p.data_ptr(), grouper._p.data_ptr()) |
| 269 | + |
| 270 | + |
| 271 | +common_utils.instantiate_parametrized_tests(TestMaskedLayerNorm) |
| 272 | +common_utils.instantiate_parametrized_tests(TestQKGrouper) |
| 273 | +common_utils.instantiate_parametrized_tests(TestAttentionHeadGrouper) |
| 274 | +common_utils.instantiate_parametrized_tests(TestSVDGrouper) |
| 275 | + |
| 276 | +if __name__ == "__main__": |
| 277 | + random.seed(0) |
| 278 | + torch.manual_seed(0) |
| 279 | + unittest.main() |
0 commit comments