Skip to content

Commit c34a005

Browse files
committed
again
1 parent 17e3b4b commit c34a005

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.10"
3+
version = "1.27.11"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ def forward(
10571057
assert not exists(mask)
10581058
x = rearrange(x, 'b d -> b 1 d')
10591059

1060-
shape, device, heads, is_multiheaded, codebook_size, return_loss = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size, exists(indices)
1060+
shape, dtype, device, heads, is_multiheaded, codebook_size, return_loss = x.shape, x.dtype, x.device, self.heads, self.heads > 1, self.codebook_size, exists(indices)
10611061

10621062
need_transpose = not self.channel_last and not self.accept_image_fmap
10631063
should_inplace_optimize = exists(self.in_place_codebook_optimizer)
@@ -1101,6 +1101,8 @@ def forward(
11011101

11021102
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
11031103

1104+
quantize = quantize.type(dtype)
1105+
11041106
# losses for loss breakdown
11051107

11061108
commit_loss = orthogonal_reg_loss = inplace_optimize_loss = codebook_diversity_loss = self.zero
@@ -1146,15 +1148,14 @@ def forward(
11461148
# spare rotation trick calculation if inputs do not need gradients
11471149

11481150
if input_requires_grad:
1149-
x_for_grad = x.to(quantize)
11501151

11511152
if self.rotation_trick:
1152-
quantize = rotate_to(x_for_grad, quantize)
1153+
quantize = rotate_to(x, quantize)
11531154
elif self.directional_reparam:
1154-
quantize = directional_reparam(x_for_grad, quantize, self.directional_reparam_variance)
1155+
quantize = directional_reparam(x, quantize, self.directional_reparam_variance)
11551156
else:
11561157
# standard STE to get gradients through VQ layer.
1157-
quantize = straight_through(x_for_grad, quantize)
1158+
quantize = straight_through(x, quantize)
11581159

11591160
if self.sync_update_v > 0.:
11601161
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf

0 commit comments

Comments
 (0)