@@ -119,7 +119,8 @@ def gumbel_sample(
119119 stochastic = False ,
120120 straight_through = False ,
121121 dim = - 1 ,
122- training = True
122+ training = True ,
123+ topk = None
123124):
124125 dtype , size = logits .dtype , logits .shape [dim ]
125126
@@ -128,7 +129,11 @@ def gumbel_sample(
128129 else :
129130 sampling_logits = logits
130131
131- ind = sampling_logits .argmax (dim = dim )
132+ if exists (topk ):
133+ ind = sampling_logits .topk (topk , dim = dim ).indices
134+ else :
135+ ind = sampling_logits .argmax (dim = dim )
136+
132137 one_hot = F .one_hot (ind , size ).type (dtype )
133138
134139 if not straight_through or temperature <= 0. or not training :
@@ -629,7 +634,8 @@ def forward(
629634 codebook_transform_fn : Callable | None = None ,
630635 ema_update_weight : Tensor | Callable | None = None ,
631636 accum_ema_update = False ,
632- ema_update = None
637+ ema_update = None ,
638+ topk = None
633639 ):
634640 ema_update = default (ema_update , self .ema_update )
635641
@@ -686,20 +692,26 @@ def forward(
686692
687693 # sample or argmax depending on temperature
688694
689- embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 , temperature = sample_codebook_temp , training = self .training )
695+ embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 , topk = topk , temperature = sample_codebook_temp , training = self .training )
690696
691- embed_ind = unpack_one (embed_ind , 'h *' )
697+ if exists (topk ):
698+ embed_ind = unpack_one (embed_ind , 'h * k' )
699+ else :
700+ embed_ind = unpack_one (embed_ind , 'h *' )
692701
693702 if exists (codebook_transform_fn ):
694703 transformed_embed = unpack_one (transformed_embed , 'h * c d' )
695704
696705 if self .training :
697- unpacked_onehot = unpack_one (embed_onehot , 'h * c' )
706+ if exists (topk ):
707+ unpacked_onehot = unpack_one (embed_onehot , 'h * k c' )
708+ else :
709+ unpacked_onehot = unpack_one (embed_onehot , 'h * c' )
698710
699711 if exists (codebook_transform_fn ):
700- quantize = einsum ('h b n c, h b n c d -> h b n d' , unpacked_onehot , transformed_embed )
712+ quantize = einsum ('h b n ... c, h b n c d -> h b n ... d' , unpacked_onehot , transformed_embed )
701713 else :
702- quantize = einsum ('h b n c, h c d -> h b n d' , unpacked_onehot , embed )
714+ quantize = einsum ('h b n ... c, h c d -> h b n ... d' , unpacked_onehot , embed )
703715
704716 else :
705717 if exists (codebook_transform_fn ):
@@ -1007,6 +1019,7 @@ def forward(
10071019 indices = None ,
10081020 mask = None ,
10091021 lens = None ,
1022+ topk = None ,
10101023 sample_codebook_temp = None ,
10111024 freeze_codebook = None ,
10121025 return_loss_breakdown = False ,
@@ -1072,7 +1085,8 @@ def forward(
10721085 codebook_transform_fn = codebook_transform_fn ,
10731086 ema_update_weight = ema_update_weight ,
10741087 accum_ema_update = accum_ema_update ,
1075- ema_update = ema_update
1088+ ema_update = ema_update and not exists (topk ),
1089+ topk = topk
10761090 )
10771091
10781092 # quantize
@@ -1196,17 +1210,27 @@ def calculate_ce_loss(codes):
11961210
11971211 commit_loss = calculate_ce_loss (embed_ind )
11981212 else :
1199- if exists (mask ):
1213+ if exists (topk ):
1214+ # handle special case when returning topk
1215+
1216+ repeated_input = repeat (orig_input , '... d -> ... k d' , k = topk )
1217+ commit_loss = F .mse_loss (commit_quantize , repeated_input , reduction = 'none' )
1218+ commit_loss = reduce (commit_loss , '... k d -> ... k' , 'mean' )
1219+
1220+ if exists (mask ):
1221+ commit_loss = einx .where ('..., ... k, -> ... k' , mask , commit_loss , 0. )
1222+
1223+ elif exists (mask ):
12001224 # with variable lengthed sequences
1201- commit_loss = F .mse_loss (commit_quantize , x , reduction = 'none' )
1225+ commit_loss = F .mse_loss (commit_quantize , orig_input , reduction = 'none' )
12021226
12031227 loss_mask = mask
12041228 if is_multiheaded :
12051229 loss_mask = repeat (loss_mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
12061230
12071231 commit_loss = commit_loss [loss_mask ].mean ()
12081232 else :
1209- commit_loss = F .mse_loss (commit_quantize , x )
1233+ commit_loss = F .mse_loss (commit_quantize , orig_input )
12101234
12111235 loss = loss + commit_loss * self .commitment_weight
12121236
@@ -1261,7 +1285,7 @@ def calculate_ce_loss(codes):
12611285 masked_out_value = torch .zeros_like (orig_input )
12621286
12631287 quantize = einx .where (
1264- 'b n, b n d, b n d -> b n d' ,
1288+ 'b n, b n ... d, b n d -> b n ... d' ,
12651289 mask ,
12661290 quantize ,
12671291 masked_out_value
0 commit comments