Skip to content

Commit dbd0197

Browse files
authored
Reset cache logic of weight workspace for NVFP4TensorStorage (#2524)
reset weight ws cache for NVFP4TensorStorage Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
1 parent eac8af6 commit dbd0197

File tree

1 file changed

+6
-0
lines changed
  • transformer_engine/pytorch/module

1 file changed

+6
-0
lines changed

transformer_engine/pytorch/module/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
4646
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
4747
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
48+
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
4849
from ..utils import (
4950
is_non_tn_fp8_gemm_supported,
5051
torch_get_autocast_gpu_dtype,
@@ -1388,6 +1389,11 @@ def get_weight_workspace(
13881389
reset_cache = True
13891390
elif quantizer.columnwise_usage and out._columnwise_data is None:
13901391
reset_cache = True
1392+
elif isinstance(out, NVFP4TensorStorage):
1393+
if quantizer.rowwise_usage and out._rowwise_data is None:
1394+
reset_cache = True
1395+
elif quantizer.columnwise_usage and out._columnwise_data is None:
1396+
reset_cache = True
13911397
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
13921398
reset_cache = True
13931399
if reset_cache:

0 commit comments

Comments
 (0)