Skip to content

Commit 28e4e56

Browse files
committed
handle VQ returning topk minimum distance codes
1 parent 29f94f4 commit 28e4e56

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

tests/test_manual_ema.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from vector_quantize_pytorch import VectorQuantize
33

4-
def test_manual_ema_update():
4+
def test_topk_and_manual_ema_update():
55

66
vq1 = VectorQuantize(
77
dim = 256,
@@ -19,17 +19,25 @@ def test_manual_ema_update():
1919
mask = torch.randint(0, 2, (1, 1024)).bool()
2020

2121
vq1.train()
22-
quantize1, indices1, _ = vq1(x, mask = mask)
22+
quantize1, indices1, commit_loss1 = vq1(x, mask = mask)
2323

2424
vq2.train()
25-
quantize2, indices2, _ = vq2(x, mask = mask, ema_update = False)
25+
quantize2, indices2, commit_losses = vq2(x, mask = mask, topk = 1, ema_update = False)
2626

27-
assert torch.allclose(quantize1, quantize2)
28-
assert torch.equal(indices1, indices2)
27+
assert quantize2.shape == (1, 1024, 1, 256)
28+
assert indices2.shape == (1, 1024, 1)
29+
assert commit_losses.shape == (1, 1024, 1)
30+
31+
top_quantize2 = quantize2[..., 0, :]
32+
top_indices2 = indices2[..., 0]
33+
34+
assert torch.allclose(commit_loss1, commit_losses.sum() / mask.sum())
35+
assert torch.equal(indices1, top_indices2)
36+
assert torch.allclose(quantize1, top_quantize2)
2937

3038
assert not torch.allclose(vq1._codebook.embed_avg, vq2._codebook.embed_avg)
3139

32-
vq2.update_ema_indices(x, indices2, mask = mask)
40+
vq2.update_ema_indices(x, top_indices2, mask = mask)
3341

3442
assert torch.allclose(vq1._codebook.cluster_size, vq2._codebook.cluster_size)
3543
assert torch.allclose(vq1._codebook.embed_avg, vq2._codebook.embed_avg)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def gumbel_sample(
119119
stochastic = False,
120120
straight_through = False,
121121
dim = -1,
122-
training = True
122+
training = True,
123+
topk = None
123124
):
124125
dtype, size = logits.dtype, logits.shape[dim]
125126

@@ -128,7 +129,11 @@ def gumbel_sample(
128129
else:
129130
sampling_logits = logits
130131

131-
ind = sampling_logits.argmax(dim = dim)
132+
if exists(topk):
133+
ind = sampling_logits.topk(topk, dim = dim).indices
134+
else:
135+
ind = sampling_logits.argmax(dim = dim)
136+
132137
one_hot = F.one_hot(ind, size).type(dtype)
133138

134139
if not straight_through or temperature <= 0. or not training:
@@ -629,7 +634,8 @@ def forward(
629634
codebook_transform_fn: Callable | None = None,
630635
ema_update_weight: Tensor | Callable | None = None,
631636
accum_ema_update = False,
632-
ema_update = None
637+
ema_update = None,
638+
topk = None
633639
):
634640
ema_update = default(ema_update, self.ema_update)
635641

@@ -686,20 +692,26 @@ def forward(
686692

687693
# sample or argmax depending on temperature
688694

689-
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
695+
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, topk = topk, temperature = sample_codebook_temp, training = self.training)
690696

691-
embed_ind = unpack_one(embed_ind, 'h *')
697+
if exists(topk):
698+
embed_ind = unpack_one(embed_ind, 'h * k')
699+
else:
700+
embed_ind = unpack_one(embed_ind, 'h *')
692701

693702
if exists(codebook_transform_fn):
694703
transformed_embed = unpack_one(transformed_embed, 'h * c d')
695704

696705
if self.training:
697-
unpacked_onehot = unpack_one(embed_onehot, 'h * c')
706+
if exists(topk):
707+
unpacked_onehot = unpack_one(embed_onehot, 'h * k c')
708+
else:
709+
unpacked_onehot = unpack_one(embed_onehot, 'h * c')
698710

699711
if exists(codebook_transform_fn):
700-
quantize = einsum('h b n c, h b n c d -> h b n d', unpacked_onehot, transformed_embed)
712+
quantize = einsum('h b n ... c, h b n c d -> h b n ... d', unpacked_onehot, transformed_embed)
701713
else:
702-
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
714+
quantize = einsum('h b n ... c, h c d -> h b n ... d', unpacked_onehot, embed)
703715

704716
else:
705717
if exists(codebook_transform_fn):
@@ -1007,6 +1019,7 @@ def forward(
10071019
indices = None,
10081020
mask = None,
10091021
lens = None,
1022+
topk = None,
10101023
sample_codebook_temp = None,
10111024
freeze_codebook = None,
10121025
return_loss_breakdown = False,
@@ -1072,7 +1085,8 @@ def forward(
10721085
codebook_transform_fn = codebook_transform_fn,
10731086
ema_update_weight = ema_update_weight,
10741087
accum_ema_update = accum_ema_update,
1075-
ema_update = ema_update
1088+
ema_update = ema_update and not exists(topk),
1089+
topk = topk
10761090
)
10771091

10781092
# quantize
@@ -1196,17 +1210,27 @@ def calculate_ce_loss(codes):
11961210

11971211
commit_loss = calculate_ce_loss(embed_ind)
11981212
else:
1199-
if exists(mask):
1213+
if exists(topk):
1214+
# handle special case when returning topk
1215+
1216+
repeated_input = repeat(orig_input, '... d -> ... k d', k = topk)
1217+
commit_loss = F.mse_loss(commit_quantize, repeated_input, reduction = 'none')
1218+
commit_loss = reduce(commit_loss, '... k d -> ... k', 'mean')
1219+
1220+
if exists(mask):
1221+
commit_loss = einx.where('..., ... k, -> ... k', mask, commit_loss, 0.)
1222+
1223+
elif exists(mask):
12001224
# with variable lengthed sequences
1201-
commit_loss = F.mse_loss(commit_quantize, x, reduction = 'none')
1225+
commit_loss = F.mse_loss(commit_quantize, orig_input, reduction = 'none')
12021226

12031227
loss_mask = mask
12041228
if is_multiheaded:
12051229
loss_mask = repeat(loss_mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
12061230

12071231
commit_loss = commit_loss[loss_mask].mean()
12081232
else:
1209-
commit_loss = F.mse_loss(commit_quantize, x)
1233+
commit_loss = F.mse_loss(commit_quantize, orig_input)
12101234

12111235
loss = loss + commit_loss * self.commitment_weight
12121236

@@ -1261,7 +1285,7 @@ def calculate_ce_loss(codes):
12611285
masked_out_value = torch.zeros_like(orig_input)
12621286

12631287
quantize = einx.where(
1264-
'b n, b n d, b n d -> b n d',
1288+
'b n, b n ... d, b n d -> b n ... d',
12651289
mask,
12661290
quantize,
12671291
masked_out_value

0 commit comments

Comments
 (0)