-
Notifications
You must be signed in to change notification settings - Fork 603
[NVFP4][MOE] Bug Fix for NVFP4 Grouped Quant #2564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1125,8 +1125,9 @@ template <bool kEnableStochasticRounding, bool kEnableRHTColQuant, bool kEnableR | |
| void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A, | ||
| TB const *B, TQA *QA, TSFA *SFA, | ||
| MultiAmaxHadamardCastFusionArgs &args, | ||
| const size_t *rng_state, uint32_t sm_count, | ||
| cudaStream_t stream, int k_tile_size = 1024) { | ||
| const size_t *rng_state, uint32_t *tile_scheduler_workspace, | ||
| uint32_t sm_count, cudaStream_t stream, | ||
| int k_tile_size = 1024) { | ||
| using namespace cute; | ||
| static int constexpr SFVecSize = 16; | ||
| static int constexpr RhtTensorSize = 16; | ||
|
|
@@ -1295,10 +1296,9 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz | |
| NVTE_CHECK_CUDA( | ||
| cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
|
|
||
| // Allocate workspace and set to zero | ||
| void *tile_scheduler_workspace = nullptr; | ||
| NVTE_CHECK_CUDA(cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream)); | ||
| NVTE_CHECK_CUDA(cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream)); | ||
| // Set workspace and set to zero | ||
| NVTE_CHECK_CUDA(cudaMemsetAsync(reinterpret_cast<void *>(tile_scheduler_workspace), 0, | ||
| sizeof(uint32_t), stream)); | ||
|
|
||
| // Launch kernel | ||
| cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; | ||
|
|
@@ -1308,8 +1308,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz | |
| tile_scheduler_workspace, mma, rng_state); | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); | ||
|
|
||
| NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream)); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
@@ -1318,7 +1316,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz | |
| void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tensor *> &output_list, | ||
| const size_t *split_sections, size_t num_tensors, | ||
| const Tensor &hadamard_matrix_, | ||
| QuantizationConfig &quant_config, cudaStream_t stream) { | ||
| QuantizationConfig &quant_config, Tensor &quant_workspace, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(group_hadamard_transform_cast_fusion); | ||
|
|
||
| using transformer_engine::detail::kMaxTensorsPerKernel; | ||
|
|
@@ -1399,6 +1398,12 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens | |
| rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr); | ||
| } | ||
|
|
||
| uint32_t *tile_scheduler_workspace = nullptr; | ||
| NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided."); | ||
| NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t), | ||
| "Quantization workspace must be at least 4 bytes."); | ||
| tile_scheduler_workspace = reinterpret_cast<uint32_t *>(quant_workspace.data.dptr); | ||
|
|
||
| // Template arguments | ||
| using TA = cute::bfloat16_t; | ||
| using TB = cute::bfloat16_t; | ||
|
|
@@ -1461,7 +1466,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens | |
| /*QA=*/reinterpret_cast<TQA *>(rowwise_data_base_ptr), | ||
| /*SFA=*/reinterpret_cast<TSFA *>(rowwise_scale_inv_base_ptr), | ||
| /*args=*/kernel_args, | ||
| /*rng_state=*/rng_state, /*sm_count=*/sm_count, | ||
| /*rng_state=*/rng_state, | ||
| /*tile_scheduler_workspace=*/tile_scheduler_workspace, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer a more generic workspace name to be honest. Proper handling of this would also require having some function that would return size of the required workspace.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from API level, it's called quant_workspace now |
||
| /*sm_count=*/sm_count, | ||
| /*stream=*/stream, /*k_tile_size=*/k_tile_size); | ||
| } else { | ||
| NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", | ||
|
|
@@ -1478,7 +1485,7 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso | |
| const size_t *split_sections, | ||
| const size_t num_tensors, | ||
| const NVTEQuantizationConfig quant_config, | ||
| cudaStream_t stream) { | ||
| NVTETensor quant_workspace, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion); | ||
| using namespace transformer_engine; | ||
| NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); | ||
|
|
@@ -1489,6 +1496,8 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso | |
| output_list[i] = convertNVTETensorCheck(outputs[i]); | ||
| } | ||
|
|
||
| Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace); | ||
|
|
||
| QuantizationConfig quant_config_cpp; | ||
| if (quant_config != nullptr) { | ||
| quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config); | ||
|
|
@@ -1497,5 +1506,5 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso | |
| // Call the multi-tensor Hadamard transform amax implementation. | ||
| group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors, | ||
| *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, | ||
| stream); | ||
| *quant_workspace_tensor, stream); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we wanted to be fancy, we could add an option to query the workspace size, similar to how we do it for LayerNorm. If the workspace is not provided, we set the
NVTETensorwith the required size. This way the caller doesn't need to know the details of the workspace size.That said, I think this approach is fine for now.