diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 8b077f6f1f..1ef1f81e82 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1125,8 +1125,9 @@ template (tile_scheduler_workspace), 0, + sizeof(uint32_t), stream)); // Launch kernel cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; @@ -1308,8 +1308,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz tile_scheduler_workspace, mma, rng_state); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); - - NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream)); } } // namespace @@ -1318,7 +1316,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector &output_list, const size_t *split_sections, size_t num_tensors, const Tensor &hadamard_matrix_, - QuantizationConfig &quant_config, cudaStream_t stream) { + QuantizationConfig &quant_config, Tensor &quant_workspace, + cudaStream_t stream) { NVTE_API_CALL(group_hadamard_transform_cast_fusion); using transformer_engine::detail::kMaxTensorsPerKernel; @@ -1399,6 +1398,12 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector(rng_state_tensor.data.dptr); } + uint32_t *tile_scheduler_workspace = nullptr; + NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided."); + NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t), + "Quantization workspace must be at least 4 bytes."); + tile_scheduler_workspace = reinterpret_cast(quant_workspace.data.dptr); + // Template arguments using TA = cute::bfloat16_t; using TB = cute::bfloat16_t; @@ -1461,7 +1466,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector(rowwise_data_base_ptr), /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), /*args=*/kernel_args, - /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*rng_state=*/rng_state, + /*tile_scheduler_workspace=*/tile_scheduler_workspace, + /*sm_count=*/sm_count, /*stream=*/stream, /*k_tile_size=*/k_tile_size); } else { NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", @@ -1478,7 +1485,7 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, - cudaStream_t stream) { + NVTETensor quant_workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion); using namespace transformer_engine; NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); @@ -1489,6 +1496,8 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso output_list[i] = convertNVTETensorCheck(outputs[i]); } + Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace); + QuantizationConfig quant_config_cpp; if (quant_config != nullptr) { quant_config_cpp = *reinterpret_cast(quant_config); @@ -1497,5 +1506,5 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso // Call the multi-tensor Hadamard transform amax implementation. group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, - stream); + *quant_workspace_tensor, stream); } diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index b6e9719aad..13103cc388 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -115,13 +115,14 @@ void nvte_group_hadamard_transform_cast_fusion_columnwise( * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. * \param[in] num_tensors Number of output tensors, must be > 0. * \param[in] quant_config Quantization configuration. + * \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes. * \param[in] stream CUDA stream used for the operation. */ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix, const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config, - cudaStream_t stream); + NVTETensor quant_workspace, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 3bbc99b444..4e5e5223f7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -872,10 +872,16 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); if (all_aligned_token_dim) { + // allocate a tile scheduler workspace + auto tile_scheduler_workspace_torch = + at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32)); + auto nvte_tile_scheduler_workspace = + makeTransformerEngineTensor(tile_scheduler_workspace_torch); // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose nvte_group_hadamard_transform_cast_fusion( input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); + rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], + nvte_tile_scheduler_workspace.data(), stream); } else { // Separate quantization for rowwise usage and columnwise usage // Rowwise quantization fusion with grouped version