Skip to content

Commit eb8e792

Browse files
zhongbozhutimmoon10pre-commit-ci[bot]
authored
[PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform (#2411)
* rowwise colwise RHT group quant v1 Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * remove local array RW Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * change wait_barrier Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * fast math options Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * use mult to replace div Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * format Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * bulk move random states Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * greptile Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * lint Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * revert to use divides Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * avoid fp32 bf16 round-trip in RHT cast fusion Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * trigger fastmath by toggle NVTE_RHT_CAST_FUSION_USE_FAST_MATH Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * integrate row col rht fusion, functional Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * numerics aligned Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * style Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * remove device sync Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * 128 padding Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * revert colwise rng state creation because of row-col fused kernel Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * fix CI, linter Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * refactor RS for generating two random values Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * Avoid invalid configs with templated kernel Signed-off-by: Tim Moon <tmoon@nvidia.com> * fix acc pipeline init with 0 arrival count Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * restore rowwise-only mode Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * switch to dynamic atomic scheduler Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * Avoid instantiating group RHT+cast kernel without row-wise or col-wise output Signed-off-by: Tim Moon <tmoon@nvidia.com> * Include fast math option in quantization config Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix linter warnings and review nits Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use TE license Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix bug where kernel is always launched on stream Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Restore BF16 intermediate downcast in fused RHT-cast kernels Signed-off-by: Tim Moon <tmoon@nvidia.com> * fix numerical test of grouped kernel Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * Make sure row-wise and col-wise quantization use different RNG seeds Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Restore autoformatter Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent 47902e9 commit eb8e792

19 files changed

+4205
-232
lines changed

benchmarks/linear/benchmark_grouped_linear.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
--set=full \
5454
--kernel-name "GroupHadamardAmaxTmaKernel" \
5555
-s 5 -c 5 \
56-
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
56+
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
5757
5858
"""
5959

@@ -173,7 +173,9 @@ def benchmark_linear(
173173
return timing_ms
174174

175175

176-
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None):
176+
def run_benchmark_linear(
177+
mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False
178+
):
177179
data = []
178180
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
179181

@@ -182,22 +184,22 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
182184
device = "cuda"
183185
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
184186
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
185-
assert m % num_gemms == 0
186-
m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits
187+
m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
187188
# Bias is not supported for GroupedLinear benchmark
188189
bias = None
189190

190191
# Run the benchmark
191192
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
192193
print(f"m_splits: {m_splits}")
194+
print(f"fwd_only: {fwd_only}")
193195

194196
grouped_fwd_bwd_timing_ms = benchmark_linear(
195197
x,
196198
ws,
197199
m_splits,
198200
bias,
199201
recipe_name,
200-
mode="fwd_bwd",
202+
mode="fwd_only" if fwd_only else "fwd_bwd",
201203
num_gemms=num_gemms,
202204
)
203205

@@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
213215
]
214216
)
215217

218+
timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms"
219+
216220
df = pd.DataFrame(
217221
data=data,
218222
columns=[
@@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
221225
"n",
222226
"recipe",
223227
"num_gemms",
224-
"grouped_fwd_bwd_time_ms",
228+
timing_notation,
225229
],
226230
)
227231

@@ -234,7 +238,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
234238
parser = argparse.ArgumentParser()
235239
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
236240
parser.add_argument(
237-
"--output_dir",
241+
"--output-dir",
238242
type=str,
239243
default="benchmark_output/",
240244
help="output path for report",
@@ -266,6 +270,12 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
266270
default=2048,
267271
help="Output dimension to use, default is 2048",
268272
)
273+
parser.add_argument(
274+
"--fwd-only",
275+
action="store_true",
276+
default=False,
277+
help="Run forward pass only, default is both forward and backward passes",
278+
)
269279
args = parser.parse_args()
270280

271281
jagged_input_splits = None
@@ -297,7 +307,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
297307
if jagged_input_splits is not None:
298308
num_gemms_list = [len(jagged_input_splits)]
299309

300-
token_dim_list = [65536]
310+
token_dim_list = [16384, 32768, 65536, 98304]
301311
hidden_dim_list = [7168]
302312
output_dim_list = [2048]
303313

@@ -371,7 +381,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
371381
recipe_name,
372382
use_bias,
373383
num_gemms=num_gemms,
374-
m_splits=jagged_input_splits,
384+
m_splits_provided=jagged_input_splits,
385+
fwd_only=args.fwd_only,
375386
)
376387
df_linears = pd.concat([df_linears, df])
377388

tests/pytorch/nvfp4/test_nvfp4_group_quantize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference(
198198

199199
for i in range(len(x_qx)):
200200
if split_sections[i] == 0:
201-
# then just assert the same same and dtype because the buffer won't be zero out
201+
# then just assert the same shape and dtype because the buffer won't be zero out
202202
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
203203
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
204204
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
@@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference(
221221
# assert with zero tolerance
222222
for i in range(len(x_qx_t)):
223223
if split_sections[i] == 0:
224-
# then just assert the same same and dtype because the buffer won't be zero out
224+
# then just assert the same shape and dtype because the buffer won't be zero out
225225
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
226226
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
227227
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
@@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference(
247247
(1024, 256),
248248
# larger sizes
249249
(8192, 1024),
250+
(16384, 8192),
250251
(16384, 16384),
251252
],
252253
)

transformer_engine/common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
174174
hadamard_transform/group_hadamard_transform.cu
175175
hadamard_transform/hadamard_transform.cu
176176
hadamard_transform/hadamard_transform_cast_fusion.cu
177+
hadamard_transform/group_hadamard_transform_cast_fusion.cu
178+
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
177179
multi_tensor/compute_scale.cu
178180
recipe/mxfp8_scaling.cu
179181
transpose/quantize_transpose_square_blockwise.cu

transformer_engine/common/cast/cast.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
100100
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
101101
}
102102
}
103+
104+
// Group quantize assumes contiguous inputs and outputs in memory allocation
105+
// TODO (zhongbo): find a better way to make it a more generalized API
106+
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
107+
const size_t *split_sections, const size_t num_tensors,
108+
const NVTEQuantizationConfig quant_config,
109+
cudaStream_t stream) {
110+
NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax);
111+
using namespace transformer_engine;
112+
113+
constexpr bool IS_ACT = false;
114+
115+
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
116+
num_tensors, quant_config, stream);
117+
}

transformer_engine/common/cast/dispatch/quantize.cuh

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "../core/common.cuh"
2020
#include "../fp8/quantize_fp8.cuh"
2121
#include "../mxfp8/quantize_mxfp8.cuh"
22+
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
2223
#include "../nvfp4/quantize_nvfp4.cuh"
2324
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
2425

@@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
320321
}
321322
}
322323

324+
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
325+
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
326+
const size_t *split_sections, const size_t num_tensors,
327+
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
328+
using namespace detail;
329+
330+
const Tensor *input_tensor = convertNVTETensorCheck(input);
331+
std::vector<Tensor *> output_tensors;
332+
for (size_t i = 0; i < num_tensors; ++i) {
333+
output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
334+
}
335+
336+
// Quantization config
337+
QuantizationConfig quant_config_cpp;
338+
if (quant_config != nullptr) {
339+
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
340+
}
341+
342+
// Noop flag
343+
Tensor dummy_tensor;
344+
Tensor *noop_tensor = &dummy_tensor;
345+
if (quant_config_cpp.noop_tensor != nullptr) {
346+
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
347+
}
348+
349+
// Check for unsupported options
350+
if (quant_config_cpp.stochastic_rounding) {
351+
NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
352+
"Stochastic rounding is only supported for NVFP4 quantization.");
353+
}
354+
355+
// Take the scaling mode of the first output tensor
356+
auto scaling_mode = output_tensors[0]->scaling_mode;
357+
358+
// Dispatch to quantization kernel depending on data format
359+
switch (scaling_mode) {
360+
case NVTE_NVFP4_1D_SCALING: {
361+
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
362+
363+
// Check tensors
364+
CheckNoopTensor(*noop_tensor, "cast_noop");
365+
CheckInputTensor(*input_tensor, "input");
366+
// Skip checking output tensor list
367+
// output list here is allowed to have empty tensor
368+
369+
// Choose kernel
370+
int32_t rows = input_tensor->flat_first_dim();
371+
int32_t cols = input_tensor->flat_last_dim();
372+
auto dtype = input_tensor->dtype();
373+
374+
NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
375+
"2D quantization is not supported for group quantize.");
376+
377+
// Launch NVFP4 group quantize kernel
378+
nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
379+
*input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
380+
&quant_config_cpp, stream);
381+
break;
382+
}
383+
default:
384+
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
385+
}
386+
}
387+
323388
} // namespace dispatch
324389
} // namespace transformer_engine
325390

0 commit comments

Comments
 (0)