Skip to content

Commit 01de02e

Browse files
authored
[gguf][torch.compile time] Convert to plain tensor earlier in dequantize_gguf_tensor (#13166)
[gguf] Convert to plain tensor earlier in dequantize_gguf_tensor Once dequantize_gguf_tensor fetches the quant_type attributed from the GGUFParamter tensor subclass, there is no further need of running the actual dequantize operations on the Tensor subclass, we can just convert to plain tensor right away. This not only makes PyTorch eager faster, but reduces torch.compile tracer compile time from 36 seconds to 10 seconds, because there is lot less code to trace now.
1 parent db2d7e7 commit 01de02e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/quantizers/gguf/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,9 @@ def dequantize_gguf_tensor(tensor):
516516

517517
block_size, type_size = GGML_QUANT_SIZES[quant_type]
518518

519+
# Conver to plain tensor to avoid unnecessary __torch_function__ overhead.
520+
tensor = tensor.as_tensor()
521+
519522
tensor = tensor.view(torch.uint8)
520523
shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
521524

@@ -525,7 +528,7 @@ def dequantize_gguf_tensor(tensor):
525528
dequant = dequant_fn(blocks, block_size, type_size)
526529
dequant = dequant.reshape(shape)
527530

528-
return dequant.as_tensor()
531+
return dequant
529532

530533

531534
class GGUFParameter(torch.nn.Parameter):

0 commit comments

Comments
 (0)