1+ // see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
2+ #pragma once
3+
4+ #include < cuda.h>
5+ #include < torch/all.h>
6+ #include < c10/cuda/CUDAStream.h>
7+
8+ #include " core/scalar_type.hpp"
9+ #include " cutlass/bfloat16.h"
10+ #include " cutlass/float8.h"
11+
12+ // ElementB is int32 (packed int4)
13+ // ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
14+ template <typename ElementA, typename ElementB, typename ElementC,
15+ typename ElementAccumulator, typename ElementGroupScale>
16+ __global__ void get_group_gemm_starts (
17+ int64_t * expert_offsets, ElementA** a_offsets, ElementB** b_offsets,
18+ ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
19+ ElementAccumulator** b_scales_offsets,
20+ ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int,
21+ ElementB* b_base_as_int, ElementC* out_base_as_int,
22+ ElementAccumulator* a_scales_base_as_int,
23+ ElementAccumulator* b_scales_base_as_int,
24+ ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k,
25+ int64_t scale_k) {
26+ int expert_id = threadIdx .x ;
27+
28+ int64_t expert_offset = expert_offsets[expert_id];
29+
30+ // same as w8a8
31+ a_offsets[expert_id] = a_base_as_int + expert_offset * k;
32+ out_offsets[expert_id] = out_base_as_int + expert_offset * n;
33+ a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset;
34+ b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id);
35+
36+ // w4a8 specific
37+ constexpr int pack_factor = 8 ; // pack 8 int4 into int32
38+ b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor);
39+ b_group_scales_offsets[expert_id] =
40+ b_group_scales_base_as_int + (expert_id * scale_k * n);
41+ }
42+
43+ #define __CALL_GET_STARTS_KERNEL (TENSOR_C_TYPE, C_TYPE ) \
44+ else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
45+ get_group_gemm_starts<cutlass::float_e4m3_t , int32_t , C_TYPE, float , \
46+ cutlass::Array<cutlass::float_e4m3_t , 8 >> \
47+ <<<1 , num_experts, 0 , stream>>> ( \
48+ static_cast <int64_t *>(expert_offsets.data_ptr ()), \
49+ static_cast <cutlass::float_e4m3_t **>(a_ptrs.data_ptr ()), \
50+ static_cast <int32_t **>(b_ptrs.data_ptr ()), \
51+ static_cast <C_TYPE**>(out_ptrs.data_ptr ()), \
52+ static_cast <float **>(a_scales_ptrs.data_ptr ()), \
53+ static_cast <float **>(b_scales_ptrs.data_ptr ()), \
54+ static_cast <cutlass::Array<cutlass::float_e4m3_t , 8 >**>( \
55+ b_group_scales_ptrs.data_ptr ()), \
56+ static_cast <cutlass::float_e4m3_t *>(a_tensors.data_ptr ()), \
57+ static_cast <int32_t *>(b_tensors.data_ptr ()), \
58+ static_cast <C_TYPE*>(out_tensors.data_ptr ()), \
59+ static_cast <float *>(a_scales.data_ptr ()), \
60+ static_cast <float *>(b_scales.data_ptr ()), \
61+ static_cast <cutlass::Array<cutlass::float_e4m3_t , 8 >*>( \
62+ b_group_scales.data_ptr ()), \
63+ n, k, scale_k); \
64+ }
65+
66+ namespace {
67+
68+ void run_get_group_gemm_starts (
69+ torch::Tensor const & expert_offsets, torch::Tensor& a_ptrs,
70+ torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
71+ torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
72+ torch::Tensor& b_group_scales_ptrs, torch::Tensor const & a_tensors,
73+ torch::Tensor const & b_tensors, torch::Tensor& out_tensors,
74+ torch::Tensor const & a_scales, torch::Tensor const & b_scales,
75+ torch::Tensor const & b_group_scales, const int64_t b_group_size) {
76+ TORCH_CHECK (a_tensors.dtype () == torch::kFloat8_e4m3fn );
77+ TORCH_CHECK (b_tensors.dtype () == torch::kInt32 ); // int4 8x packed into int32
78+ TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
79+ TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
80+ TORCH_CHECK (b_group_scales.dtype () ==
81+ torch::kFloat8_e4m3fn ); // the underlying torch type is e4m3
82+ TORCH_CHECK (out_tensors.dtype () ==
83+ torch::kBFloat16 ); // only support bf16 for now
84+ // expect int64_t to avoid overflow during offset calculations
85+ TORCH_CHECK (expert_offsets.dtype () == torch::kInt64 );
86+
87+ int num_experts = static_cast <int >(expert_offsets.size (0 ));
88+ // logical k, n
89+ int64_t n = out_tensors.size (1 );
90+ int64_t k = a_tensors.size (1 );
91+ int64_t scale_k = cutlass::ceil_div (k, b_group_size);
92+
93+ auto stream = at::cuda::getCurrentCUDAStream (a_tensors.device ().index ());
94+
95+ if (false ) {
96+ }
97+ __CALL_GET_STARTS_KERNEL (torch::kBFloat16 , cutlass::bfloat16_t )
98+ __CALL_GET_STARTS_KERNEL (torch::kFloat16 , half)
99+ else {
100+ TORCH_CHECK (false , " Invalid output type (must be float16 or bfloat16)" );
101+ }
102+ }
103+
104+ } // namespace
0 commit comments