Skip to content

Commit f6227c2

Browse files
authored
[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
1 parent ea657f2 commit f6227c2

File tree

22 files changed

+2045
-101
lines changed

22 files changed

+2045
-101
lines changed

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
874874
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
875875
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
876876
set(SRCS
877-
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu")
877+
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
878+
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
879+
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
880+
)
878881

879882
set_gencode_flags_for_srcs(
880883
SRCS "${SRCS}"

csrc/ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
262262
void get_cutlass_moe_mm_problem_sizes(
263263
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
264264
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
265-
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
265+
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
266+
std::optional<bool> force_swap_ab = std::nullopt);
266267

267268
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
268269
torch::Tensor& problem_sizes1,
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
2+
#pragma once
3+
4+
#include <cuda.h>
5+
#include <torch/all.h>
6+
#include <c10/cuda/CUDAStream.h>
7+
8+
#include "core/scalar_type.hpp"
9+
#include "cutlass/bfloat16.h"
10+
#include "cutlass/float8.h"
11+
12+
// ElementB is int32 (packed int4)
13+
// ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
14+
template <typename ElementA, typename ElementB, typename ElementC,
15+
typename ElementAccumulator, typename ElementGroupScale>
16+
__global__ void get_group_gemm_starts(
17+
int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets,
18+
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
19+
ElementAccumulator** b_scales_offsets,
20+
ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int,
21+
ElementB* b_base_as_int, ElementC* out_base_as_int,
22+
ElementAccumulator* a_scales_base_as_int,
23+
ElementAccumulator* b_scales_base_as_int,
24+
ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k,
25+
int64_t scale_k) {
26+
int expert_id = threadIdx.x;
27+
28+
int64_t expert_offset = expert_offsets[expert_id];
29+
30+
// same as w8a8
31+
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
32+
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
33+
a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset;
34+
b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id);
35+
36+
// w4a8 specific
37+
constexpr int pack_factor = 8; // pack 8 int4 into int32
38+
b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor);
39+
b_group_scales_offsets[expert_id] =
40+
b_group_scales_base_as_int + (expert_id * scale_k * n);
41+
}
42+
43+
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
44+
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
45+
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
46+
cutlass::Array<cutlass::float_e4m3_t, 8>> \
47+
<<<1, num_experts, 0, stream>>>( \
48+
static_cast<int64_t*>(expert_offsets.data_ptr()), \
49+
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
50+
static_cast<int32_t**>(b_ptrs.data_ptr()), \
51+
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
52+
static_cast<float**>(a_scales_ptrs.data_ptr()), \
53+
static_cast<float**>(b_scales_ptrs.data_ptr()), \
54+
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>**>( \
55+
b_group_scales_ptrs.data_ptr()), \
56+
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
57+
static_cast<int32_t*>(b_tensors.data_ptr()), \
58+
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
59+
static_cast<float*>(a_scales.data_ptr()), \
60+
static_cast<float*>(b_scales.data_ptr()), \
61+
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( \
62+
b_group_scales.data_ptr()), \
63+
n, k, scale_k); \
64+
}
65+
66+
namespace {
67+
68+
void run_get_group_gemm_starts(
69+
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
70+
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
71+
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
72+
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
73+
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
74+
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
75+
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
76+
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
77+
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
78+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
79+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
80+
TORCH_CHECK(b_group_scales.dtype() ==
81+
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
82+
TORCH_CHECK(out_tensors.dtype() ==
83+
torch::kBFloat16); // only support bf16 for now
84+
// expect int64_t to avoid overflow during offset calculations
85+
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
86+
87+
int num_experts = static_cast<int>(expert_offsets.size(0));
88+
// logical k, n
89+
int64_t n = out_tensors.size(1);
90+
int64_t k = a_tensors.size(1);
91+
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
92+
93+
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
94+
95+
if (false) {
96+
}
97+
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
98+
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
99+
else {
100+
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
101+
}
102+
}
103+
104+
} // namespace

0 commit comments

Comments
 (0)