File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
transformer_engine/pytorch/module Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change 4545from ..tensor .float8_blockwise_tensor import Float8BlockQuantizer
4646from ..tensor .storage .float8_tensor_storage import Float8TensorStorage
4747from ..tensor .storage .mxfp8_tensor_storage import MXFP8TensorStorage
48+ from ..tensor .storage .nvfp4_tensor_storage import NVFP4TensorStorage
4849from ..utils import (
4950 is_non_tn_fp8_gemm_supported ,
5051 torch_get_autocast_gpu_dtype ,
@@ -1388,6 +1389,11 @@ def get_weight_workspace(
13881389 reset_cache = True
13891390 elif quantizer .columnwise_usage and out ._columnwise_data is None :
13901391 reset_cache = True
1392+ elif isinstance (out , NVFP4TensorStorage ):
1393+ if quantizer .rowwise_usage and out ._rowwise_data is None :
1394+ reset_cache = True
1395+ elif quantizer .columnwise_usage and out ._columnwise_data is None :
1396+ reset_cache = True
13911397 if isinstance (out , DebugQuantizedTensor ) != isinstance (quantizer , DebugQuantizer ):
13921398 reset_cache = True
13931399 if reset_cache :
You can’t perform that action at this time.
0 commit comments