From d32b7ab9ed3de67ae94b2d83a9f7af720829fe2d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 12 Dec 2025 04:17:16 -0800 Subject: [PATCH 01/23] avoiding shape copy, torch dynamo and torch autograd overheads Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 62 +++++++- transformer_engine/pytorch/csrc/common.h | 19 ++- .../pytorch/csrc/extensions/bias.cpp | 9 +- .../pytorch/csrc/extensions/gemm.cpp | 8 +- .../pytorch/csrc/extensions/transpose.cpp | 6 +- transformer_engine/pytorch/csrc/quantizer.cpp | 34 +++-- .../pytorch/csrc/type_converters.cpp | 4 +- transformer_engine/pytorch/module/linear.py | 138 ++++++++---------- 8 files changed, 173 insertions(+), 107 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e054424dd4d..f7a8540197f 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,12 +26,8 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -std::vector getTensorShape(const at::Tensor& t) { - std::vector shape; - for (auto s : t.sizes()) { - shape.push_back(s); - } - return shape; +NVTEShape getTensorShape(const at::Tensor& t) { + return convertTorchShape(t.sizes()); } NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { @@ -178,6 +174,38 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -199,6 +227,28 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 + : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 + : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..883c2a24cad 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -339,7 +339,7 @@ class NVFP4Quantizer : public Quantizer { std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(const at::Tensor& t); +NVTEShape getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -432,6 +432,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -440,6 +450,13 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( const std::vector& columnwise_scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index b0435d27230..2eef7438068 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,7 +26,8 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShape(grad_output_torch); + const auto shape_nvte = getTensorShape(grad_output_torch); + const auto shape = convertShape(shape_nvte); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -116,11 +117,13 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShape(grad_output_torch); + const auto output_shape_nvte = getTensorShape(grad_output_torch); + const auto output_shape = convertShape(output_shape_nvte); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShape(act_input_torch); + const auto input_shape_nvte = getTensorShape(act_input_torch); + const auto input_shape = convertShape(input_shape_nvte); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 13e8bfb6e5f..f704864cb60 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -365,12 +365,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; + const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, + A.data_ptr(), A_shape, A_type, nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), nvte_scaling_modeA); + const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; + const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, + B.data_ptr(), B_shape, B_type, nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 7dfdf995475..5ace996afcc 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -19,7 +19,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional transpose_shape_int64; if (shape.size() > 0) { transpose_shape_int64.push_back(shape.back()); @@ -60,7 +61,8 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - auto in_shape = getTensorShape(input); + const auto in_shape_nvte = getTensorShape(input); + const auto in_shape = convertShape(in_shape_nvte); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d7e8912ac74..3b94d38ac16 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -209,7 +209,8 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); + const auto transpose_shape = convertShape(transpose_shape_nvte); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -217,12 +218,13 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape_nvte = getTensorShape(*data_tensor); + const auto expected_shape = convertShape(expected_shape_nvte); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = convertShape(getTensorShape(*data_tensor)); } // Coerce data tensor @@ -430,7 +432,8 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); + const auto transpose_shape = convertShape(transpose_shape_nvte); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -438,12 +441,13 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape_nvte = getTensorShape(*data_tensor); + const auto expected_shape = convertShape(expected_shape_nvte); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = convertShape(getTensorShape(*data_tensor)); } // Coerce data tensor in Python tensor @@ -680,9 +684,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return getTensorShape(*columnwise_data); + return convertShape(getTensorShape(*columnwise_data)); } - std::vector shape = getTensorShape(*columnwise_data); + std::vector shape = convertShape(getTensorShape(*columnwise_data)); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -694,7 +698,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = getTensorShape(*rowwise_data); + shape = convertShape(getTensorShape(*rowwise_data)); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -1004,14 +1008,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = getTensorShape(*columnwise_data); + shape = convertShape(getTensorShape(*columnwise_data)); if (rowwise_data) { - auto expected_shape = getTensorShape(*rowwise_data); + const auto expected_shape = convertShape(getTensorShape(*rowwise_data)); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = getTensorShape(*rowwise_data); + shape = convertShape(getTensorShape(*rowwise_data)); } // Coerce row-wise data @@ -1320,14 +1324,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + auto expected_shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); } size_t flat_first_dim = 1; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 368e9dcdfa3..780a08da7f8 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb3..7557f5c5396 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -96,40 +96,66 @@ def forward( ( is_first_microbatch, - fp8, - fp8_calibration, - wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - fuse_wgrad_accumulation, cpu_offloading, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, is_grad_enabled, - ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - ub_name, - fp8_output, # pylint: disable=unused-variable - fsdp_group, + fp8_output, + fp8_grad, module, skip_fp8_weight_update, - symmetric_ar_type, - save_original_input, debug, ) = non_tensor_args + (fp8, + fp8_calibration, + wgrad_store, + fuse_wgrad_accumulation, + tp_group, + tp_size, + sequence_parallel, + tensor_parallel, + activation_dtype, + parallel_mode, + ub_overlap_rs_fprop, + ub_overlap_ag_dgrad, + ub_overlap_ag_fprop, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + fsdp_group, + symmetric_ar_type, + save_original_input + ) = (module.fp8, + module.fp8_calibration, + module.wgrad_store, + module.fuse_wgrad_accumulation, + module.tp_group, + module.tp_size, + module.sequence_parallel, + module.tp_size > 1, + module.activation_dtype, + module.parallel_mode, + module.ub_overlap_rs_fprop, + module.ub_overlap_ag_dgrad, + module.ub_overlap_ag_fprop, + module.ub_overlap_rs_dgrad, + module.ub_bulk_dgrad, + module.ub_bulk_wgrad, + module.ub_name, + module.fsdp_group, + module.symmetric_ar_type, + module.save_original_input, + ) + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + + if debug: + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if module.no_debug_features_active(quantizers): + debug = False + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + (input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer) = quantizers + + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" if ub_name is not None: @@ -981,7 +1007,6 @@ def wgrad_gemm( None, ) - class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` @@ -1343,7 +1368,6 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1401,28 +1425,7 @@ def forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - if is_grad_enabled: linear_fn = _Linear.apply autograd_ctx = [] @@ -1432,37 +1435,12 @@ def forward( non_tensor_args = ( is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, fp8_output, - self.fsdp_group, + fp8_grad, self, skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, debug, ) out = linear_fn( @@ -1687,3 +1665,11 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].all_gather_usage = True + + +# disable torch dynamo just once to reduce wrapped function overhead on each +# forward call of te Linear. +if torch.__version__ >= "2": + Linear.forward._torchdynamo_disable = True + Linear.forward._torchdynamo_disable_msg = None + From e7248151ecf7877e3217b6bdd1fcf3e4b59d28ae Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 12 Dec 2025 22:47:01 +0000 Subject: [PATCH 02/23] minor additional change Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f7a8540197f..3467223d2ac 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -30,7 +30,7 @@ NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 883c2a24cad..22061de4773 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -496,7 +496,7 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); From b725f5b31d52adc800514898cf34f8c65851e6be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Dec 2025 23:02:13 +0000 Subject: [PATCH 03/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.cpp | 12 +-- transformer_engine/pytorch/csrc/common.h | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 14 ++- transformer_engine/pytorch/csrc/quantizer.cpp | 3 +- transformer_engine/pytorch/module/linear.py | 94 ++++++++++--------- 5 files changed, 68 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 3467223d2ac..c7f0975216b 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,9 +26,7 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { - return convertTorchShape(t.sizes()); -} +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; @@ -175,8 +173,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); @@ -229,8 +227,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, - const NVTEShape& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 22061de4773..e6c22880323 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -433,8 +433,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( @@ -452,8 +452,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, - const NVTEShape& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f704864cb60..35b523b5192 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -367,16 +367,14 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), A_shape, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), B_shape, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. auto te_D = makeTransformerEngineTensor( D.data_ptr(), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 3b94d38ac16..aa8416121d0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1326,7 +1326,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( if (columnwise_data) { shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + auto expected_shape = + convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7557f5c5396..965367ac31b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -105,46 +105,48 @@ def forward( debug, ) = non_tensor_args - (fp8, - fp8_calibration, - wgrad_store, - fuse_wgrad_accumulation, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, - ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - ub_name, - fsdp_group, - symmetric_ar_type, - save_original_input - ) = (module.fp8, - module.fp8_calibration, - module.wgrad_store, - module.fuse_wgrad_accumulation, - module.tp_group, - module.tp_size, - module.sequence_parallel, - module.tp_size > 1, - module.activation_dtype, - module.parallel_mode, - module.ub_overlap_rs_fprop, - module.ub_overlap_ag_dgrad, - module.ub_overlap_ag_fprop, - module.ub_overlap_rs_dgrad, - module.ub_bulk_dgrad, - module.ub_bulk_wgrad, - module.ub_name, - module.fsdp_group, - module.symmetric_ar_type, - module.save_original_input, + ( + fp8, + fp8_calibration, + wgrad_store, + fuse_wgrad_accumulation, + tp_group, + tp_size, + sequence_parallel, + tensor_parallel, + activation_dtype, + parallel_mode, + ub_overlap_rs_fprop, + ub_overlap_ag_dgrad, + ub_overlap_ag_fprop, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + fsdp_group, + symmetric_ar_type, + save_original_input, + ) = ( + module.fp8, + module.fp8_calibration, + module.wgrad_store, + module.fuse_wgrad_accumulation, + module.tp_group, + module.tp_size, + module.sequence_parallel, + module.tp_size > 1, + module.activation_dtype, + module.parallel_mode, + module.ub_overlap_rs_fprop, + module.ub_overlap_ag_dgrad, + module.ub_overlap_ag_fprop, + module.ub_overlap_rs_dgrad, + module.ub_bulk_dgrad, + module.ub_bulk_wgrad, + module.ub_name, + module.fsdp_group, + module.symmetric_ar_type, + module.save_original_input, ) quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) @@ -153,8 +155,14 @@ def forward( if module.no_debug_features_active(quantizers): debug = False quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - (input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer) = quantizers - + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -1007,6 +1015,7 @@ def wgrad_gemm( None, ) + class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` @@ -1672,4 +1681,3 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci if torch.__version__ >= "2": Linear.forward._torchdynamo_disable = True Linear.forward._torchdynamo_disable_msg = None - From 7b031d011331324b5f872f9e4038a9ace4a9f86c Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 23 Dec 2025 12:39:17 +0000 Subject: [PATCH 04/23] changes done to remove the additional nvte_make_shape calls Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 31 ++++-- transformer_engine/pytorch/csrc/common.h | 9 ++ .../pytorch/csrc/extensions/attention.cpp | 4 +- .../pytorch/csrc/extensions/bias.cpp | 9 +- .../pytorch/csrc/extensions/cast.cpp | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 105 ++++++++++++------ .../pytorch/csrc/extensions/padding.cpp | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 54 +++++---- .../pytorch/csrc/type_converters.cpp | 4 +- transformer_engine/pytorch/csrc/util.cpp | 22 ++-- 10 files changed, 155 insertions(+), 93 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index c7f0975216b..b6a3853f6fd 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,7 +26,17 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } +NVTEShape getTensorShape(const at::Tensor& t) { + return convertTorchShape(t.sizes()); +} + +std::vector getTensorShapeVector(const at::Tensor& t) { + std::vector shape; + for (auto s : t.sizes()) { + shape.push_back(s); + } + return shape; +} NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; @@ -113,10 +123,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + NVTEShape shape = getTensorShape(tensor); return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } @@ -179,7 +186,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = @@ -194,8 +203,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); - const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = @@ -234,8 +244,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e6c22880323..a9e7d895192 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -141,6 +141,13 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const NVTEShape& shape, DType dtype, + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -341,6 +348,8 @@ std::unique_ptr convert_quantizer(py::handle quantizer); NVTEShape getTensorShape(const at::Tensor& t); +std::vector getTensorShapeVector(const at::Tensor& t); + transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2480d9aba9b..804a4667d71 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -479,9 +479,9 @@ std::vector fused_attn_bwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 2eef7438068..c3e89ed0856 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,8 +26,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape_nvte = getTensorShape(grad_output_torch); - const auto shape = convertShape(shape_nvte); + const auto shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -117,13 +116,11 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape_nvte = getTensorShape(grad_output_torch); - const auto output_shape = convertShape(output_shape_nvte); + const auto output_shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape_nvte = getTensorShape(act_input_torch); - const auto input_shape = convertShape(input_shape_nvte); + const auto input_shape = getTensorShapeVector(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b12da7542bb..3f107f443c7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -334,12 +334,12 @@ std::tuple, std::vector> bulk_allocate_fp tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp8_dtype, nullptr, + rowwise_usage ? nvte_make_shape(rowwise_data_shapes[i].data(), rowwise_data_shapes[i].size()) : NVTEShape{}, + columnwise_usage ? nvte_make_shape(columnwise_data_shapes[i].data(), columnwise_data_shapes[i].size()) : NVTEShape{}, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode)); + rowwise_usage ? nvte_make_shape(rowwise_scale_shapes[i].data(), rowwise_scale_shapes[i].size()) : NVTEShape{}, + columnwise_usage ? nvte_make_shape(columnwise_scale_shapes[i].data(), columnwise_scale_shapes[i].size()) : NVTEShape{}, scaling_mode)); } return retval; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 35b523b5192..11be2d4e2fe 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, + const NVTEShape& B_shape, const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); const size_t A1 = A_shape.data[A_shape.ndim - 1]; @@ -53,27 +53,29 @@ std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool tran A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); // Construct output dims - std::vector ret; + NVTEShape ret; + size_t idx = 0; if (transb) { - ret.emplace_back(B1); + ret.data[idx++] = B1; } else { // Unflatten B0 for (size_t i = 0; i < B_shape.ndim - 1; ++i) { - ret.emplace_back(B_shape.data[i]); + ret.data[idx++] = B_shape.data[i]; } } if (transa) { - ret.emplace_back(A0); + ret.data[idx++] = A0; } else { - ret.emplace_back(A1); + ret.data[idx++] = A1; } + ret.ndim = idx; return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { - if (expected.size() != actual.ndim) return false; - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != actual.data[i]) return false; +bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { + if (expected.ndim != actual.ndim) return false; + for (size_t i = 0; i < expected.ndim; ++i) { + if (expected.data[i] != actual.data[i]) return false; } return true; } @@ -117,7 +119,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Check tensor dimensions const auto& A_shape = A_tensor.shape(); const auto& B_shape = B_tensor.shape(); - const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + const NVTEShape D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); @@ -138,7 +140,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + std::tie(D_tensor, D) = createOutputTensor(convertShape(D_shape), output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), @@ -168,7 +170,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(convertShape(D_shape), output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -197,8 +199,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans auto dtype = GetATenDType(gelu_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); std::vector torch_shape; - for (auto v : D_shape) { - torch_shape.push_back(v); + for (size_t i = 0; i < D_shape.ndim; ++i) { + torch_shape.push_back(static_cast(D_shape.data[i])); } pre_gelu_out = at::empty(torch_shape, opts); } else { @@ -207,14 +209,21 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - const auto gelu_shape = gelu ? D_shape : std::vector{0}; + NVTEShape gelu_shape; + gelu_shape.ndim = 1; + gelu_shape.data[0] = 0; + if (gelu) { + gelu_shape = D_shape; + } auto te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); // Workspace - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; + auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -263,8 +272,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (extra_output.has_value()) { extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { + NVTEShape extra_output_shape; + extra_output_shape.ndim = 0; extra_output_tensor = - makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + makeTransformerEngineTensor(nullptr, extra_output_shape, DType::kByte); } // Direct GEMM call to the correct overlap @@ -367,28 +378,47 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, - A_scale_inverse.data_ptr(), - getTensorShape(A_scale_inverse), nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor( + A.data_ptr(), A_shape, A_type, + nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), + nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, - B_scale_inverse.data_ptr(), - getTensorShape(B_scale_inverse), nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor( + B.data_ptr(), B_shape, B_type, + nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), + nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. + NVTEShape D_shape, D_scale_inv_shape; + D_shape.ndim = 2; + D_scale_inv_shape.ndim = 1; + D_scale_inv_shape.data[0] = 1; + D_shape.data[0] = static_cast(D.size(0)); + D_shape.data[1] = static_cast(D.size(1)); auto te_D = makeTransformerEngineTensor( D.data_ptr(), - std::vector{static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); + D_shape, D_type, + D_amax.data_ptr(), D_scale.data_ptr(), nullptr, D_scale_inv_shape); + NVTEShape bias_shape; + bias_shape.ndim = 1; + bias_shape.data[0] = static_cast(bias.size(0)); auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); + bias.data_ptr(), bias_shape, bias_type); + NVTEShape counter_shape; + counter_shape.ndim = 1; + counter_shape.data[0] = static_cast(counter.size(0)); auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), std::vector{static_cast(counter.size(0))}, DType::kInt32); + counter.data_ptr(), counter_shape, DType::kInt32); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; + NVTEShape gelu_shape; + if (pre_gelu_out.data_ptr() == nullptr) { + gelu_shape.ndim = 1; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + } else { + gelu_shape.ndim = 2; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); + } auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), @@ -432,12 +462,13 @@ std::optional> te_general_grouped_gemm( // if there is single output at::Tensor out_tensor; - auto size_t_shape = + const NVTEShape nvte_D_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; std::vector D_shape; - for (size_t t : size_t_shape) { - D_shape.push_back(t); + for (size_t j = 0; j < nvte_D_shape.ndim; ++j) { + const size_t t = nvte_D_shape.data[j]; + D_shape.push_back(static_cast(t)); if (t == 0) { D_numel_is_zero = true; } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index d4b64a485c1..389308405b1 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -34,7 +34,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - + NVTEShape input_shape = {input_row_list[tensor_id], static_cast(input.size(1))}; input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index aa8416121d0..00f43433435 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -77,6 +77,16 @@ std::pair NoneQuantizer::create_tensor(const std::vec return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_int64; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_int64.push_back(static_cast(shape.data[i])); + } + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} + std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const { @@ -86,6 +96,15 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype, + at::Tensor data) const { +TensorWrapper out_cpp; +out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); +set_quantization_params(&out_cpp); +return {std::move(out_cpp), py::cast(data)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -209,8 +228,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); - const auto transpose_shape = convertShape(transpose_shape_nvte); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -218,13 +236,12 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape_nvte = getTensorShape(*data_tensor); - const auto expected_shape = convertShape(expected_shape_nvte); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = convertShape(getTensorShape(*data_tensor)); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor @@ -432,8 +449,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); - const auto transpose_shape = convertShape(transpose_shape_nvte); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -441,13 +457,12 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape_nvte = getTensorShape(*data_tensor); - const auto expected_shape = convertShape(expected_shape_nvte); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = convertShape(getTensorShape(*data_tensor)); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor in Python tensor @@ -684,9 +699,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return convertShape(getTensorShape(*columnwise_data)); + return getTensorShapeVector(*columnwise_data); } - std::vector shape = convertShape(getTensorShape(*columnwise_data)); + std::vector shape = getTensorShapeVector(*columnwise_data); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -698,7 +713,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = convertShape(getTensorShape(*rowwise_data)); + shape = getTensorShapeVector(*rowwise_data); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -1008,14 +1023,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = convertShape(getTensorShape(*columnwise_data)); + shape = getTensorShapeVector(*columnwise_data); if (rowwise_data) { - const auto expected_shape = convertShape(getTensorShape(*rowwise_data)); + const auto expected_shape = getTensorShapeVector(*rowwise_data); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convertShape(getTensorShape(*rowwise_data)); + shape = getTensorShapeVector(*rowwise_data); } // Coerce row-wise data @@ -1324,15 +1339,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*columnwise_data), true); if (rowwise_data) { - auto expected_shape = - convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + auto expected_shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); } size_t flat_first_dim = 1; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 780a08da7f8..48e9f06cc40 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 134185ac823..7fc04801e49 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } @@ -59,24 +59,24 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } return swizzled_scale_inv; From d6ac3f1b28b2d3d2e8abbe0ab51cba7f782712b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Dec 2025 04:47:56 +0000 Subject: [PATCH 05/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.cpp | 4 +-- transformer_engine/pytorch/csrc/common.h | 9 +++-- .../pytorch/csrc/extensions/attention.cpp | 10 ++++-- .../pytorch/csrc/extensions/gemm.cpp | 36 +++++++++---------- transformer_engine/pytorch/csrc/quantizer.cpp | 14 ++++---- 5 files changed, 35 insertions(+), 38 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index b6a3853f6fd..d4ce064facf 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,9 +26,7 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { - return convertTorchShape(t.sizes()); -} +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } std::vector getTensorShapeVector(const at::Tensor& t) { std::vector shape; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index a9e7d895192..58e2acb6959 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -141,13 +141,12 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, - DType dtype) const; - + std::pair create_tensor(const NVTEShape& shape, DType dtype) const; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShape& shape, DType dtype, - at::Tensor data) const; - + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 804a4667d71..1007dcb80c6 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -478,10 +478,14 @@ std::vector fused_attn_bwd( auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), DType::kInt32); + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), + DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b8928053d77..0e478ecd3ce 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, + const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; @@ -170,7 +170,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(convertShape(D_shape), output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = + q.create_tensor(convertShape(D_shape), output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -223,7 +224,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans NVTEShape workspace_shape; workspace_shape.ndim = 1; workspace_shape.data[0] = workspaceSize; - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -378,16 +380,14 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), A_shape, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), B_shape, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. NVTEShape D_shape, D_scale_inv_shape; D_shape.ndim = 2; @@ -395,20 +395,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, D_scale_inv_shape.data[0] = 1; D_shape.data[0] = static_cast(D.size(0)); D_shape.data[1] = static_cast(D.size(1)); - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), - D_shape, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr, D_scale_inv_shape); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), D_shape, D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr, D_scale_inv_shape); NVTEShape bias_shape; bias_shape.ndim = 1; bias_shape.data[0] = static_cast(bias.size(0)); - auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), bias_shape, bias_type); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), bias_shape, bias_type); NVTEShape counter_shape; counter_shape.ndim = 1; counter_shape.data[0] = static_cast(counter.size(0)); - auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), counter_shape, DType::kInt32); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); NVTEShape gelu_shape; if (pre_gelu_out.data_ptr() == nullptr) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 64a6fa84766..0f8aa8381a8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -78,7 +78,7 @@ std::pair NoneQuantizer::create_tensor(const std::vec } std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, - DType dtype) const { + DType dtype) const { std::vector shape_int64; for (size_t i = 0; i < shape.ndim; ++i) { shape_int64.push_back(static_cast(shape.data[i])); @@ -97,12 +97,12 @@ std::pair NoneQuantizer::create_tensor(const std::vec } std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, - DType dtype, - at::Tensor data) const { -TensorWrapper out_cpp; -out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); -set_quantization_params(&out_cpp); -return {std::move(out_cpp), py::cast(data)}; + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; } std::pair NoneQuantizer::convert_and_update_tensor( From 51dd309ed4c4ece54350196d27351c163eb3ca9d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 28 Dec 2025 08:49:18 +0000 Subject: [PATCH 06/23] some additional changes Signed-off-by: Varun Thumbe --- .../common/gemm/cublaslt_gemm.cu | 10 +++++---- .../common/transformer_engine.cpp | 2 +- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 22 ++++++++++--------- transformer_engine/pytorch/module/linear.py | 6 +---- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 118bf193353..899f5fe5e6d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 4a140b4376d..0c4c9456c60 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int num_devices = transformer_engine::cuda::num_devices(); + static int num_devices = transformer_engine::cuda::num_devices(); static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index d4ce064facf..04ae78c0b78 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -36,7 +36,7 @@ std::vector getTensorShapeVector(const at::Tensor& t) { return shape; } -NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { +NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { NVTEShape ret; ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 58e2acb6959..dd36e178ce7 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -504,7 +504,7 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape); +NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0f8aa8381a8..4dc776b4546 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -141,8 +141,9 @@ std::pair Float8Quantizer::create_tensor( std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -153,7 +154,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -204,10 +205,10 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -347,7 +348,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -356,7 +358,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -425,10 +427,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 965367ac31b..f71f780be3c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1377,6 +1377,7 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1676,8 +1677,3 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci ].all_gather_usage = True -# disable torch dynamo just once to reduce wrapped function overhead on each -# forward call of te Linear. -if torch.__version__ >= "2": - Linear.forward._torchdynamo_disable = True - Linear.forward._torchdynamo_disable_msg = None From 425182f7db96dab01e078fd54252415bd179cbd8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 28 Dec 2025 08:50:00 +0000 Subject: [PATCH 07/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f71f780be3c..4722d51f59d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1675,5 +1675,3 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].all_gather_usage = True - - From f8748c068313a5e1d6479a1669762560754296c0 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 28 Dec 2025 09:50:55 +0000 Subject: [PATCH 08/23] some optimizations Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.h | 20 +++- .../pytorch/csrc/extensions/gemm.cpp | 10 +- .../pytorch/csrc/extensions/transpose.cpp | 3 +- transformer_engine/pytorch/csrc/quantizer.cpp | 51 +++++++- transformer_engine/pytorch/module/linear.py | 110 ++++++++++-------- 5 files changed, 127 insertions(+), 67 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index dd36e178ce7..c5670357fb9 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -102,7 +102,8 @@ class Quantizer { /*! @brief Construct a tensor with uninitialized data */ virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; - + virtual std::pair create_tensor(const NVTEShape& shape, + DType dtype) const = 0; /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the @@ -141,7 +142,7 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, DType dtype) const; + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShape& shape, DType dtype, @@ -168,7 +169,8 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, @@ -200,7 +202,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * * The amax is zeroed out. Most TE kernels that output amax expect @@ -259,6 +262,8 @@ class Float8BlockQuantizer : public Quantizer { // and optionally columnwise usage. std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -280,6 +285,8 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -314,7 +321,8 @@ class NVFP4Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * * The amax is zeroed out. Most TE kernels that output amax expect @@ -560,4 +568,4 @@ inline string to_string(const NVTEShape& s) { } } // namespace std -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 0e478ecd3ce..fa01af53fe8 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -82,7 +82,7 @@ bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { } // namespace detail -std::pair createOutputTensor(const std::vector& shape, +std::pair createOutputTensor(const NVTEShape& shape, DType dtype, py::handle quantizer) { std::unique_ptr my_quantizer = convert_quantizer(quantizer); return my_quantizer->create_tensor(shape, dtype); @@ -119,7 +119,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Check tensor dimensions const auto& A_shape = A_tensor.shape(); const auto& B_shape = B_tensor.shape(); - const NVTEShape D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); @@ -140,7 +140,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - std::tie(D_tensor, D) = createOutputTensor(convertShape(D_shape), output_dtype, quantizer); + std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), @@ -171,7 +171,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; std::tie(unquantized_D_tensor, unquantized_out) = - q.create_tensor(convertShape(D_shape), output_dtype); + q.create_tensor(D_shape, output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -596,4 +596,4 @@ std::optional> te_general_grouped_gemm( return bias; } -} // namespace transformer_engine::pytorch +} // namespace transformer_engine::pytorch \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 5ace996afcc..55c7fd57d79 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -61,8 +61,7 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - const auto in_shape_nvte = getTensorShape(input); - const auto in_shape = convertShape(in_shape_nvte); + const auto in_shape = getTensorShapeVector(input); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 4dc776b4546..6345ae3894c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -135,12 +135,19 @@ std::pair Float8Quantizer::create_tensor( at::Tensor scale_inv = at::empty(std::vector{1}, opts); return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } +std::pair Float8Quantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; @@ -205,8 +212,9 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + // Expected buffers + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); @@ -341,7 +349,14 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); } - +std::pair Float8CurrentScalingQuantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair Float8CurrentScalingQuantizer::create_tensor( const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -427,8 +442,9 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + // Expected buffers + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); @@ -581,6 +597,15 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} +std::pair Float8BlockQuantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} + std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -923,6 +948,15 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} +std::pair MXFP8Quantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} + std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -1189,7 +1223,14 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); } - +std::pair NVFP4Quantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, DType dtype) const { using namespace pybind11::literals; diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f71f780be3c..09b01e288a5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -96,26 +96,24 @@ def forward( ( is_first_microbatch, - cpu_offloading, - is_grad_enabled, - fp8_output, - fp8_grad, - module, - skip_fp8_weight_update, - debug, - ) = non_tensor_args - - ( fp8, fp8_calibration, wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, fuse_wgrad_accumulation, + cpu_offloading, tp_group, tp_size, sequence_parallel, tensor_parallel, activation_dtype, parallel_mode, + is_grad_enabled, ub_overlap_rs_fprop, ub_overlap_ag_dgrad, ub_overlap_ag_fprop, @@ -123,46 +121,14 @@ def forward( ub_bulk_dgrad, ub_bulk_wgrad, ub_name, + fp8_output, # pylint: disable=unused-variable fsdp_group, + module, + skip_fp8_weight_update, symmetric_ar_type, save_original_input, - ) = ( - module.fp8, - module.fp8_calibration, - module.wgrad_store, - module.fuse_wgrad_accumulation, - module.tp_group, - module.tp_size, - module.sequence_parallel, - module.tp_size > 1, - module.activation_dtype, - module.parallel_mode, - module.ub_overlap_rs_fprop, - module.ub_overlap_ag_dgrad, - module.ub_overlap_ag_fprop, - module.ub_overlap_rs_dgrad, - module.ub_bulk_dgrad, - module.ub_bulk_wgrad, - module.ub_name, - module.fsdp_group, - module.symmetric_ar_type, - module.save_original_input, - ) - quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - if debug: - quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if module.no_debug_features_active(quantizers): - debug = False - quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers + debug, + ) = non_tensor_args # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -1377,7 +1343,7 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() + @no_torch_dynamo(recursive=False) def forward( self, inp: torch.Tensor, @@ -1435,7 +1401,28 @@ def forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + if is_grad_enabled: linear_fn = _Linear.apply autograd_ctx = [] @@ -1445,12 +1432,37 @@ def forward( non_tensor_args = ( is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, is_grad_enabled, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_name, fp8_output, - fp8_grad, + self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, + self.save_original_input, debug, ) out = linear_fn( From 72886287f48b1501d74ea3d10724c4249806fea7 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 30 Dec 2025 12:40:16 +0000 Subject: [PATCH 09/23] got rid of vector in makeTransformerEngineTensor Signed-off-by: Varun Thumbe --- .../transformer_engine/transformer_engine.h | 63 +++++++++ transformer_engine/pytorch/csrc/common.cpp | 60 +-------- transformer_engine/pytorch/csrc/common.h | 36 ++--- .../pytorch/csrc/extensions/attention.cpp | 73 +++++----- .../pytorch/csrc/extensions/cast.cpp | 126 ++++++++++-------- .../pytorch/csrc/extensions/gemm.cpp | 23 +++- .../pytorch/csrc/extensions/padding.cpp | 36 +++-- .../pytorch/csrc/extensions/permutation.cpp | 47 +++++-- .../pytorch/csrc/extensions/transpose.cpp | 11 +- transformer_engine/pytorch/csrc/quantizer.cpp | 112 ++++++++++------ 10 files changed, 343 insertions(+), 244 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 19cb646be29..f9eb244cd9c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -518,6 +518,69 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); */ namespace transformer_engine { +/*! \class NVTEShapeWrapper + * \brief C++ wrapper for NVTEShape with container-like interface. + */ +class NVTEShapeWrapper { + private: + NVTEShape data; + + public: + // Default constructor + NVTEShapeWrapper() { + data.ndim = 0; + } + + // Constructor from NVTEShape (direct assignment by reference) + NVTEShapeWrapper(const NVTEShape& shape) { + data = shape; + } + + // Constructor from vector (creates a copy) + template NVTEShapeWrapper(const std::vector& shape_vec) { + data.ndim = shape_vec.size(); + for (size_t i = 0; i < data.ndim; ++i) { + data.data[i] = static_cast(shape_vec[i]); + } + } + + operator NVTEShape&() { return data; } + operator const NVTEShape&() const { return data; } + + // Iterator support + size_t* begin() { return data.data; } + const size_t* begin() const { return data.data; } + size_t* end() { return data.data + data.ndim; } + const size_t* end() const { return data.data + data.ndim; } + + // Index access + size_t& operator[](size_t idx) { return data.data[idx]; } + const size_t& operator[](size_t idx) const { return data.data[idx]; } + + // Back access + size_t& back() { return data.data[data.ndim - 1]; } + const size_t& back() const { return data.data[data.ndim - 1]; } + + // Size access + size_t size() const { return data.ndim; } + bool empty() const { return data.ndim == 0; } + + // Container operations + void push_back(size_t value) { + if (data.ndim < 15) { + data.data[data.ndim++] = value; + } + } + + void clear() { data.ndim = 0; } + + void resize(size_t new_size) { + if (new_size <= 15) { + data.ndim = new_size; + } + } +}; + /*! \enum DType * \brief TE datatype. */ diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 04ae78c0b78..66b1e227c25 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -114,11 +114,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return transformer_engine::TensorWrapper(data_ptr, shape, type); } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { - return transformer_engine::TensorWrapper(data_ptr, shape, type); -} - transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); NVTEShape shape = getTensorShape(tensor); @@ -162,21 +157,6 @@ makeTransformerEngineTensorList(std::vector> at_tensor_l std::move(nvte_tensor_list_ptrs), num_lists, num_tensors); } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, - NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - return ret; -} - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, @@ -195,43 +175,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, - NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - NVTEShape meta_shape; - meta_shape.ndim = 1; - meta_shape.data[0] = 1; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - return ret; -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape, - const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 - : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 - : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, - columnwise_scale_inv_shape); - return ret; -} transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, @@ -287,6 +230,9 @@ template size_t product(const std::vector& shape); template int64_t product(const std::vector& shape); size_t product(const NVTEShape& shape, size_t begin, size_t end) { + if(end == -1) { + end = shape.ndim; + } NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, " in a shape with ", shape.ndim, " entries"); size_t ret = 1; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c5670357fb9..b703cbc6810 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -271,6 +271,11 @@ class Float8BlockQuantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; class MXFP8Quantizer : public Quantizer { @@ -294,6 +299,11 @@ class MXFP8Quantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; class NVFP4Quantizer : public Quantizer { @@ -345,8 +355,11 @@ class NVFP4Quantizer : public Quantizer { void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; @@ -439,32 +452,11 @@ inline transformer_engine::DType GetTransformerEngineDType(int DType_value) { return static_cast(DType_value); } -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type); - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape = {1}, - const std::vector& columnwise_scale_inv_shape = {1}, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, @@ -492,7 +484,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( template T product(const std::vector& shape); -size_t product(const NVTEShape& shape, size_t begin, size_t end); +size_t product(const NVTEShape& shape, size_t begin=0, size_t end=-1); std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 1007dcb80c6..6d6effce6d9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -163,52 +163,52 @@ std::vector fused_attn_fwd( } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); + NVTEShapeWrapper bias_shape{bias_sizes}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), static_cast(bias_shape), DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; te_cu_seqlens_q = - makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); + makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32); te_cu_seqlens_kv = - makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); + makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + static_cast(cu_seqlens_q_padded_shape), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } - + NVTEShape default_scale_inv_shape; + default_scale_inv_shape.ndim = 1; + default_scale_inv_shape.data[0] = 1; if ((page_table_k.has_value()) && (page_table_v.has_value())) { auto page_table_k_sizes = page_table_k.value().sizes().vec(); - std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; + NVTEShapeWrapper page_table_k_shape{page_table_k_sizes}; auto page_table_v_sizes = page_table_v.value().sizes().vec(); - std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; + NVTEShapeWrapper page_table_v_shape{page_table_v_sizes}; te_page_table_k = - makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape, - DType::kInt32, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(page_table_k.value().data_ptr(), static_cast(page_table_k_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); te_page_table_v = - makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape, - DType::kInt32, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(page_table_v.value().data_ptr(), static_cast(page_table_v_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // softmax offset TensorWrapper te_SoftmaxOffset; if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); - std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; + NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes}; te_SoftmaxOffset = - makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, - DType::kFloat32, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), + DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // extract rng seed and offset @@ -461,30 +461,31 @@ std::vector fused_attn_bwd( // create cu_seqlens tensorwrappers auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; + NVTEShape zero_scale_inv_shape; + zero_scale_inv_shape.ndim = 1; + zero_scale_inv_shape.data[0] = 0; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), + DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), + DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; te_cu_seqlens_q_padded = makeTransformerEngineTensor( cu_seqlens_q_padded.value().data_ptr(), - nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), + static_cast(cu_seqlens_q_padded_shape), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( cu_seqlens_kv_padded.value().data_ptr(), - nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } @@ -494,12 +495,12 @@ std::vector fused_attn_bwd( nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { const std::vector &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); - const std::vector tmp(signed_shape.begin(), signed_shape.end()); + NVTEShapeWrapper tmp(signed_shape); NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - nvte_make_shape(tmp.data(), tmp.size())}; + static_cast(tmp)}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index aa9d800c7bb..af04328948b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -56,9 +56,8 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob TensorWrapper output_cpp; py::object output_py; if (output.is_none()) { - const auto shape = get_tensor_shape(input_cpp); const auto fake_dtype = input_cpp.dtype(); - std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(input_cpp.shape(), fake_dtype); } else { std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } @@ -180,8 +179,8 @@ std::vector multi_tensor_quantize(const std::vector &ten const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); // Construct output tensor - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + // std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(input_shape, input_dtype); output_cpp_list.emplace_back(std::move(output_cpp)); output_py_list.emplace_back(std::move(output_py)); } @@ -195,7 +194,7 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -220,7 +219,7 @@ std::tuple, std::vector> bulk_allocate_fp // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -235,13 +234,13 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -273,7 +272,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -330,24 +329,26 @@ std::tuple, std::vector> bulk_allocate_fp tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); - + NVTEShape zero_shape; + zero_shape.ndim = 1; + zero_shape.data[0] = 0; // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); } return retval; } std::tuple, std::vector> bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -371,7 +372,7 @@ std::tuple, std::vector> bulk_allocate_mx // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -386,13 +387,13 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -424,7 +425,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -477,17 +478,19 @@ std::tuple, std::vector> bulk_allocate_mx tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i])); - + NVTEShape zero_shape; + zero_shape.ndim = 1; + zero_shape.data[0] = 0; // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); } return retval; @@ -497,7 +500,7 @@ std::tuple, std::vector> bulk_allocate_mx // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector, bool> retval; @@ -522,7 +525,7 @@ std::tuple, std::vector, bool> bulk_alloc // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -535,9 +538,9 @@ std::tuple, std::vector, bool> bulk_alloc at::device(at::kCUDA).dtype(dtype)); }; - // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) - auto to_fp4_shape = [](const std::vector &shape) { - std::vector fp4_shape(shape.begin(), shape.end()); + // Lambda function for converting NVTEShapeWrapper shape to NVFP4 shape (last dim divided by 2) + auto to_fp4_shape = [](const NVTEShapeWrapper &shape) { + NVTEShapeWrapper fp4_shape(shape); if (!fp4_shape.empty()) { fp4_shape.back() /= 2; } @@ -546,13 +549,13 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -587,7 +590,9 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), @@ -595,13 +600,13 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -649,7 +654,9 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_list.emplace_back(make_torch_view( @@ -657,7 +664,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -682,26 +689,31 @@ std::tuple, std::vector, bool> bulk_alloc // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, // then set the amax and amax_columnwise values. + NVTEShape zero_shape, amax_shape; + zero_shape.ndim = 1; + amax_shape.ndim = 1; + zero_shape.data[0] = 0; + amax_shape.data[0] = 1; { auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp4_dtype, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); - + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode); + // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + amax_shape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + amax_shape); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); } @@ -765,9 +777,11 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; philox_unpack(philox_args, rng_state_ptr); - + NVTEShape rng_state_shape; + rng_state_shape.ndim = 1; + rng_state_shape.data[0] = 2; res.te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); + static_cast(rng_state_ptr), rng_state_shape, DType::kInt64)); quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); quant_config_list_rowwise[i].set_stochastic_rounding(true); @@ -781,7 +795,7 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( philox_unpack(philox_args_col, rng_state_ptr_colwise); res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr_colwise), std::vector{2}, DType::kInt64)); + static_cast(rng_state_ptr_colwise), rng_state_shape, DType::kInt64)); quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data()); quant_config_list_colwise[i].set_stochastic_rounding(true); } @@ -997,18 +1011,21 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // Note that the multi compute amax API expects rowwise amax pointer to be not null // So we need to set the pointer accordingly to make colwise-only quantization work std::vector orig_amax_ptr_list; + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; for (size_t i = 0; i < num_tensors; i++) { auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; orig_amax_ptr_list.push_back(rowwise_amax_ptr); auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + output_list[i].set_amax(amax_ptr, DType::kFloat32, amax_shape); } nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, stream); for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, amax_shape); } // Quantize tensors individually @@ -1104,7 +1121,7 @@ std::vector split_quantize(const at::Tensor &tensor, auto input_py = tensor.contiguous(); uint8_t *input_dptr = reinterpret_cast(input_py.data_ptr()); auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - std::vector input_shape; + NVTEShapeWrapper input_shape; size_t input_size = 1; for (const auto &d : input_py.sizes()) { input_shape.push_back(d); @@ -1114,7 +1131,7 @@ std::vector split_quantize(const at::Tensor &tensor, // Split input tensor along dim 0 std::vector input_list; - std::vector> split_shapes; + std::vector split_shapes; size_t dim0_offset = 0; const size_t dim0_stride = input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; @@ -1122,11 +1139,14 @@ std::vector split_quantize(const at::Tensor &tensor, NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], "Attempted to split tensor with shape=", input_shape, " along dim 0 with split_sections=", split_sections); - split_shapes.push_back(input_shape); + split_shapes.emplace_back(); auto &split_shape = split_shapes.back(); - split_shape[0] = split_sections[i]; + split_shape.push_back(split_sections[i]); + for (size_t j = 1; j < input_shape.size(); ++j) { + split_shape.push_back(input_shape[j]); + } void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); - input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); + input_list.emplace_back(makeTransformerEngineTensor(split_dptr, static_cast(split_shape), input_dtype)); dim0_offset += split_sections[i]; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index fa01af53fe8..07acd44170a 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -406,7 +406,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, counter_shape.data[0] = static_cast(counter.size(0)); auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); - NVTEShape gelu_shape; + NVTEShape gelu_shape, workspace_shape; if (pre_gelu_out.data_ptr() == nullptr) { gelu_shape.ndim = 1; gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); @@ -415,10 +415,12 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); } + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + workspace_shape, DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -509,10 +511,14 @@ std::optional> te_general_grouped_gemm( auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(te_pre_gelu_out.size(0))} - : std::vector{static_cast(te_pre_gelu_out.size(0)), - static_cast(te_pre_gelu_out.size(1))}; + NVTEShape gelu_shape; + gelu_shape.data[0] = te_pre_gelu_out.size(0); + if (pre_gelu_out[i].data_ptr() == nullptr) { + gelu_shape.ndim = 1; + } else { + gelu_shape.ndim = 2; + gelu_shape.data[1] = te_pre_gelu_out.size(1); + } DType gelu_type = bias_type; te_pre_gelu_out = @@ -579,9 +585,12 @@ std::optional> te_general_grouped_gemm( std::vector te_workspace_vector; std::vector te_workspace_wrappers; + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), - std::vector{workspaceSize}, DType::kByte); + workspace_shape, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index 389308405b1..cabb65233f7 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -20,7 +20,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -34,8 +34,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - NVTEShape input_shape = {input_row_list[tensor_id], static_cast(input.size(1))}; - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + NVTEShape input_shape; + input_shape.ndim = 2; + input_shape.data[0] = input_row_list[tensor_id]; + input_shape.data[1] = static_cast(input.size(1)); + input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -45,14 +48,17 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + NVTEShape output_shape; + output_shape.ndim = 2; + output_shape.data[0] = padded_input_row_list[tensor_id]; + output_shape.data[1] = static_cast(output.size(1)); + output_shape_list.push_back(output_shape); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); @@ -95,7 +101,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -109,8 +115,11 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + NVTEShape input_shape; + input_shape.ndim = 2; + input_shape.data[0] = input_row_list[tensor_id]; + input_shape.data[1] = static_cast(input.size(1)); + input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -120,14 +129,17 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {unpadded_input_row_list[tensor_id], static_cast(output.size(1))}); + NVTEShape output_shape; + output_shape.ndim = 2; + output_shape.data[0] = unpadded_input_row_list[tensor_id]; + output_shape.data[1] = static_cast(output.size(1)); + output_shape_list.push_back(output_shape); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, transformer_engine::DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 97cf4008511..189c7127731 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -60,18 +60,25 @@ std::tuple> moe_permute_fwd( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - + NVTEShape input_shape, permuted_output_shape, sorted_row_id_cu_shape; + input_shape.ndim = 2; + permuted_output_shape.ndim = 2; + sorted_row_id_cu_shape.ndim = 1; + input_shape.data[0] = static_cast(input.size(0)); + input_shape.data[1] = static_cast(input.size(1)); + permuted_output_shape.data[0] = static_cast(permuted_output.size(0)); + permuted_output_shape.data[1] = static_cast(permuted_output.size(1)); + sorted_row_id_cu_shape.data[0] = static_cast(num_tokens * topK); auto input_cu = makeTransformerEngineTensor( input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, + input_shape, dtype); auto permuted_output_cu = makeTransformerEngineTensor(permuted_output.data_ptr(), - std::vector{static_cast(permuted_output.size(0)), - static_cast(num_cols)}, + permuted_output_shape, dtype); auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_tokens * topK)}, + sorted_row_id_ptr, sorted_row_id_cu_shape, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); @@ -97,15 +104,20 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - + NVTEShape input_shape, unpermuted_output_shape; + input_shape.ndim = 2; + unpermuted_output_shape.ndim = 2; + input_shape.data[0] = static_cast(input.size(0)); + input_shape.data[1] = static_cast(input.size(1)); + unpermuted_output_shape.data[0] = static_cast(unpermuted_output.size(0)); + unpermuted_output_shape.data[1] = static_cast(num_cols); auto input_cu = makeTransformerEngineTensor( input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, + input_shape, dtype); auto unpermuted_output_cu = makeTransformerEngineTensor( unpermuted_output.data_ptr(), - std::vector{static_cast(unpermuted_output.size(0)), - static_cast(num_cols)}, + unpermuted_output_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); @@ -131,18 +143,27 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - + NVTEShape input_bwd_shape, act_grad_shape, input_fwd_shape; + input_bwd_shape.ndim = 2; + act_grad_shape.ndim = 2; + input_fwd_shape.ndim = 2; + input_bwd_shape.data[0] = static_cast(input_bwd.size(0)); + input_bwd_shape.data[1] = static_cast(num_cols); + act_grad_shape.data[0] = static_cast(act_grad.size(0)); + act_grad_shape.data[1] = static_cast(num_cols); + input_fwd_shape.data[0] = static_cast(input_fwd.size(0)); + input_fwd_shape.data[1] = static_cast(num_cols); auto input_bwd_cu = makeTransformerEngineTensor( input_bwd.data_ptr(), - std::vector{static_cast(input_bwd.size(0)), static_cast(num_cols)}, + input_bwd_shape, dtype); auto act_grad_cu = makeTransformerEngineTensor( act_grad.data_ptr(), - std::vector{static_cast(act_grad.size(0)), static_cast(num_cols)}, + act_grad_shape, dtype); auto input_fwd_cu = makeTransformerEngineTensor( input_fwd.data_ptr(), - std::vector{static_cast(input_fwd.size(0)), static_cast(num_cols)}, + input_fwd_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 55c7fd57d79..6c1ff313c5c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -45,9 +45,16 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional{M, N}, otype); - auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), output_shape, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6345ae3894c..650847dd86d 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -869,12 +869,24 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper Float8BlockQuantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT Float8BlockQuantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t k_dim; + for (auto s : shape) { numel *= s; } + k_dim = shape.size() == 0 ? 1u : shape.back(); - size_t k_dim = shape.size() == 0 ? 1u : shape.back(); size_t m_dim = numel / k_dim; constexpr size_t kBlockLen = 128; @@ -882,27 +894,20 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = 0; - size_t sinv1 = 0; if (block_scaling_dim == 2) { - // 2D scaling is always GEMM_READY for now NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, "2D scaling is always GEMM_READY for now."); sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); } else if (block_scaling_dim == 1) { - // 1D scaling can be GEMM_READY or COMPACT bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; - // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); - // if the rowwise format is compact, the scaling factor is not be transposed if (rowwise_compact) { std::swap(sinv0, sinv1); } @@ -912,13 +917,8 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper MXFP8Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT MXFP8Quantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } - - auto last_dim = shape.back(); + last_dim = shape.back(); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); - std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(numel / last_dim, 128); - size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / last_dim, 128); + sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - size_t sinv1 = roundup(last_dim, 128); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + sinv1 = roundup(last_dim, 128); } - return scale_shape; + + ShapeT result; + result.resize(2); + result[0] = sinv0; + result[1] = sinv1; + return result; } NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -1766,12 +1781,24 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper NVFP4Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT NVFP4Quantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } + last_dim = shape.back(); - auto last_dim = shape.back(); auto flat_first_dim = numel / last_dim; NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", @@ -1780,22 +1807,23 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(flat_first_dim, 128); - size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(flat_first_dim, 128); + sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(last_dim, 128); - size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(last_dim, 128); + sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); } - return scale_shape; + + ShapeT result; + result.resize(2); + result[0] = sinv0; + result[1] = sinv1; + return result; } } // namespace transformer_engine::pytorch From a66f46b831f9665d6edf0ed9109ac6c52c075b73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Dec 2025 12:42:50 +0000 Subject: [PATCH 10/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../transformer_engine/transformer_engine.h | 31 ++++---- transformer_engine/pytorch/csrc/common.cpp | 3 +- transformer_engine/pytorch/csrc/common.h | 8 +-- .../pytorch/csrc/extensions/attention.cpp | 55 +++++++------- .../pytorch/csrc/extensions/cast.cpp | 44 ++++++------ .../pytorch/csrc/extensions/gemm.cpp | 16 ++--- .../pytorch/csrc/extensions/permutation.cpp | 40 +++-------- transformer_engine/pytorch/csrc/quantizer.cpp | 72 +++++++++---------- 8 files changed, 125 insertions(+), 144 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index f9eb244cd9c..26a07c707a6 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -527,39 +527,36 @@ class NVTEShapeWrapper { public: // Default constructor - NVTEShapeWrapper() { - data.ndim = 0; - } + NVTEShapeWrapper() { data.ndim = 0; } // Constructor from NVTEShape (direct assignment by reference) - NVTEShapeWrapper(const NVTEShape& shape) { - data = shape; - } + NVTEShapeWrapper(const NVTEShape &shape) { data = shape; } // Constructor from vector (creates a copy) - template NVTEShapeWrapper(const std::vector& shape_vec) { + template + NVTEShapeWrapper(const std::vector &shape_vec) { data.ndim = shape_vec.size(); for (size_t i = 0; i < data.ndim; ++i) { data.data[i] = static_cast(shape_vec[i]); } } - operator NVTEShape&() { return data; } - operator const NVTEShape&() const { return data; } + operator NVTEShape &() { return data; } + operator const NVTEShape &() const { return data; } // Iterator support - size_t* begin() { return data.data; } - const size_t* begin() const { return data.data; } - size_t* end() { return data.data + data.ndim; } - const size_t* end() const { return data.data + data.ndim; } + size_t *begin() { return data.data; } + const size_t *begin() const { return data.data; } + size_t *end() { return data.data + data.ndim; } + const size_t *end() const { return data.data + data.ndim; } // Index access - size_t& operator[](size_t idx) { return data.data[idx]; } - const size_t& operator[](size_t idx) const { return data.data[idx]; } + size_t &operator[](size_t idx) { return data.data[idx]; } + const size_t &operator[](size_t idx) const { return data.data[idx]; } // Back access - size_t& back() { return data.data[data.ndim - 1]; } - const size_t& back() const { return data.data[data.ndim - 1]; } + size_t &back() { return data.data[data.ndim - 1]; } + const size_t &back() const { return data.data[data.ndim - 1]; } // Size access size_t size() const { return data.ndim; } diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 66b1e227c25..33060732777 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -175,7 +175,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, @@ -230,7 +229,7 @@ template size_t product(const std::vector& shape); template int64_t product(const std::vector& shape); size_t product(const NVTEShape& shape, size_t begin, size_t end) { - if(end == -1) { + if (end == -1) { end = shape.ndim; } NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index b703cbc6810..39b73c96443 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -142,7 +142,8 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShape& shape, DType dtype, @@ -457,7 +458,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, @@ -484,7 +484,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( template T product(const std::vector& shape); -size_t product(const NVTEShape& shape, size_t begin=0, size_t end=-1); +size_t product(const NVTEShape& shape, size_t begin = 0, size_t end = -1); std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); @@ -560,4 +560,4 @@ inline string to_string(const NVTEShape& s) { } } // namespace std -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ \ No newline at end of file +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 6d6effce6d9..bfa989c26c0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -164,26 +164,29 @@ std::vector fused_attn_fwd( if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); NVTEShapeWrapper bias_shape{bias_sizes}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), static_cast(bias_shape), DType::kFloat32); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), + static_cast(bias_shape), DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; - te_cu_seqlens_q = - makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32); - te_cu_seqlens_kv = - makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32); + te_cu_seqlens_q = makeTransformerEngineTensor( + cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32); + te_cu_seqlens_kv = makeTransformerEngineTensor( + cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - static_cast(cu_seqlens_q_padded_shape), DType::kInt32); + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), static_cast(cu_seqlens_q_padded_shape), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } NVTEShape default_scale_inv_shape; default_scale_inv_shape.ndim = 1; @@ -193,12 +196,12 @@ std::vector fused_attn_fwd( NVTEShapeWrapper page_table_k_shape{page_table_k_sizes}; auto page_table_v_sizes = page_table_v.value().sizes().vec(); NVTEShapeWrapper page_table_v_shape{page_table_v_sizes}; - te_page_table_k = - makeTransformerEngineTensor(page_table_k.value().data_ptr(), static_cast(page_table_k_shape), - DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); - te_page_table_v = - makeTransformerEngineTensor(page_table_v.value().data_ptr(), static_cast(page_table_v_shape), - DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); + te_page_table_k = makeTransformerEngineTensor( + page_table_k.value().data_ptr(), static_cast(page_table_k_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); + te_page_table_v = makeTransformerEngineTensor( + page_table_v.value().data_ptr(), static_cast(page_table_v_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // softmax offset @@ -206,9 +209,9 @@ std::vector fused_attn_fwd( if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes}; - te_SoftmaxOffset = - makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), - DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); + te_SoftmaxOffset = makeTransformerEngineTensor( + SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), + DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // extract rng seed and offset @@ -468,10 +471,12 @@ std::vector fused_attn_bwd( zero_scale_inv_shape.ndim = 1; zero_scale_inv_shape.data[0] = 0; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), - DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), - DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); + te_cu_seqlens_q = makeTransformerEngineTensor( + cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32, nullptr, + nullptr, nullptr, zero_scale_inv_shape); + te_cu_seqlens_kv = makeTransformerEngineTensor( + cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32, + nullptr, nullptr, nullptr, zero_scale_inv_shape); TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { @@ -480,13 +485,11 @@ std::vector fused_attn_bwd( auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; te_cu_seqlens_q_padded = makeTransformerEngineTensor( - cu_seqlens_q_padded.value().data_ptr(), - static_cast(cu_seqlens_q_padded_shape), + cu_seqlens_q_padded.value().data_ptr(), static_cast(cu_seqlens_q_padded_shape), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( cu_seqlens_kv_padded.value().data_ptr(), - static_cast(cu_seqlens_kv_padded_shape), - DType::kInt32); + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors @@ -500,7 +503,7 @@ std::vector fused_attn_bwd( NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - static_cast(tmp)}; + static_cast(tmp)}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index af04328948b..95e06278dff 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -336,12 +336,13 @@ std::tuple, std::vector> bulk_allocate_fp tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode)); } return retval; @@ -485,12 +486,13 @@ std::tuple, std::vector> bulk_allocate_mx tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode)); } return retval; @@ -698,18 +700,19 @@ std::tuple, std::vector, bool> bulk_alloc auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp4_dtype, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode); - + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode); + // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { - tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - amax_shape); + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, amax_shape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, @@ -780,8 +783,8 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( NVTEShape rng_state_shape; rng_state_shape.ndim = 1; rng_state_shape.data[0] = 2; - res.te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), rng_state_shape, DType::kInt64)); + res.te_rng_state_list.push_back(makeTransformerEngineTensor(static_cast(rng_state_ptr), + rng_state_shape, DType::kInt64)); quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); quant_config_list_rowwise[i].set_stochastic_rounding(true); @@ -1146,7 +1149,8 @@ std::vector split_quantize(const at::Tensor &tensor, split_shape.push_back(input_shape[j]); } void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); - input_list.emplace_back(makeTransformerEngineTensor(split_dptr, static_cast(split_shape), input_dtype)); + input_list.emplace_back(makeTransformerEngineTensor( + split_dptr, static_cast(split_shape), input_dtype)); dim0_offset += split_sections[i]; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 07acd44170a..82c1ce1e7b6 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -82,8 +82,8 @@ bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { } // namespace detail -std::pair createOutputTensor(const NVTEShape& shape, - DType dtype, py::handle quantizer) { +std::pair createOutputTensor(const NVTEShape& shape, DType dtype, + py::handle quantizer) { std::unique_ptr my_quantizer = convert_quantizer(quantizer); return my_quantizer->create_tensor(shape, dtype); } @@ -170,8 +170,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = - q.create_tensor(D_shape, output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -419,8 +418,8 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, workspace_shape.data[0] = workspaceSize; auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - workspace_shape, DType::kByte); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -589,8 +588,7 @@ std::optional> te_general_grouped_gemm( workspace_shape.ndim = 1; workspace_shape.data[0] = workspaceSize; for (size_t i = 0; i < workspace.size(); i++) { - auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), - workspace_shape, DType::kByte); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), workspace_shape, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } @@ -605,4 +603,4 @@ std::optional> te_general_grouped_gemm( return bias; } -} // namespace transformer_engine::pytorch \ No newline at end of file +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 189c7127731..b0654c326e7 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -69,17 +69,11 @@ std::tuple> moe_permute_fwd( permuted_output_shape.data[0] = static_cast(permuted_output.size(0)); permuted_output_shape.data[1] = static_cast(permuted_output.size(1)); sorted_row_id_cu_shape.data[0] = static_cast(num_tokens * topK); - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - input_shape, - dtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); auto permuted_output_cu = - makeTransformerEngineTensor(permuted_output.data_ptr(), - permuted_output_shape, - dtype); - auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, sorted_row_id_cu_shape, - DType::kInt32); + makeTransformerEngineTensor(permuted_output.data_ptr(), permuted_output_shape, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, sorted_row_id_cu_shape, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), @@ -111,14 +105,9 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row input_shape.data[1] = static_cast(input.size(1)); unpermuted_output_shape.data[0] = static_cast(unpermuted_output.size(0)); unpermuted_output_shape.data[1] = static_cast(num_cols); - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - input_shape, - dtype); - auto unpermuted_output_cu = makeTransformerEngineTensor( - unpermuted_output.data_ptr(), - unpermuted_output_shape, - dtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); + auto unpermuted_output_cu = + makeTransformerEngineTensor(unpermuted_output.data_ptr(), unpermuted_output_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); @@ -153,18 +142,9 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T act_grad_shape.data[1] = static_cast(num_cols); input_fwd_shape.data[0] = static_cast(input_fwd.size(0)); input_fwd_shape.data[1] = static_cast(num_cols); - auto input_bwd_cu = makeTransformerEngineTensor( - input_bwd.data_ptr(), - input_bwd_shape, - dtype); - auto act_grad_cu = makeTransformerEngineTensor( - act_grad.data_ptr(), - act_grad_shape, - dtype); - auto input_fwd_cu = makeTransformerEngineTensor( - input_fwd.data_ptr(), - input_fwd_shape, - dtype); + auto input_bwd_cu = makeTransformerEngineTensor(input_bwd.data_ptr(), input_bwd_shape, dtype); + auto act_grad_cu = makeTransformerEngineTensor(act_grad.data_ptr(), act_grad_shape, dtype); + auto input_fwd_cu = makeTransformerEngineTensor(input_fwd.data_ptr(), input_fwd_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 650847dd86d..1520541ade7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -135,13 +135,13 @@ std::pair Float8Quantizer::create_tensor( at::Tensor scale_inv = at::empty(std::vector{1}, opts); return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } -std::pair Float8Quantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair Float8Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair Float8Quantizer::create_tensor( @@ -351,11 +351,11 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair Float8CurrentScalingQuantizer::create_tensor( const std::vector& shape, DType dtype) const { @@ -597,13 +597,13 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair Float8BlockQuantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair Float8BlockQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair Float8BlockQuantizer::create_tensor( @@ -873,7 +873,7 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair MXFP8Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, @@ -1166,7 +1166,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s } NVTEShapeWrapper MXFP8Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, - bool columnwise) const { + bool columnwise) const { return get_scale_shape_impl(shape, columnwise); } @@ -1238,13 +1238,13 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); } -std::pair NVFP4Quantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair NVFP4Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, DType dtype) const { @@ -1785,7 +1785,7 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s } NVTEShapeWrapper NVFP4Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, - bool columnwise) const { + bool columnwise) const { return get_scale_shape_impl(shape, columnwise); } From b334a743c55f224b080a7033549001c2fddbb3f8 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 30 Dec 2025 18:50:07 +0000 Subject: [PATCH 11/23] minor miss Signed-off-by: Varun Thumbe --- .../pytorch/csrc/extensions/gemm.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 07acd44170a..24bcf3510a0 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -275,7 +275,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { NVTEShape extra_output_shape; - extra_output_shape.ndim = 0; + extra_output_shape.ndim = 1; + extra_output_shape.data[0] = 0; extra_output_tensor = makeTransformerEngineTensor(nullptr, extra_output_shape, DType::kByte); } @@ -378,13 +379,17 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; - const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; - const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); + NVTEShape A_shape, B_shape; + A_shape.ndim = 2; + + A_shape.data[0] = static_cast(A.size(0)); + A_shape.data[1] = static_cast(A.size(1)); auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), nvte_scaling_modeA); - const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; - const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); + B_shape.ndim = 2; + B_shape.data[0] = static_cast(B.size(0)); + B_shape.data[1] = static_cast(B.size(1)); auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), nvte_scaling_modeB); From 58bf0f08dc6442fff71bc7070a88e3bc2f647a06 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 31 Dec 2025 06:34:16 +0000 Subject: [PATCH 12/23] all shape copies removed Signed-off-by: Varun Thumbe --- .../transformer_engine/transformer_engine.h | 31 ++- transformer_engine/pytorch/csrc/common.cpp | 11 +- transformer_engine/pytorch/csrc/common.h | 48 ++-- .../pytorch/csrc/extensions/activation.cpp | 5 +- .../pytorch/csrc/extensions/attention.cpp | 7 +- .../pytorch/csrc/extensions/bias.cpp | 6 +- .../pytorch/csrc/extensions/cast.cpp | 3 +- .../pytorch/csrc/extensions/normalization.cpp | 4 +- .../pytorch/csrc/extensions/transpose.cpp | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 213 ++++++++++-------- .../pytorch/csrc/type_converters.cpp | 4 +- 11 files changed, 196 insertions(+), 138 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 26a07c707a6..8fed7927a06 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -528,11 +528,11 @@ class NVTEShapeWrapper { public: // Default constructor NVTEShapeWrapper() { data.ndim = 0; } - + NVTEShapeWrapper(int ndim) { data.ndim = ndim; } // Constructor from NVTEShape (direct assignment by reference) NVTEShapeWrapper(const NVTEShape &shape) { data = shape; } - // Constructor from vector (creates a copy) + // Constructor from vector (creates a copy) template NVTEShapeWrapper(const std::vector &shape_vec) { data.ndim = shape_vec.size(); @@ -540,6 +540,15 @@ class NVTEShapeWrapper { data.data[i] = static_cast(shape_vec[i]); } } + // In the NVTEShapeWrapper class definition: + template + NVTEShapeWrapper& operator=(const std::vector& shape_vec) { + data.ndim = shape_vec.size(); + for (size_t i = 0; i < data.ndim; ++i) { + data.data[i] = static_cast(shape_vec[i]); + } + return *this; + } operator NVTEShape &() { return data; } operator const NVTEShape &() const { return data; } @@ -558,6 +567,10 @@ class NVTEShapeWrapper { size_t &back() { return data.data[data.ndim - 1]; } const size_t &back() const { return data.data[data.ndim - 1]; } + // Front access + size_t &front() { return data.data[0]; } + const size_t &front() const { return data.data[0]; } + // Size access size_t size() const { return data.ndim; } bool empty() const { return data.ndim == 0; } @@ -576,6 +589,20 @@ class NVTEShapeWrapper { data.ndim = new_size; } } + + // Equality comparison with another NVTEShapeWrapper + bool operator==(const NVTEShapeWrapper& other) const { + if (data.ndim != other.data.ndim) { + return false; + } + for (size_t i = 0; i < data.ndim; ++i) { + if (data.data[i] != other.data.data[i]) { + return false; + } + } + return true; + } + }; /*! \enum DType diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 33060732777..eb4fbb6bd68 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -13,7 +13,8 @@ namespace transformer_engine::pytorch { /*! convert fp4 data shape back to original shape */ -std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose) { +template +T convert_shape_back_from_fp4(const T& shape, bool transpose) { std::vector ret; size_t start_idx = (transpose) ? 1 : 0; for (size_t i = start_idx; i < shape.size() - 1; ++i) { @@ -28,14 +29,6 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } -std::vector getTensorShapeVector(const at::Tensor& t) { - std::vector shape; - for (auto s : t.sizes()) { - shape.push_back(s); - } - return shape; -} - NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { NVTEShape ret; ret.ndim = torch_shape.size(); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 39b73c96443..bf70a2ec236 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -102,7 +102,7 @@ class Quantizer { /*! @brief Construct a tensor with uninitialized data */ virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; - virtual std::pair create_tensor(const NVTEShape& shape, + virtual std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const = 0; /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * @@ -138,21 +138,25 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShapeWrapper& shape, + DType dtype) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, - DType dtype) const override; - /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const NVTEShape& shape, DType dtype, + std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype, at::Tensor data) const; std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; + + private: + template + std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; }; class Float8Quantizer : public Quantizer { @@ -170,7 +174,7 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair create_tensor(const NVTEShape& shape, + std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, @@ -182,6 +186,12 @@ class Float8Quantizer : public Quantizer { void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; + + private: + template + std::pair create_tensor_impl( + const ShapeT& shape, DType dtype, std::optional data, + std::optional transpose, std::optional scale_inv) const; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -203,7 +213,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair create_tensor(const NVTEShape& shape, + std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * @@ -211,7 +221,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { * amax to be initialized to zero. */ std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype, std::optional data = std::nullopt); + const NVTEShapeWrapper& shape, DType dtype, std::optional data = std::nullopt); std::pair convert_and_update_tensor(py::object shape) const override; @@ -228,6 +238,9 @@ class Float8CurrentScalingQuantizer : public Quantizer { const std::optional& noop_flag = std::nullopt); private: + template + std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; + void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; @@ -263,7 +276,7 @@ class Float8BlockQuantizer : public Quantizer { // and optionally columnwise usage. std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair create_tensor(const NVTEShape& shape, + std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -275,6 +288,9 @@ class Float8BlockQuantizer : public Quantizer { NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; private: + template + std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; + template ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; @@ -291,7 +307,7 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair create_tensor(const NVTEShape& shape, + std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -303,6 +319,9 @@ class MXFP8Quantizer : public Quantizer { NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; private: + template + std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; + template ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; @@ -332,7 +351,7 @@ class NVFP4Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair create_tensor(const NVTEShape& shape, + std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * @@ -359,6 +378,9 @@ class NVFP4Quantizer : public Quantizer { NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; private: + template + std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; + template ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; void quantize_impl(const TensorWrapper& input, TensorWrapper& out, @@ -369,7 +391,6 @@ std::unique_ptr convert_quantizer(py::handle quantizer); NVTEShape getTensorShape(const at::Tensor& t); -std::vector getTensorShapeVector(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -506,7 +527,8 @@ size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); -std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); +template +T convert_shape_back_from_fp4(const T& shape, bool transpose); // unpack the PhiloxCudaState into CUDA tensor void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 14cc084c0c7..49c3023e0b5 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -26,7 +26,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int // Construct output tensor auto quantizer_cpp = convert_quantizer(quantizer); const auto input_shape = input_nvte.shape(); - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + NVTEShapeWrapper output_shape{input_shape}; output_shape.back() /= shape_divisor; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); @@ -138,8 +138,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); const auto input_shape_te = input_nvte.shape(); - const std::vector input_shape(input_shape_te.data, - input_shape_te.data + input_shape_te.ndim); + const NVTEShapeWrapper& input_shape{input_shape_te}; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bfa989c26c0..39bad4a0942 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -78,15 +78,16 @@ std::pair quantizer_helper(py::handle quantizer, } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // current scaling auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + NVTEShapeWrapper nvte_shape_wrapper{shape}; if (create_hp_tensor_for_cs) { if (data.has_value()) { std::tie(te_T, py_T) = - T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + T_quantizer_fp8->create_unquantized_tensor_with_amax(nvte_shape_wrapper, dtype, data.value()); } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(nvte_shape_wrapper, dtype); } } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(nvte_shape_wrapper, dtype); NVTE_CHECK( !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index c3e89ed0856..e15d9d5dc6f 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,7 +26,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShapeVector(grad_output_torch); + const NVTEShapeWrapper& shape = getTensorShape(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -116,11 +116,11 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShapeVector(grad_output_torch); + const NVTEShapeWrapper& output_shape = getTensorShape(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShapeVector(act_input_torch); + const NVTEShapeWrapper& input_shape = getTensorShape(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 95e06278dff..418e8fa0232 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -88,8 +88,7 @@ py::object dequantize(const py::handle &input, transformer_engine::DType otype) NoneQuantizer q(none); - const auto &shape = convertShape(input_tensor.shape()); - + NVTEShapeWrapper shape{input_tensor.shape()}; auto [out_tensor, out] = q.create_tensor(shape, otype); NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3c5c17fc6f2..f9ebe32f25d 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -79,7 +79,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } // Tensor dimensions - const auto shape = nvte_shape_to_vector(input_nvte.shape()); + NVTEShapeWrapper shape{input_nvte.shape()}; const auto outer_size = product(shape) / shape.back(); const auto inner_size = shape.back(); @@ -310,7 +310,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); // Tensor dimensions - const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const NVTEShapeWrapper shape{input_nvte.shape()}; const auto outer_size = product(shape) / shape.back(); const auto inner_size = shape.back(); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 6c1ff313c5c..c3cc83122c4 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -68,7 +68,7 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - const auto in_shape = getTensorShapeVector(input); + const NVTEShapeWrapper in_shape = getTensorShape(input); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1520541ade7..0c314f8cf42 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -19,9 +19,9 @@ namespace { * The tensor is interpreted as a 2D matrix by flattening all but the * last dimension, and then transposed. */ -template -std::vector make_transpose_shape(const std::vector& shape) { - std::vector ret; +template +T make_transpose_shape(const S& shape) { + T ret; if (shape.size() > 0) { ret.push_back(shape.back()); for (size_t i = 0; i < shape.size() - 1; ++i) { @@ -70,23 +70,25 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype) const { +template +std::pair NoneQuantizer::create_tensor_impl(const ShapeT& shape, + DType dtype) const { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } -std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, +std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype) const { - std::vector shape_int64; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_int64.push_back(static_cast(shape.data[i])); - } - const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); - return create_tensor(shape, dtype, at::empty(shape_int64, opts)); + return create_tensor_impl(shape, dtype); } +std::pair NoneQuantizer::create_tensor(const NVTEShapeWrapper& shape, + DType dtype) const { + return create_tensor_impl(shape, dtype); +} + + std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const { @@ -96,7 +98,7 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } -std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, +std::pair NoneQuantizer::create_tensor(const NVTEShapeWrapper& shape, DType dtype, at::Tensor data) const { TensorWrapper out_cpp; @@ -105,6 +107,8 @@ std::pair NoneQuantizer::create_tensor(const NVTEShap return {std::move(out_cpp), py::cast(data)}; } + + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -133,19 +137,19 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype) const { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); at::Tensor scale_inv = at::empty(std::vector{1}, opts); - return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); -} -std::pair Float8Quantizer::create_tensor(const NVTEShape& shape, - DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); + return create_tensor_impl(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional data, + const NVTEShapeWrapper& shape, DType dtype) const { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor scale_inv = at::empty(std::vector{1}, opts); + return create_tensor_impl(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); +} + +template +std::pair Float8Quantizer::create_tensor_impl( + const ShapeT& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); @@ -163,7 +167,7 @@ std::pair Float8Quantizer::create_tensor( // Initialize transpose tensor const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape>(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose = at::empty(transpose_shape, opts); } else if (!with_transpose && transpose) { @@ -199,7 +203,7 @@ std::pair Float8Quantizer::create_tensor( out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); } if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape(shape); out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); @@ -209,6 +213,13 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional data, + std::optional transpose, std::optional scale_inv) const { + return create_tensor_impl(shape, dtype, std::move(data), std::move(transpose), + std::move(scale_inv)); +} + std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); @@ -235,9 +246,9 @@ std::pair Float8Quantizer::convert_and_update_tensor( at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); // Tensor dimensions - std::vector shape; + NVTEShapeWrapper shape; if (has_transpose) { - const auto transpose_shape = getTensorShapeVector(*transpose_tensor); + const NVTEShapeWrapper transpose_shape = getTensorShape(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -245,12 +256,12 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape = getTensorShapeVector(*data_tensor); + const NVTEShapeWrapper expected_shape = getTensorShape(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShapeVector(*data_tensor); + shape = getTensorShape(*data_tensor); } // Coerce data tensor @@ -272,7 +283,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( transpose_py = py::none(); tensor.attr("_transpose") = transpose_py; } else if (!has_transpose && need_transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape>(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); transpose_py = py::cast(transpose_tensor); @@ -291,7 +302,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( std::vector{1}); } if (transpose_tensor) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape(shape); out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); @@ -350,15 +361,18 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso getTensorShape(amax)); } std::pair Float8CurrentScalingQuantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); + const std::vector& shape, DType dtype) const { + return create_tensor_impl(shape, dtype); } + std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { + const NVTEShapeWrapper& shape, DType dtype) const { + return create_tensor_impl(shape, dtype); +} + +template +std::pair Float8CurrentScalingQuantizer::create_tensor_impl( + const ShapeT& shape, DType dtype) const { using namespace pybind11::literals; // Initialize data tensor @@ -375,7 +389,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso at::Tensor transpose_tensor; const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape>(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); } @@ -414,7 +428,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::vector{1}); } if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape(shape); out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); @@ -425,7 +439,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso } std::pair -Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, +Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const NVTEShapeWrapper& shape, DType dtype, std::optional data) { amax.zero_(); @@ -465,9 +479,9 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); // Tensor dimensions - std::vector shape; + NVTEShapeWrapper shape; if (has_transpose) { - const auto transpose_shape = getTensorShapeVector(*transpose_tensor); + const NVTEShapeWrapper transpose_shape = getTensorShape(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -475,12 +489,12 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape = getTensorShapeVector(*data_tensor); + const NVTEShapeWrapper expected_shape = getTensorShape(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShapeVector(*data_tensor); + shape = getTensorShape(*data_tensor); } // Coerce data tensor in Python tensor @@ -502,7 +516,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ transpose_py = py::none(); tensor.attr("_transpose") = transpose_py; } else if (!has_transpose && need_transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape>(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); transpose_py = py::cast(transpose_tensor); @@ -521,7 +535,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ std::vector{1}); } if (transpose_tensor) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_shape(shape); out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); @@ -597,17 +611,20 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair Float8BlockQuantizer::create_tensor(const NVTEShape& shape, - DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); -} std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { + return create_tensor_impl(shape, dtype); +} + +std::pair Float8BlockQuantizer::create_tensor( + const NVTEShapeWrapper& shape, DType dtype) const { + return create_tensor_impl(shape, dtype); +} + +template +std::pair Float8BlockQuantizer::create_tensor_impl( + const ShapeT& shape, DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -639,13 +656,12 @@ std::pair Float8BlockQuantizer::create_tensor( if (columnwise_usage) { std::vector torch_columnwise_shape; - std::vector columnwise_shape; + NVTEShapeWrapper columnwise_shape; NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", columnwise_shape, " torch shape: ", torch_columnwise_shape); if (torch_shape.size() > 0) { if (!all_gather_usage) { torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); columnwise_shape.push_back(shape[shape.size() - 1]); for (size_t i = 0; i < torch_shape.size() - 1; ++i) { @@ -721,15 +737,15 @@ std::pair Float8BlockQuantizer::convert_and_update_te opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector { + auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> NVTEShapeWrapper { if (!columnwise_data) { - return std::vector(); + return NVTEShapeWrapper(); } if (all_gather_usage) { - return getTensorShapeVector(*columnwise_data); + return getTensorShape(*columnwise_data); } - std::vector shape = getTensorShapeVector(*columnwise_data); - std::vector shape_transposed(shape.size()); + NVTEShapeWrapper shape = getTensorShape(*columnwise_data); + NVTEShapeWrapper shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; } @@ -738,9 +754,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te } return shape_transposed; }; - std::vector shape; + NVTEShapeWrapper shape, columnwise_shape; if (rowwise_data) { - shape = getTensorShapeVector(*rowwise_data); + shape = getTensorShape(*rowwise_data); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -781,12 +797,10 @@ std::pair Float8BlockQuantizer::convert_and_update_te // Coerce column-wise data if (columnwise_usage) { - std::vector columnwise_shape; std::vector torch_columnwise_shape; if (torch_shape.size() > 0) { if (!all_gather_usage) { torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); columnwise_shape.push_back(shape[shape.size() - 1]); for (size_t i = 0; i < torch_shape.size() - 1; ++i) { @@ -830,8 +844,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast(); const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto& rowwise_shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); } @@ -839,8 +852,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast(); const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto& shape = getTensorShape(data_colwise); - ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, columnwise_shape); const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } @@ -951,17 +963,20 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair MXFP8Quantizer::create_tensor(const NVTEShape& shape, + +std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); + return create_tensor_impl(shape, dtype); } -std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, +std::pair MXFP8Quantizer::create_tensor(const NVTEShapeWrapper& shape, DType dtype) const { + return create_tensor_impl(shape, dtype); +} + +template +std::pair MXFP8Quantizer::create_tensor_impl(const ShapeT& shape, + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions @@ -1060,16 +1075,16 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data."); // Tensor dimensions - std::vector shape; + NVTEShapeWrapper shape; if (columnwise_data) { - shape = getTensorShapeVector(*columnwise_data); + shape = getTensorShape(*columnwise_data); if (rowwise_data) { - const auto expected_shape = getTensorShapeVector(*rowwise_data); + const NVTEShapeWrapper expected_shape = getTensorShape(*rowwise_data); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = getTensorShapeVector(*rowwise_data); + shape = getTensorShape(*rowwise_data); } // Coerce row-wise data @@ -1238,16 +1253,19 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); } -std::pair NVFP4Quantizer::create_tensor(const NVTEShape& shape, +std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); + return create_tensor_impl(shape, dtype); } -std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, + +std::pair NVFP4Quantizer::create_tensor(const NVTEShapeWrapper& shape, DType dtype) const { + return create_tensor_impl(shape, dtype); +} + +template +std::pair NVFP4Quantizer::create_tensor_impl(const ShapeT& shape, + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions @@ -1288,7 +1306,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero std::vector shape_int64_2d = {static_cast(flat_first_dim), static_cast(flat_last_dim)}; - const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + const auto transpose_shape_int64 = make_transpose_shape>(shape_int64_2d); columnwise_data_tensor = at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); @@ -1341,7 +1359,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // enforce 2D shape to avoid [S, B, H] shape and B and be 1 // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero std::vector shape_2d = {flat_first_dim, flat_last_dim}; - auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + auto col_data_shape_fp4 = make_transpose_shape>(shape_2d); out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, col_data_shape_fp4); out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, @@ -1357,8 +1375,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( TensorWrapper& quantized_tensor, DType dtype) { // Construct tensor - auto shape = convertShape(quantized_tensor.shape()); - auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(quantized_tensor.shape(), dtype); // Register amax pointer from quantized tensor void* amax_ptr = quantized_tensor.amax(); @@ -1395,16 +1412,16 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data."); // Tensor dimensions, shape means original shape - std::vector shape; + NVTEShapeWrapper shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShapeVector(*columnwise_data), true); + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); + NVTEShapeWrapper expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); + shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); } size_t flat_first_dim = 1; @@ -1461,7 +1478,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( std::vector shape_int64_2d = {static_cast(flat_first_dim), static_cast(flat_last_dim)}; const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + const auto transpose_shape_int64 = make_transpose_shape>(shape_int64_2d); columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts); tensor.attr("_columnwise_data") = *columnwise_data; } @@ -1507,7 +1524,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // enforce 2D shape to avoid [S, B, H] shape and B and be 1 // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero std::vector shape_2d = {flat_first_dim, flat_last_dim}; - auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, col_data_shape_fp4); out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 48e9f06cc40..368e9dcdfa3 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(getTensorShapeVector(data), false)); + convert_shape_back_from_fp4(getTensorShape(data), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(getTensorShapeVector(data), false)); + convert_shape_back_from_fp4(getTensorShape(data), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, From 8a9bb773a9d3bb0c8a92d7a856e7f810240fa923 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 Dec 2025 06:35:06 +0000 Subject: [PATCH 13/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../transformer_engine/transformer_engine.h | 7 +++--- transformer_engine/pytorch/csrc/common.h | 1 - .../pytorch/csrc/extensions/attention.cpp | 7 +++--- .../pytorch/csrc/extensions/bias.cpp | 6 ++--- transformer_engine/pytorch/csrc/quantizer.cpp | 24 ++++++++----------- 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 8fed7927a06..3d330db5988 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -532,7 +532,7 @@ class NVTEShapeWrapper { // Constructor from NVTEShape (direct assignment by reference) NVTEShapeWrapper(const NVTEShape &shape) { data = shape; } - // Constructor from vector (creates a copy) + // Constructor from vector (creates a copy) template NVTEShapeWrapper(const std::vector &shape_vec) { data.ndim = shape_vec.size(); @@ -542,7 +542,7 @@ class NVTEShapeWrapper { } // In the NVTEShapeWrapper class definition: template - NVTEShapeWrapper& operator=(const std::vector& shape_vec) { + NVTEShapeWrapper &operator=(const std::vector &shape_vec) { data.ndim = shape_vec.size(); for (size_t i = 0; i < data.ndim; ++i) { data.data[i] = static_cast(shape_vec[i]); @@ -591,7 +591,7 @@ class NVTEShapeWrapper { } // Equality comparison with another NVTEShapeWrapper - bool operator==(const NVTEShapeWrapper& other) const { + bool operator==(const NVTEShapeWrapper &other) const { if (data.ndim != other.data.ndim) { return false; } @@ -602,7 +602,6 @@ class NVTEShapeWrapper { } return true; } - }; /*! \enum DType diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bf70a2ec236..5e7a03daa35 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -391,7 +391,6 @@ std::unique_ptr convert_quantizer(py::handle quantizer); NVTEShape getTensorShape(const at::Tensor& t); - transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 39bad4a0942..b40cec24948 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -81,10 +81,11 @@ std::pair quantizer_helper(py::handle quantizer, NVTEShapeWrapper nvte_shape_wrapper{shape}; if (create_hp_tensor_for_cs) { if (data.has_value()) { - std::tie(te_T, py_T) = - T_quantizer_fp8->create_unquantized_tensor_with_amax(nvte_shape_wrapper, dtype, data.value()); + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax( + nvte_shape_wrapper, dtype, data.value()); } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(nvte_shape_wrapper, dtype); + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(nvte_shape_wrapper, dtype); } } else { std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(nvte_shape_wrapper, dtype); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index e15d9d5dc6f..38186221753 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,7 +26,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const NVTEShapeWrapper& shape = getTensorShape(grad_output_torch); + const NVTEShapeWrapper &shape = getTensorShape(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -116,11 +116,11 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const NVTEShapeWrapper& output_shape = getTensorShape(grad_output_torch); + const NVTEShapeWrapper &output_shape = getTensorShape(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const NVTEShapeWrapper& input_shape = getTensorShape(act_input_torch); + const NVTEShapeWrapper &input_shape = getTensorShape(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0c314f8cf42..e825fea32c7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -72,7 +72,7 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti template std::pair NoneQuantizer::create_tensor_impl(const ShapeT& shape, - DType dtype) const { + DType dtype) const { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); @@ -88,7 +88,6 @@ std::pair NoneQuantizer::create_tensor(const NVTEShap return create_tensor_impl(shape, dtype); } - std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const { @@ -107,8 +106,6 @@ std::pair NoneQuantizer::create_tensor(const NVTEShap return {std::move(out_cpp), py::cast(data)}; } - - std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -140,8 +137,8 @@ std::pair Float8Quantizer::create_tensor( return create_tensor_impl(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } -std::pair Float8Quantizer::create_tensor( - const NVTEShapeWrapper& shape, DType dtype) const { +std::pair Float8Quantizer::create_tensor(const NVTEShapeWrapper& shape, + DType dtype) const { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); at::Tensor scale_inv = at::empty(std::vector{1}, opts); return create_tensor_impl(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); @@ -217,7 +214,7 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { return create_tensor_impl(shape, dtype, std::move(data), std::move(transpose), - std::move(scale_inv)); + std::move(scale_inv)); } std::pair Float8Quantizer::convert_and_update_tensor( @@ -611,7 +608,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} - std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { return create_tensor_impl(shape, dtype); @@ -623,8 +619,8 @@ std::pair Float8BlockQuantizer::create_tensor( } template -std::pair Float8BlockQuantizer::create_tensor_impl( - const ShapeT& shape, DType dtype) const { +std::pair Float8BlockQuantizer::create_tensor_impl(const ShapeT& shape, + DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -963,7 +959,6 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} - std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { return create_tensor_impl(shape, dtype); @@ -976,7 +971,7 @@ std::pair MXFP8Quantizer::create_tensor(const NVTESha template std::pair MXFP8Quantizer::create_tensor_impl(const ShapeT& shape, - DType dtype) const { + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions @@ -1265,7 +1260,7 @@ std::pair NVFP4Quantizer::create_tensor(const NVTESha template std::pair NVFP4Quantizer::create_tensor_impl(const ShapeT& shape, - DType dtype) const { + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions @@ -1416,7 +1411,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( if (columnwise_data) { shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); if (rowwise_data) { - NVTEShapeWrapper expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + NVTEShapeWrapper expected_shape = + convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } From 116761f47a7c014d2cc00e8675079ffacde3add4 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 1 Jan 2026 06:08:56 +0000 Subject: [PATCH 14/23] clean up Signed-off-by: Varun Thumbe --- .../transformer_engine/transformer_engine.h | 18 ++++ transformer_engine/pytorch/csrc/common.cpp | 42 +++++++-- transformer_engine/pytorch/csrc/common.h | 5 +- .../pytorch/csrc/extensions/cast.cpp | 31 ++----- .../pytorch/csrc/extensions/gemm.cpp | 90 +++++++------------ .../pytorch/csrc/extensions/padding.cpp | 20 +---- .../pytorch/csrc/extensions/permutation.cpp | 21 ++--- .../pytorch/csrc/extensions/transpose.cpp | 11 +-- 8 files changed, 107 insertions(+), 131 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 8fed7927a06..351754b3a0c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -540,6 +540,14 @@ class NVTEShapeWrapper { data.data[i] = static_cast(shape_vec[i]); } } + // Constructor from initializer list + NVTEShapeWrapper(const std::initializer_list& shape_list) { + data.ndim = shape_list.size(); + size_t i = 0; + for (const auto& val : shape_list) { + data.data[i++] = val; + } + } // In the NVTEShapeWrapper class definition: template NVTEShapeWrapper& operator=(const std::vector& shape_vec) { @@ -550,6 +558,16 @@ class NVTEShapeWrapper { return *this; } + // Assignment operator from initializer list + NVTEShapeWrapper& operator=(const std::initializer_list& shape_list) { + data.ndim = shape_list.size(); + size_t i = 0; + for (const auto& val : shape_list) { + data.data[i++] = val; + } + return *this; + } + operator NVTEShape &() { return data; } operator const NVTEShape &() const { return data; } diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index eb4fbb6bd68..57a88c8714f 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -13,18 +13,18 @@ namespace transformer_engine::pytorch { /*! convert fp4 data shape back to original shape */ -template -T convert_shape_back_from_fp4(const T& shape, bool transpose) { - std::vector ret; +NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose) { + NVTEShapeWrapper ret; + const NVTEShapeWrapper input_shape(shape); size_t start_idx = (transpose) ? 1 : 0; - for (size_t i = start_idx; i < shape.size() - 1; ++i) { - ret.push_back(shape[i]); + for (size_t i = start_idx; i < input_shape.size() - 1; ++i) { + ret.push_back(input_shape[i]); } - ret.push_back(shape.back() * 2); + ret.push_back(input_shape.back() * 2); if (transpose) { - ret.push_back(shape.front()); + ret.push_back(input_shape.front()); } - return ret; + return static_cast(ret); } NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } @@ -43,6 +43,21 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { return ret; } +template NVTEShape make_nvte_1d_shape(T dim0) { + NVTEShape shape; + shape.ndim = 1; + shape.data[0] = static_cast(dim0); + return shape; +} + +template NVTEShape make_nvte_2d_shape(T dim0, U dim1) { + NVTEShape shape; + shape.ndim = 2; + shape.data[0] = static_cast(dim0); + shape.data[1] = static_cast(dim1); + return shape; +} + std::unique_ptr convert_quantizer(py::handle quantizer) { init_extension(); if (quantizer.is_none()) { @@ -317,4 +332,15 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_pe return philox_args; } +// Explicit template instantiations for make_nvte_1d_shape +template NVTEShape make_nvte_1d_shape(int dim0); +template NVTEShape make_nvte_1d_shape(int64_t dim0); +template NVTEShape make_nvte_1d_shape(size_t dim0); + +// Explicit template instantiations for make_nvte_2d_shape +template NVTEShape make_nvte_2d_shape(int64_t dim0, int64_t dim1); +template NVTEShape make_nvte_2d_shape(size_t dim0, size_t dim1); +template NVTEShape make_nvte_2d_shape(int64_t dim0, size_t dim1); +template NVTEShape make_nvte_2d_shape(size_t dim0, int64_t dim1); + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bf70a2ec236..d4d77733570 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -391,6 +391,8 @@ std::unique_ptr convert_quantizer(py::handle quantizer); NVTEShape getTensorShape(const at::Tensor& t); +template NVTEShape make_nvte_1d_shape(T dim0); +template NVTEShape make_nvte_2d_shape(T dim0, U dim1); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -527,8 +529,7 @@ size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); -template -T convert_shape_back_from_fp4(const T& shape, bool transpose); +NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose); // unpack the PhiloxCudaState into CUDA tensor void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 418e8fa0232..f99ba928645 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -328,9 +328,7 @@ std::tuple, std::vector> bulk_allocate_fp tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); - NVTEShape zero_shape; - zero_shape.ndim = 1; - zero_shape.data[0] = 0; + const NVTEShape& zero_shape = make_nvte_1d_shape(0); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, @@ -478,9 +476,7 @@ std::tuple, std::vector> bulk_allocate_mx tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i])); - NVTEShape zero_shape; - zero_shape.ndim = 1; - zero_shape.data[0] = 0; + const NVTEShape& zero_shape = make_nvte_1d_shape(0); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, @@ -591,9 +587,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - NVTEShape amax_shape; - amax_shape.ndim = 1; - amax_shape.data[0] = 1; + const NVTEShape& amax_shape = make_nvte_1d_shape(1); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), @@ -655,9 +649,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - NVTEShape amax_shape; - amax_shape.ndim = 1; - amax_shape.data[0] = 1; + const NVTEShape& amax_shape = make_nvte_1d_shape(1); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_list.emplace_back(make_torch_view( @@ -690,11 +682,8 @@ std::tuple, std::vector, bool> bulk_alloc // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, // then set the amax and amax_columnwise values. - NVTEShape zero_shape, amax_shape; - zero_shape.ndim = 1; - amax_shape.ndim = 1; - zero_shape.data[0] = 0; - amax_shape.data[0] = 1; + const NVTEShape zero_shape = make_nvte_1d_shape(0); + const NVTEShape amax_shape = make_nvte_1d_shape(1); { auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, @@ -779,9 +768,7 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; philox_unpack(philox_args, rng_state_ptr); - NVTEShape rng_state_shape; - rng_state_shape.ndim = 1; - rng_state_shape.data[0] = 2; + const NVTEShape rng_state_shape = make_nvte_1d_shape(2); res.te_rng_state_list.push_back(makeTransformerEngineTensor(static_cast(rng_state_ptr), rng_state_shape, DType::kInt64)); quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); @@ -1013,9 +1000,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // Note that the multi compute amax API expects rowwise amax pointer to be not null // So we need to set the pointer accordingly to make colwise-only quantization work std::vector orig_amax_ptr_list; - NVTEShape amax_shape; - amax_shape.ndim = 1; - amax_shape.data[0] = 1; + const NVTEShape& amax_shape = make_nvte_1d_shape(1); for (size_t i = 0; i < num_tensors; i++) { auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; orig_amax_ptr_list.push_back(rowwise_amax_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 7ac924b6105..ee9cf28684f 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -210,9 +210,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } NVTEShape gelu_shape; - gelu_shape.ndim = 1; - gelu_shape.data[0] = 0; - if (gelu) { + if (!gelu) { + gelu_shape = make_nvte_1d_shape(0); + } else { gelu_shape = D_shape; } @@ -220,11 +220,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); // Workspace - NVTEShape workspace_shape; - workspace_shape.ndim = 1; - workspace_shape.data[0] = workspaceSize; - auto te_workspace = - makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); + auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), + make_nvte_1d_shape(workspaceSize), + DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -273,11 +271,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (extra_output.has_value()) { extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { - NVTEShape extra_output_shape; - extra_output_shape.ndim = 1; - extra_output_shape.data[0] = 0; extra_output_tensor = - makeTransformerEngineTensor(nullptr, extra_output_shape, DType::kByte); + makeTransformerEngineTensor(nullptr, make_nvte_1d_shape(0), DType::kByte); } // Direct GEMM call to the correct overlap @@ -378,53 +373,32 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; - NVTEShape A_shape, B_shape; - A_shape.ndim = 2; - - A_shape.data[0] = static_cast(A.size(0)); - A_shape.data[1] = static_cast(A.size(1)); - auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, - A_scale_inverse.data_ptr(), + auto te_A = makeTransformerEngineTensor(A.data_ptr(), make_nvte_2d_shape(A.size(0), A.size(1)), + A_type, nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), nvte_scaling_modeA); - B_shape.ndim = 2; - B_shape.data[0] = static_cast(B.size(0)); - B_shape.data[1] = static_cast(B.size(1)); - auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, - B_scale_inverse.data_ptr(), + auto te_B = makeTransformerEngineTensor(B.data_ptr(), make_nvte_2d_shape(B.size(0), B.size(1)), + B_type, nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. - NVTEShape D_shape, D_scale_inv_shape; - D_shape.ndim = 2; - D_scale_inv_shape.ndim = 1; - D_scale_inv_shape.data[0] = 1; - D_shape.data[0] = static_cast(D.size(0)); - D_shape.data[1] = static_cast(D.size(1)); - auto te_D = makeTransformerEngineTensor(D.data_ptr(), D_shape, D_type, D_amax.data_ptr(), - D_scale.data_ptr(), nullptr, D_scale_inv_shape); - NVTEShape bias_shape; - bias_shape.ndim = 1; - bias_shape.data[0] = static_cast(bias.size(0)); - auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), bias_shape, bias_type); - NVTEShape counter_shape; - counter_shape.ndim = 1; - counter_shape.data[0] = static_cast(counter.size(0)); - auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); - - NVTEShape gelu_shape, workspace_shape; + auto te_D = makeTransformerEngineTensor( + D.data_ptr(), make_nvte_2d_shape(D.size(0), D.size(1)), D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr, make_nvte_1d_shape(1)); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), make_nvte_1d_shape(bias.size(0)), + bias_type); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), + make_nvte_1d_shape(counter.size(0)), + DType::kInt32); + + NVTEShape gelu_shape; if (pre_gelu_out.data_ptr() == nullptr) { - gelu_shape.ndim = 1; - gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + gelu_shape = make_nvte_1d_shape(pre_gelu_out.size(0)); } else { - gelu_shape.ndim = 2; - gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); - gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); + gelu_shape = make_nvte_2d_shape(pre_gelu_out.size(0), pre_gelu_out.size(1)); } - workspace_shape.ndim = 1; - workspace_shape.data[0] = workspaceSize; auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); - auto te_workspace = - makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); + auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), + make_nvte_1d_shape(workspaceSize), DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -516,17 +490,15 @@ std::optional> te_general_grouped_gemm( auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); NVTEShape gelu_shape; - gelu_shape.data[0] = te_pre_gelu_out.size(0); if (pre_gelu_out[i].data_ptr() == nullptr) { - gelu_shape.ndim = 1; + gelu_shape = make_nvte_1d_shape(te_pre_gelu_out.size(0)); } else { - gelu_shape.ndim = 2; - gelu_shape.data[1] = te_pre_gelu_out.size(1); + gelu_shape = make_nvte_2d_shape(te_pre_gelu_out.size(0), te_pre_gelu_out.size(1)); } DType gelu_type = bias_type; - te_pre_gelu_out = - makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); + te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, + gelu_type); te_A_wrappers.emplace_back(std::move(te_A)); te_B_wrappers.emplace_back(std::move(te_B)); @@ -589,9 +561,7 @@ std::optional> te_general_grouped_gemm( std::vector te_workspace_vector; std::vector te_workspace_wrappers; - NVTEShape workspace_shape; - workspace_shape.ndim = 1; - workspace_shape.data[0] = workspaceSize; + const NVTEShape& workspace_shape = make_nvte_1d_shape(workspaceSize); for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), workspace_shape, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index cabb65233f7..00400e1bee0 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -34,10 +34,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - NVTEShape input_shape; - input_shape.ndim = 2; - input_shape.data[0] = input_row_list[tensor_id]; - input_shape.data[1] = static_cast(input.size(1)); + auto input_shape = make_nvte_2d_shape(input_row_list[tensor_id], input.size(1)); input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); @@ -48,10 +45,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - NVTEShape output_shape; - output_shape.ndim = 2; - output_shape.data[0] = padded_input_row_list[tensor_id]; - output_shape.data[1] = static_cast(output.size(1)); + auto output_shape = make_nvte_2d_shape(padded_input_row_list[tensor_id], output.size(1)); output_shape_list.push_back(output_shape); } @@ -115,10 +109,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - NVTEShape input_shape; - input_shape.ndim = 2; - input_shape.data[0] = input_row_list[tensor_id]; - input_shape.data[1] = static_cast(input.size(1)); + auto input_shape = make_nvte_2d_shape(input_row_list[tensor_id], input.size(1)); input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); @@ -129,10 +120,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - NVTEShape output_shape; - output_shape.ndim = 2; - output_shape.data[0] = unpadded_input_row_list[tensor_id]; - output_shape.data[1] = static_cast(output.size(1)); + auto output_shape = make_nvte_2d_shape(unpadded_input_row_list[tensor_id], output.size(1)); output_shape_list.push_back(output_shape); } diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index b0654c326e7..21ed3f792ca 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -60,20 +60,15 @@ std::tuple> moe_permute_fwd( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - NVTEShape input_shape, permuted_output_shape, sorted_row_id_cu_shape; - input_shape.ndim = 2; - permuted_output_shape.ndim = 2; - sorted_row_id_cu_shape.ndim = 1; - input_shape.data[0] = static_cast(input.size(0)); - input_shape.data[1] = static_cast(input.size(1)); - permuted_output_shape.data[0] = static_cast(permuted_output.size(0)); - permuted_output_shape.data[1] = static_cast(permuted_output.size(1)); - sorted_row_id_cu_shape.data[0] = static_cast(num_tokens * topK); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); - auto permuted_output_cu = - makeTransformerEngineTensor(permuted_output.data_ptr(), permuted_output_shape, dtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), + make_nvte_2d_shape(input.size(0), input.size(1)), + dtype); + auto permuted_output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + make_nvte_2d_shape(permuted_output.size(0), permuted_output.size(1)), dtype); auto sorted_row_id_cu = - makeTransformerEngineTensor(sorted_row_id_ptr, sorted_row_id_cu_shape, DType::kInt32); + makeTransformerEngineTensor(sorted_row_id_ptr, make_nvte_1d_shape(num_tokens * topK), + DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index c3cc83122c4..78f40245d5f 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -45,16 +45,9 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional Date: Thu, 1 Jan 2026 06:10:19 +0000 Subject: [PATCH 15/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../transformer_engine/transformer_engine.h | 8 +++---- transformer_engine/pytorch/csrc/common.cpp | 6 +++-- transformer_engine/pytorch/csrc/common.h | 6 +++-- .../pytorch/csrc/extensions/cast.cpp | 10 ++++----- .../pytorch/csrc/extensions/gemm.cpp | 22 +++++++++---------- .../pytorch/csrc/extensions/permutation.cpp | 10 ++++----- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 402fe5329df..246a5032e29 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -541,10 +541,10 @@ class NVTEShapeWrapper { } } // Constructor from initializer list - NVTEShapeWrapper(const std::initializer_list& shape_list) { + NVTEShapeWrapper(const std::initializer_list &shape_list) { data.ndim = shape_list.size(); size_t i = 0; - for (const auto& val : shape_list) { + for (const auto &val : shape_list) { data.data[i++] = val; } } @@ -559,10 +559,10 @@ class NVTEShapeWrapper { } // Assignment operator from initializer list - NVTEShapeWrapper& operator=(const std::initializer_list& shape_list) { + NVTEShapeWrapper &operator=(const std::initializer_list &shape_list) { data.ndim = shape_list.size(); size_t i = 0; - for (const auto& val : shape_list) { + for (const auto &val : shape_list) { data.data[i++] = val; } return *this; diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 57a88c8714f..0f19e952a76 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -43,14 +43,16 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { return ret; } -template NVTEShape make_nvte_1d_shape(T dim0) { +template +NVTEShape make_nvte_1d_shape(T dim0) { NVTEShape shape; shape.ndim = 1; shape.data[0] = static_cast(dim0); return shape; } -template NVTEShape make_nvte_2d_shape(T dim0, U dim1) { +template +NVTEShape make_nvte_2d_shape(T dim0, U dim1) { NVTEShape shape; shape.ndim = 2; shape.data[0] = static_cast(dim0); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index d4d77733570..6d4cab18c9c 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -391,8 +391,10 @@ std::unique_ptr convert_quantizer(py::handle quantizer); NVTEShape getTensorShape(const at::Tensor& t); -template NVTEShape make_nvte_1d_shape(T dim0); -template NVTEShape make_nvte_2d_shape(T dim0, U dim1); +template +NVTEShape make_nvte_1d_shape(T dim0); +template +NVTEShape make_nvte_2d_shape(T dim0, U dim1); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f99ba928645..ae0cc801fa1 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -328,7 +328,7 @@ std::tuple, std::vector> bulk_allocate_fp tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); - const NVTEShape& zero_shape = make_nvte_1d_shape(0); + const NVTEShape &zero_shape = make_nvte_1d_shape(0); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, @@ -476,7 +476,7 @@ std::tuple, std::vector> bulk_allocate_mx tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i])); - const NVTEShape& zero_shape = make_nvte_1d_shape(0); + const NVTEShape &zero_shape = make_nvte_1d_shape(0); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, @@ -587,7 +587,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - const NVTEShape& amax_shape = make_nvte_1d_shape(1); + const NVTEShape &amax_shape = make_nvte_1d_shape(1); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), @@ -649,7 +649,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - const NVTEShape& amax_shape = make_nvte_1d_shape(1); + const NVTEShape &amax_shape = make_nvte_1d_shape(1); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_list.emplace_back(make_torch_view( @@ -1000,7 +1000,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // Note that the multi compute amax API expects rowwise amax pointer to be not null // So we need to set the pointer accordingly to make colwise-only quantization work std::vector orig_amax_ptr_list; - const NVTEShape& amax_shape = make_nvte_1d_shape(1); + const NVTEShape &amax_shape = make_nvte_1d_shape(1); for (size_t i = 0; i < num_tensors; i++) { auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; orig_amax_ptr_list.push_back(rowwise_amax_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index ee9cf28684f..5de28696ee0 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -221,8 +221,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Workspace auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - make_nvte_1d_shape(workspaceSize), - DType::kByte); + make_nvte_1d_shape(workspaceSize), DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -380,14 +379,13 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, B_type, nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), make_nvte_2d_shape(D.size(0), D.size(1)), D_type, D_amax.data_ptr(), - D_scale.data_ptr(), nullptr, make_nvte_1d_shape(1)); - auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), make_nvte_1d_shape(bias.size(0)), - bias_type); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), make_nvte_2d_shape(D.size(0), D.size(1)), + D_type, D_amax.data_ptr(), D_scale.data_ptr(), nullptr, + make_nvte_1d_shape(1)); + auto te_bias = + makeTransformerEngineTensor(bias.data_ptr(), make_nvte_1d_shape(bias.size(0)), bias_type); auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), - make_nvte_1d_shape(counter.size(0)), - DType::kInt32); + make_nvte_1d_shape(counter.size(0)), DType::kInt32); NVTEShape gelu_shape; if (pre_gelu_out.data_ptr() == nullptr) { @@ -398,7 +396,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - make_nvte_1d_shape(workspaceSize), DType::kByte); + make_nvte_1d_shape(workspaceSize), DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -497,8 +495,8 @@ std::optional> te_general_grouped_gemm( } DType gelu_type = bias_type; - te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, - gelu_type); + te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); te_A_wrappers.emplace_back(std::move(te_A)); te_B_wrappers.emplace_back(std::move(te_B)); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 21ed3f792ca..dc26d9b959d 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -60,15 +60,13 @@ std::tuple> moe_permute_fwd( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), - make_nvte_2d_shape(input.size(0), input.size(1)), - dtype); + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), make_nvte_2d_shape(input.size(0), input.size(1)), dtype); auto permuted_output_cu = makeTransformerEngineTensor( permuted_output.data_ptr(), make_nvte_2d_shape(permuted_output.size(0), permuted_output.size(1)), dtype); - auto sorted_row_id_cu = - makeTransformerEngineTensor(sorted_row_id_ptr, make_nvte_1d_shape(num_tokens * topK), - DType::kInt32); + auto sorted_row_id_cu = makeTransformerEngineTensor( + sorted_row_id_ptr, make_nvte_1d_shape(num_tokens * topK), DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), From 43b693e9872524778e2d27185fa15508fa98855c Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 1 Jan 2026 07:24:24 +0000 Subject: [PATCH 16/23] minor other change Signed-off-by: Varun Thumbe --- .../pytorch/csrc/extensions/padding.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index 00400e1bee0..65721be9de1 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -34,8 +34,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - auto input_shape = make_nvte_2d_shape(input_row_list[tensor_id], input.size(1)); - input_shape_list.push_back(input_shape); + input_shape_list.push_back(make_nvte_2d_shape(input_row_list[tensor_id], input.size(1))); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -45,8 +44,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - auto output_shape = make_nvte_2d_shape(padded_input_row_list[tensor_id], output.size(1)); - output_shape_list.push_back(output_shape); + output_shape_list.push_back(make_nvte_2d_shape(padded_input_row_list[tensor_id], output.size(1))); } // Construct TE tensors @@ -109,8 +107,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - auto input_shape = make_nvte_2d_shape(input_row_list[tensor_id], input.size(1)); - input_shape_list.push_back(input_shape); + input_shape_list.push_back(make_nvte_2d_shape(input_row_list[tensor_id], input.size(1))); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -120,8 +117,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - auto output_shape = make_nvte_2d_shape(unpadded_input_row_list[tensor_id], output.size(1)); - output_shape_list.push_back(output_shape); + output_shape_list.push_back(make_nvte_2d_shape(unpadded_input_row_list[tensor_id], output.size(1))); } // Construct TE tensors From 9026f1fd5d041d252279481b825f6fbcac832d8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:28:04 +0000 Subject: [PATCH 17/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/padding.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index 65721be9de1..0661803172e 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -44,7 +44,8 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back(make_nvte_2d_shape(padded_input_row_list[tensor_id], output.size(1))); + output_shape_list.push_back( + make_nvte_2d_shape(padded_input_row_list[tensor_id], output.size(1))); } // Construct TE tensors @@ -117,7 +118,8 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back(make_nvte_2d_shape(unpadded_input_row_list[tensor_id], output.size(1))); + output_shape_list.push_back( + make_nvte_2d_shape(unpadded_input_row_list[tensor_id], output.size(1))); } // Construct TE tensors From b833d15c96ac89beb2bce0288d50e49cce4c870d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 1 Jan 2026 19:12:59 +0000 Subject: [PATCH 18/23] minor opt Signed-off-by: Varun Thumbe --- .../transformer_engine/transformer_engine.h | 35 +++++++ transformer_engine/pytorch/csrc/common.cpp | 21 ++--- transformer_engine/pytorch/csrc/common.h | 45 +-------- .../pytorch/csrc/extensions/cast.cpp | 40 ++++---- .../pytorch/csrc/extensions/gemm.cpp | 6 +- .../pytorch/csrc/extensions/recipe.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 92 ++----------------- 7 files changed, 73 insertions(+), 170 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 246a5032e29..4cb6b293f10 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -548,6 +548,26 @@ class NVTEShapeWrapper { data.data[i++] = val; } } + + // Copy constructor + NVTEShapeWrapper(const NVTEShapeWrapper &other) : data(other.data) {} + + // Move constructor from another NVTEShapeWrapper + NVTEShapeWrapper(NVTEShapeWrapper &&other) noexcept : data(other.data) { + other.data.ndim = 0; + } + + // Move constructor from NVTEShape rvalue reference + NVTEShapeWrapper(NVTEShape &&shape) noexcept : data(shape) {} + + // Copy assignment operator + NVTEShapeWrapper &operator=(const NVTEShapeWrapper &other) { + if (this != &other) { + data = other.data; + } + return *this; + } + // In the NVTEShapeWrapper class definition: template NVTEShapeWrapper &operator=(const std::vector &shape_vec) { @@ -568,6 +588,21 @@ class NVTEShapeWrapper { return *this; } + // Move assignment operator from another NVTEShapeWrapper + NVTEShapeWrapper &operator=(NVTEShapeWrapper &&other) noexcept { + if (this != &other) { + data = other.data; + other.data.ndim = 0; + } + return *this; + } + + // Move assignment operator from NVTEShape rvalue reference + NVTEShapeWrapper &operator=(NVTEShape &&shape) noexcept { + data = shape; + return *this; + } + operator NVTEShape &() { return data; } operator const NVTEShape &() const { return data; } diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 0f19e952a76..1122772c865 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -27,10 +27,9 @@ NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose) { return static_cast(ret); } -NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } - -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { +NVTEShape getTensorShape(const at::Tensor& t) { NVTEShape ret; + const c10::IntArrayRef& torch_shape = t.sizes(); ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); NVTE_CHECK(ret.ndim < max_dimensions, @@ -41,7 +40,7 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { ret.data[i] = static_cast(v); } return ret; -} + } template NVTEShape make_nvte_1d_shape(T dim0) { @@ -174,11 +173,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); const size_t meta_shape_data[1] = {1}; - NVTEShape meta_shape; - meta_shape.ndim = 1; - meta_shape.data[0] = 1; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); + ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); @@ -194,11 +190,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - NVTEShape meta_shape; - meta_shape.ndim = 1; - meta_shape.data[0] = 1; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); + ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 : DType::kFloat32; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 6d4cab18c9c..9b5fb4ad901 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -100,8 +100,6 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; /*! @brief Construct a tensor with uninitialized data */ - virtual std::pair create_tensor(const std::vector& shape, - DType dtype) const = 0; virtual std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const = 0; /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor @@ -135,16 +133,10 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; - /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const std::vector& shape, DType dtype, - at::Tensor data) const; - /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype, at::Tensor data) const; @@ -154,9 +146,6 @@ class NoneQuantizer : public Quantizer { void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - private: - template - std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; }; class Float8Quantizer : public Quantizer { @@ -172,12 +161,10 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const std::vector& shape, DType dtype, + std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const; @@ -187,11 +174,6 @@ class Float8Quantizer : public Quantizer { void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - private: - template - std::pair create_tensor_impl( - const ShapeT& shape, DType dtype, std::optional data, - std::optional transpose, std::optional scale_inv) const; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -211,8 +193,6 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. @@ -237,10 +217,6 @@ class Float8CurrentScalingQuantizer : public Quantizer { void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt); - private: - template - std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; - void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; @@ -274,8 +250,6 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; @@ -287,10 +261,6 @@ class Float8BlockQuantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; - private: - template - std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; - template ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; @@ -305,8 +275,6 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; @@ -318,10 +286,6 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; - private: - template - std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; - template ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; @@ -349,8 +313,6 @@ class NVFP4Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer @@ -377,9 +339,6 @@ class NVFP4Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; - private: - template - std::pair create_tensor_impl(const ShapeT& shape, DType dtype) const; template ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; @@ -529,8 +488,6 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); - NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose); // unpack the PhiloxCudaState into CUDA tensor diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ae0cc801fa1..d236d0c5bc2 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -328,17 +328,16 @@ std::tuple, std::vector> bulk_allocate_fp tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); - const NVTEShape &zero_shape = make_nvte_1d_shape(0); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : TensorWrapper::emptyShape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : TensorWrapper::emptyShape, scaling_mode)); } @@ -476,17 +475,16 @@ std::tuple, std::vector> bulk_allocate_mx tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i])); - const NVTEShape &zero_shape = make_nvte_1d_shape(0); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : TensorWrapper::emptyShape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : TensorWrapper::emptyShape, scaling_mode)); } @@ -537,7 +535,7 @@ std::tuple, std::vector, bool> bulk_alloc // Lambda function for converting NVTEShapeWrapper shape to NVFP4 shape (last dim divided by 2) auto to_fp4_shape = [](const NVTEShapeWrapper &shape) { - NVTEShapeWrapper fp4_shape(shape); + NVTEShapeWrapper fp4_shape{shape}; if (!fp4_shape.empty()) { fp4_shape.back() /= 2; } @@ -587,7 +585,6 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - const NVTEShape &amax_shape = make_nvte_1d_shape(1); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), @@ -595,7 +592,7 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_rowwise_list.emplace_back( - make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, TensorWrapper::defaultShape, amax_offsets[i], torch::kFloat32)); } } @@ -649,7 +646,6 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - const NVTEShape &amax_shape = make_nvte_1d_shape(1); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_list.emplace_back(make_torch_view( @@ -657,7 +653,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_columnwise_list.emplace_back( - make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, TensorWrapper::defaultShape, amax_offsets[i], torch::kFloat32)); } } @@ -682,29 +678,27 @@ std::tuple, std::vector, bool> bulk_alloc // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, // then set the amax and amax_columnwise values. - const NVTEShape zero_shape = make_nvte_1d_shape(0); - const NVTEShape amax_shape = make_nvte_1d_shape(1); { auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : TensorWrapper::emptyShape, fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : TensorWrapper::emptyShape, scaling_mode); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { - tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, amax_shape); + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, TensorWrapper::defaultShape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - amax_shape); + TensorWrapper::defaultShape); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 5de28696ee0..5500ff54411 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -211,7 +211,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } NVTEShape gelu_shape; if (!gelu) { - gelu_shape = make_nvte_1d_shape(0); + gelu_shape = TensorWrapper::defaultShape; } else { gelu_shape = D_shape; } @@ -271,7 +271,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { extra_output_tensor = - makeTransformerEngineTensor(nullptr, make_nvte_1d_shape(0), DType::kByte); + makeTransformerEngineTensor(nullptr, TensorWrapper::emptyShape, DType::kByte); } // Direct GEMM call to the correct overlap @@ -381,7 +381,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, // TODO: D_scale_inv cannot be nullptr when D_type is FP8. auto te_D = makeTransformerEngineTensor(D.data_ptr(), make_nvte_2d_shape(D.size(0), D.size(1)), D_type, D_amax.data_ptr(), D_scale.data_ptr(), nullptr, - make_nvte_1d_shape(1)); + TensorWrapper::defaultShape); auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), make_nvte_1d_shape(bias.size(0)), bias_type); auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 63c26ee303b..299b6748c66 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -42,14 +42,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio for (size_t i = 0; i < num_tensors; i++) { te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); NVTETensor& amax_history = te_amax_histories.back(); - NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes()); + NVTEShape amax_shape = getTensorShape(amax_histories[i]); NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(), static_cast(DType::kFloat32), amax_shape}; nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data); te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); NVTETensor& scale = te_scales.back(); - NVTEShape scale_shape = convertTorchShape(scales[i].sizes()); + NVTEShape scale_shape = getTensorShape(scales[i]); NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast(DType::kFloat32), scale_shape}; nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index e825fea32c7..377d75be901 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -70,33 +70,13 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -template -std::pair NoneQuantizer::create_tensor_impl(const ShapeT& shape, +std::pair NoneQuantizer::create_tensor(const NVTEShapeWrapper& shape, DType dtype) const { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } -std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype) const { - return create_tensor_impl(shape, dtype); -} - -std::pair NoneQuantizer::create_tensor(const NVTEShapeWrapper& shape, - DType dtype) const { - return create_tensor_impl(shape, dtype); -} - -std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype, - at::Tensor data) const { - TensorWrapper out_cpp; - out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); - set_quantization_params(&out_cpp); - return {std::move(out_cpp), py::cast(data)}; -} - std::pair NoneQuantizer::create_tensor(const NVTEShapeWrapper& shape, DType dtype, at::Tensor data) const { @@ -130,23 +110,15 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { getTensorShape(amax)); } -std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype) const { - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - at::Tensor scale_inv = at::empty(std::vector{1}, opts); - return create_tensor_impl(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); -} - std::pair Float8Quantizer::create_tensor(const NVTEShapeWrapper& shape, DType dtype) const { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); at::Tensor scale_inv = at::empty(std::vector{1}, opts); - return create_tensor_impl(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); + return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } -template -std::pair Float8Quantizer::create_tensor_impl( - const ShapeT& shape, DType dtype, std::optional data, +std::pair Float8Quantizer::create_tensor( + const NVTEShapeWrapper& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); @@ -210,12 +182,6 @@ std::pair Float8Quantizer::create_tensor_impl( return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional data, - std::optional transpose, std::optional scale_inv) const { - return create_tensor_impl(shape, dtype, std::move(data), std::move(transpose), - std::move(scale_inv)); -} std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { @@ -357,19 +323,9 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); } -std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { - return create_tensor_impl(shape, dtype); -} std::pair Float8CurrentScalingQuantizer::create_tensor( const NVTEShapeWrapper& shape, DType dtype) const { - return create_tensor_impl(shape, dtype); -} - -template -std::pair Float8CurrentScalingQuantizer::create_tensor_impl( - const ShapeT& shape, DType dtype) const { using namespace pybind11::literals; // Initialize data tensor @@ -579,7 +535,7 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); }); // Cast to FP8 - out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates + out.set_amax(nullptr, DType::kFloat32, TensorWrapper::defaultShape); // Avoid atomic amax updates NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); } @@ -592,7 +548,7 @@ void Float8CurrentScalingQuantizer::quantize_with_amax( TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(), "Input does not use the appropriate amax tensor"); - input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + input.set_amax(nullptr, DType::kFloat32, TensorWrapper::defaultShape); this->quantize_impl(input, out, noop_flag, false); } @@ -608,19 +564,8 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { - return create_tensor_impl(shape, dtype); -} - std::pair Float8BlockQuantizer::create_tensor( const NVTEShapeWrapper& shape, DType dtype) const { - return create_tensor_impl(shape, dtype); -} - -template -std::pair Float8BlockQuantizer::create_tensor_impl(const ShapeT& shape, - DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -959,19 +904,8 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { - return create_tensor_impl(shape, dtype); -} - std::pair MXFP8Quantizer::create_tensor(const NVTEShapeWrapper& shape, - DType dtype) const { - return create_tensor_impl(shape, dtype); -} - -template -std::pair MXFP8Quantizer::create_tensor_impl(const ShapeT& shape, - DType dtype) const { + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions @@ -1248,18 +1182,8 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); } -std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { - return create_tensor_impl(shape, dtype); -} std::pair NVFP4Quantizer::create_tensor(const NVTEShapeWrapper& shape, - DType dtype) const { - return create_tensor_impl(shape, dtype); -} - -template -std::pair NVFP4Quantizer::create_tensor_impl(const ShapeT& shape, DType dtype) const { using namespace pybind11::literals; @@ -1786,7 +1710,7 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float), cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); } - input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + input.set_amax(nullptr, DType::kFloat32, TensorWrapper::defaultShape); // Perform quantization this->quantize_impl(input, out, std::nullopt, false); From 5d77edafbb279d5ddb7b5249b3f846fce50cce58 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:14:55 +0000 Subject: [PATCH 19/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../transformer_engine/transformer_engine.h | 6 +-- transformer_engine/pytorch/csrc/common.cpp | 4 +- transformer_engine/pytorch/csrc/common.h | 4 -- .../pytorch/csrc/extensions/cast.cpp | 39 ++++++++++++------- transformer_engine/pytorch/csrc/quantizer.cpp | 7 ++-- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 4cb6b293f10..5d53376b41f 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -553,16 +553,14 @@ class NVTEShapeWrapper { NVTEShapeWrapper(const NVTEShapeWrapper &other) : data(other.data) {} // Move constructor from another NVTEShapeWrapper - NVTEShapeWrapper(NVTEShapeWrapper &&other) noexcept : data(other.data) { - other.data.ndim = 0; - } + NVTEShapeWrapper(NVTEShapeWrapper &&other) noexcept : data(other.data) { other.data.ndim = 0; } // Move constructor from NVTEShape rvalue reference NVTEShapeWrapper(NVTEShape &&shape) noexcept : data(shape) {} // Copy assignment operator NVTEShapeWrapper &operator=(const NVTEShapeWrapper &other) { - if (this != &other) { + if (this != &other) { data = other.data; } return *this; diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 1122772c865..09a05882d33 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -27,7 +27,7 @@ NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose) { return static_cast(ret); } -NVTEShape getTensorShape(const at::Tensor& t) { +NVTEShape getTensorShape(const at::Tensor& t) { NVTEShape ret; const c10::IntArrayRef& torch_shape = t.sizes(); ret.ndim = torch_shape.size(); @@ -40,7 +40,7 @@ NVTEShape getTensorShape(const at::Tensor& t) { ret.data[i] = static_cast(v); } return ret; - } +} template NVTEShape make_nvte_1d_shape(T dim0) { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 9b5fb4ad901..fa1df3f49ed 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -133,7 +133,6 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor(const NVTEShapeWrapper& shape, DType dtype) const override; @@ -145,7 +144,6 @@ class NoneQuantizer : public Quantizer { void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - }; class Float8Quantizer : public Quantizer { @@ -173,7 +171,6 @@ class Float8Quantizer : public Quantizer { void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -339,7 +336,6 @@ class NVFP4Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; - template ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; void quantize_impl(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index d236d0c5bc2..e7da1c911c2 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -332,12 +332,16 @@ std::tuple, std::vector> bulk_allocate_fp tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : TensorWrapper::emptyShape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : TensorWrapper::emptyShape, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) + : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) + : TensorWrapper::emptyShape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : TensorWrapper::emptyShape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : TensorWrapper::emptyShape, + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) + : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) + : TensorWrapper::emptyShape, scaling_mode)); } @@ -479,12 +483,16 @@ std::tuple, std::vector> bulk_allocate_mx tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : TensorWrapper::emptyShape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : TensorWrapper::emptyShape, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) + : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) + : TensorWrapper::emptyShape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : TensorWrapper::emptyShape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : TensorWrapper::emptyShape, + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) + : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) + : TensorWrapper::emptyShape, scaling_mode)); } @@ -682,19 +690,24 @@ std::tuple, std::vector, bool> bulk_alloc auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : TensorWrapper::emptyShape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : TensorWrapper::emptyShape, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) + : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) + : TensorWrapper::emptyShape, fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : TensorWrapper::emptyShape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : TensorWrapper::emptyShape, + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) + : TensorWrapper::emptyShape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) + : TensorWrapper::emptyShape, scaling_mode); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { - tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, TensorWrapper::defaultShape); + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, + TensorWrapper::defaultShape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 377d75be901..52f4fde2c2d 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -71,7 +71,7 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti } std::pair NoneQuantizer::create_tensor(const NVTEShapeWrapper& shape, - DType dtype) const { + DType dtype) const { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); @@ -182,7 +182,6 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } - std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); @@ -905,7 +904,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair MXFP8Quantizer::create_tensor(const NVTEShapeWrapper& shape, - DType dtype) const { + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions @@ -1184,7 +1183,7 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair NVFP4Quantizer::create_tensor(const NVTEShapeWrapper& shape, - DType dtype) const { + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions From 29f84265739dd796a82b5c787423cd7ec8492f13 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 1 Jan 2026 19:30:09 +0000 Subject: [PATCH 20/23] minor cleanup Signed-off-by: Varun Thumbe --- .../pytorch/csrc/extensions/gemm.cpp | 19 ++++-------- .../pytorch/csrc/extensions/permutation.cpp | 29 ++++++------------- .../pytorch/csrc/extensions/transpose.cpp | 3 +- 3 files changed, 15 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 5500ff54411..7378e2717b3 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -209,12 +209,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - NVTEShape gelu_shape; - if (!gelu) { - gelu_shape = TensorWrapper::defaultShape; - } else { - gelu_shape = D_shape; - } + const auto gelu_shape = gelu ? D_shape : TensorWrapper::emptyShape; auto te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); @@ -387,12 +382,9 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), make_nvte_1d_shape(counter.size(0)), DType::kInt32); - NVTEShape gelu_shape; - if (pre_gelu_out.data_ptr() == nullptr) { - gelu_shape = make_nvte_1d_shape(pre_gelu_out.size(0)); - } else { - gelu_shape = make_nvte_2d_shape(pre_gelu_out.size(0), pre_gelu_out.size(1)); - } + const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr ? make_nvte_1d_shape(pre_gelu_out.size(0)) + : make_nvte_2d_shape(pre_gelu_out.size(0), + pre_gelu_out.size(1)); auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), @@ -559,9 +551,8 @@ std::optional> te_general_grouped_gemm( std::vector te_workspace_vector; std::vector te_workspace_wrappers; - const NVTEShape& workspace_shape = make_nvte_1d_shape(workspaceSize); for (size_t i = 0; i < workspace.size(); i++) { - auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), workspace_shape, DType::kByte); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), make_nvte_1d_shape(workspaceSize), DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index dc26d9b959d..d0cb97c8f3a 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -91,13 +91,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - NVTEShape input_shape, unpermuted_output_shape; - input_shape.ndim = 2; - unpermuted_output_shape.ndim = 2; - input_shape.data[0] = static_cast(input.size(0)); - input_shape.data[1] = static_cast(input.size(1)); - unpermuted_output_shape.data[0] = static_cast(unpermuted_output.size(0)); - unpermuted_output_shape.data[1] = static_cast(num_cols); + NVTEShapeWrapper input_shape = make_nvte_2d_shape(input.size(0), input.size(1)); + NVTEShapeWrapper unpermuted_output_shape = make_nvte_2d_shape(unpermuted_output.size(0), unpermuted_output.size(1)); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); auto unpermuted_output_cu = makeTransformerEngineTensor(unpermuted_output.data_ptr(), unpermuted_output_shape, dtype); @@ -125,19 +120,13 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - NVTEShape input_bwd_shape, act_grad_shape, input_fwd_shape; - input_bwd_shape.ndim = 2; - act_grad_shape.ndim = 2; - input_fwd_shape.ndim = 2; - input_bwd_shape.data[0] = static_cast(input_bwd.size(0)); - input_bwd_shape.data[1] = static_cast(num_cols); - act_grad_shape.data[0] = static_cast(act_grad.size(0)); - act_grad_shape.data[1] = static_cast(num_cols); - input_fwd_shape.data[0] = static_cast(input_fwd.size(0)); - input_fwd_shape.data[1] = static_cast(num_cols); - auto input_bwd_cu = makeTransformerEngineTensor(input_bwd.data_ptr(), input_bwd_shape, dtype); - auto act_grad_cu = makeTransformerEngineTensor(act_grad.data_ptr(), act_grad_shape, dtype); - auto input_fwd_cu = makeTransformerEngineTensor(input_fwd.data_ptr(), input_fwd_shape, dtype); + + auto input_bwd_cu = makeTransformerEngineTensor(input_bwd.data_ptr(), + make_nvte_2d_shape(input_bwd.size(0), num_cols), dtype); + auto act_grad_cu = makeTransformerEngineTensor(act_grad.data_ptr(), + make_nvte_2d_shape(act_grad.size(0), num_cols), dtype); + auto input_fwd_cu = makeTransformerEngineTensor(input_fwd.data_ptr(), + make_nvte_2d_shape(input_fwd.size(0), num_cols), dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 78f40245d5f..306d878946b 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -19,8 +19,7 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional transpose_shape_int64; if (shape.size() > 0) { transpose_shape_int64.push_back(shape.back()); From 7d75815e3a6d3dce2f414e0784f0351442e103b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:31:07 +0000 Subject: [PATCH 21/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/gemm.cpp | 9 +++++---- .../pytorch/csrc/extensions/permutation.cpp | 15 ++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 7378e2717b3..f3548cdd186 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -382,9 +382,9 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), make_nvte_1d_shape(counter.size(0)), DType::kInt32); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr ? make_nvte_1d_shape(pre_gelu_out.size(0)) - : make_nvte_2d_shape(pre_gelu_out.size(0), - pre_gelu_out.size(1)); + const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr + ? make_nvte_1d_shape(pre_gelu_out.size(0)) + : make_nvte_2d_shape(pre_gelu_out.size(0), pre_gelu_out.size(1)); auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), @@ -552,7 +552,8 @@ std::optional> te_general_grouped_gemm( std::vector te_workspace_vector; std::vector te_workspace_wrappers; for (size_t i = 0; i < workspace.size(); i++) { - auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), make_nvte_1d_shape(workspaceSize), DType::kByte); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), + make_nvte_1d_shape(workspaceSize), DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index d0cb97c8f3a..734a9754ead 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -92,7 +92,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row auto stream = at::cuda::getCurrentCUDAStream().stream(); NVTEShapeWrapper input_shape = make_nvte_2d_shape(input.size(0), input.size(1)); - NVTEShapeWrapper unpermuted_output_shape = make_nvte_2d_shape(unpermuted_output.size(0), unpermuted_output.size(1)); + NVTEShapeWrapper unpermuted_output_shape = + make_nvte_2d_shape(unpermuted_output.size(0), unpermuted_output.size(1)); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); auto unpermuted_output_cu = makeTransformerEngineTensor(unpermuted_output.data_ptr(), unpermuted_output_shape, dtype); @@ -121,12 +122,12 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T auto stream = at::cuda::getCurrentCUDAStream().stream(); - auto input_bwd_cu = makeTransformerEngineTensor(input_bwd.data_ptr(), - make_nvte_2d_shape(input_bwd.size(0), num_cols), dtype); - auto act_grad_cu = makeTransformerEngineTensor(act_grad.data_ptr(), - make_nvte_2d_shape(act_grad.size(0), num_cols), dtype); - auto input_fwd_cu = makeTransformerEngineTensor(input_fwd.data_ptr(), - make_nvte_2d_shape(input_fwd.size(0), num_cols), dtype); + auto input_bwd_cu = makeTransformerEngineTensor( + input_bwd.data_ptr(), make_nvte_2d_shape(input_bwd.size(0), num_cols), dtype); + auto act_grad_cu = makeTransformerEngineTensor( + act_grad.data_ptr(), make_nvte_2d_shape(act_grad.size(0), num_cols), dtype); + auto input_fwd_cu = makeTransformerEngineTensor( + input_fwd.data_ptr(), make_nvte_2d_shape(input_fwd.size(0), num_cols), dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); From da650a5b2e08e057a2ec5eeed965e9784d791ca0 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 1 Jan 2026 19:36:14 +0000 Subject: [PATCH 22/23] remove uncessary code Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 09a05882d33..af51d6f304d 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -172,7 +172,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); - const size_t meta_shape_data[1] = {1}; ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape); auto scale_inv_dtype = From 9283238661e6cdd47a3a3715d4ab27c42bed5923 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Thu, 1 Jan 2026 19:44:13 +0000 Subject: [PATCH 23/23] other minor cleanups Signed-off-by: Varun Thumbe --- .../pytorch/csrc/extensions/attention.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b40cec24948..1e3b5b76609 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -190,9 +190,6 @@ std::vector fused_attn_fwd( cu_seqlens_kv_padded.value().data_ptr(), static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } - NVTEShape default_scale_inv_shape; - default_scale_inv_shape.ndim = 1; - default_scale_inv_shape.data[0] = 1; if ((page_table_k.has_value()) && (page_table_v.has_value())) { auto page_table_k_sizes = page_table_k.value().sizes().vec(); NVTEShapeWrapper page_table_k_shape{page_table_k_sizes}; @@ -200,10 +197,10 @@ std::vector fused_attn_fwd( NVTEShapeWrapper page_table_v_shape{page_table_v_sizes}; te_page_table_k = makeTransformerEngineTensor( page_table_k.value().data_ptr(), static_cast(page_table_k_shape), - DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); + DType::kInt32, nullptr, nullptr, nullptr, TensorWrapper::defaultShape); te_page_table_v = makeTransformerEngineTensor( page_table_v.value().data_ptr(), static_cast(page_table_v_shape), - DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); + DType::kInt32, nullptr, nullptr, nullptr, TensorWrapper::defaultShape); } // softmax offset @@ -213,7 +210,7 @@ std::vector fused_attn_fwd( NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes}; te_SoftmaxOffset = makeTransformerEngineTensor( SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), - DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); + DType::kFloat32, nullptr, nullptr, nullptr, TensorWrapper::defaultShape); } // extract rng seed and offset @@ -469,16 +466,13 @@ std::vector fused_attn_bwd( NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; - NVTEShape zero_scale_inv_shape; - zero_scale_inv_shape.ndim = 1; - zero_scale_inv_shape.data[0] = 0; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; te_cu_seqlens_q = makeTransformerEngineTensor( cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32, nullptr, - nullptr, nullptr, zero_scale_inv_shape); + nullptr, nullptr, TensorWrapper::emptyShape); te_cu_seqlens_kv = makeTransformerEngineTensor( cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32, - nullptr, nullptr, nullptr, zero_scale_inv_shape); + nullptr, nullptr, nullptr, TensorWrapper::emptyShape); TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {