@@ -1057,7 +1057,7 @@ def forward(
10571057 assert not exists (mask )
10581058 x = rearrange (x , 'b d -> b 1 d' )
10591059
1060- shape , device , heads , is_multiheaded , codebook_size , return_loss = x .shape , x .device , self .heads , self .heads > 1 , self .codebook_size , exists (indices )
1060+ shape , dtype , device , heads , is_multiheaded , codebook_size , return_loss = x .shape , x . dtype , x .device , self .heads , self .heads > 1 , self .codebook_size , exists (indices )
10611061
10621062 need_transpose = not self .channel_last and not self .accept_image_fmap
10631063 should_inplace_optimize = exists (self .in_place_codebook_optimizer )
@@ -1101,6 +1101,8 @@ def forward(
11011101
11021102 quantize , embed_ind , distances = self ._codebook (x , ** codebook_forward_kwargs )
11031103
1104+ quantize = quantize .type (dtype )
1105+
11041106 # losses for loss breakdown
11051107
11061108 commit_loss = orthogonal_reg_loss = inplace_optimize_loss = codebook_diversity_loss = self .zero
@@ -1146,15 +1148,14 @@ def forward(
11461148 # spare rotation trick calculation if inputs do not need gradients
11471149
11481150 if input_requires_grad :
1149- x_for_grad = x .to (quantize )
11501151
11511152 if self .rotation_trick :
1152- quantize = rotate_to (x_for_grad , quantize )
1153+ quantize = rotate_to (x , quantize )
11531154 elif self .directional_reparam :
1154- quantize = directional_reparam (x_for_grad , quantize , self .directional_reparam_variance )
1155+ quantize = directional_reparam (x , quantize , self .directional_reparam_variance )
11551156 else :
11561157 # standard STE to get gradients through VQ layer.
1157- quantize = straight_through (x_for_grad , quantize )
1158+ quantize = straight_through (x , quantize )
11581159
11591160 if self .sync_update_v > 0. :
11601161 # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
0 commit comments