diff --git a/csrc/common.cpp b/csrc/common.cpp index 0a9601689..834453370 100644 --- a/csrc/common.cpp +++ b/csrc/common.cpp @@ -26,10 +26,12 @@ void quantize_block(const quantize_block_args& args) { if (idx < 255) { float dist_left = fabs(normed_value - (args.code[idx])); float dist_right = fabs(normed_value - (args.code[idx + 1])); - if (dist_right < dist_left) { idx += 1; } + if (dist_right < dist_left) { + idx += 1; + } } // 5. store index - args.out[i] = (unsigned char) idx; + args.out[i] = (unsigned char)idx; } } diff --git a/csrc/common.cuh b/csrc/common.cuh index 8c85accfd..d454caa0e 100644 --- a/csrc/common.cuh +++ b/csrc/common.cuh @@ -2,47 +2,48 @@ // TODO: Let's make some of these constexpr and put in a namespace. -#define BNB_CC_MAXWELL 500 -#define BNB_CC_MAXWELL2 520 -#define BNB_CC_MAXWELL2_X1 530 -#define BNB_CC_PASCAL 600 -#define BNB_CC_PASCAL_X2 620 -#define BNB_CC_VOLTA 700 -#define BNB_CC_VOLTA_XAVIER 720 -#define BNB_CC_TURING 750 -#define BNB_CC_AMPERE 800 -#define BNB_CC_AMPERE2 860 -#define BNB_CC_AMPERE2_ORIN 870 -#define BNB_CC_ADA 890 -#define BNB_CC_HOPPER 900 -#define BNB_CC_BLACKWELL 1000 +#define BNB_CC_MAXWELL 500 +#define BNB_CC_MAXWELL2 520 +#define BNB_CC_MAXWELL2_X1 530 +#define BNB_CC_PASCAL 600 +#define BNB_CC_PASCAL_X2 620 +#define BNB_CC_VOLTA 700 +#define BNB_CC_VOLTA_XAVIER 720 +#define BNB_CC_TURING 750 +#define BNB_CC_AMPERE 800 +#define BNB_CC_AMPERE2 860 +#define BNB_CC_AMPERE2_ORIN 870 +#define BNB_CC_ADA 890 +#define BNB_CC_HOPPER 900 +#define BNB_CC_BLACKWELL 1000 -#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1) -#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) -#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) -#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) -#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) +#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1) +#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) +#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) +#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) +#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) -#define BNB_WARP_SIZE 32 +#define BNB_WARP_SIZE 32 // The maximum number of resident threads per SM varies by arch. // For A100/H100 and all prior to Turing, it is 2048, which allows // for 2 full blocks of 1024 threads per SM. -// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability +// Reference: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability #if __CUDA_ARCH__ == 750 -#define BNB_MAX_THREADS_PER_SM 1024 +#define BNB_MAX_THREADS_PER_SM 1024 #elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 -#define BNB_MAX_THREADS_PER_SM 1536 +#define BNB_MAX_THREADS_PER_SM 1536 #else -#define BNB_MAX_THREADS_PER_SM 2048 +#define BNB_MAX_THREADS_PER_SM 2048 #endif // Maximum resident warps per SM is always directly related to the number of threads. -#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) +#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) // Maximum resident blocks per SM may vary. #if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 -#define BNB_MAX_BLOCKS_PER_SM 16 +#define BNB_MAX_BLOCKS_PER_SM 16 #else -#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) +#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) #endif diff --git a/csrc/common.h b/csrc/common.h index e513f2875..c0c9a43be 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -5,21 +5,18 @@ using namespace BinSearch; -#define BLOCK_SIZE 16384 - struct quantize_block_args { - BinAlgo *bin_searcher; - float *code; - float *A; - float *absmax; - unsigned char *out; + BinAlgo* bin_searcher; + float* code; + float* A; + float* absmax; + unsigned char* out; long long block_end; long long block_idx; long long threadidx; - long long blocksize; + long long blocksize; }; - void quantize_block(const quantize_block_args& args); #endif diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e67135360..5c2bc6332 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -4,7 +4,7 @@ using namespace BinSearch; -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { +void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) { for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long block_end = block_idx + valid_items; @@ -13,8 +13,7 @@ void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, lo } } -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) -{ +void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) { // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below code[0] = -1.0f; @@ -28,36 +27,35 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long int thread_wave_size = 256; // we chunk the threads into waves of 256 since the max limit is // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) - for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) - { - long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; - std::vector threads(valid_chunks); - std::vector args(valid_chunks); - - int chunks_processed = 0; - for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) - { - long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; - long long block_end = block_idx + valid_items; - - struct quantize_block_args& arg = args[chunks_processed]; - arg.bin_searcher = &bin_searcher; - arg.code = code; - arg.A = A; - arg.absmax = absmax; - arg.out = out; - arg.block_end = block_end; - arg.block_idx = block_idx; - arg.threadidx = block_idx / blocksize; - arg.blocksize = blocksize; - - threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); - chunks_processed += 1; - if(chunks_processed == valid_chunks){ break; } - } - - for (int i = 0; i < valid_chunks; i++) - threads[i].join(); + for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) { + long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; + std::vector threads(valid_chunks); + std::vector args(valid_chunks); + + int chunks_processed = 0; + for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + + struct quantize_block_args& arg = args[chunks_processed]; + arg.bin_searcher = &bin_searcher; + arg.code = code; + arg.A = A; + arg.absmax = absmax; + arg.out = out; + arg.block_end = block_end; + arg.block_idx = block_idx; + arg.threadidx = block_idx / blocksize; + arg.blocksize = blocksize; + + threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); + chunks_processed += 1; + if (chunks_processed == valid_chunks) { + break; + } + } + + for (int i = 0; i < valid_chunks; i++) + threads[i].join(); } - } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 2ddf81e49..3c10e6d13 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -4,7 +4,7 @@ #include #include -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); +void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); +void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n); #endif diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 7eba3f884..649f2ee1f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3,234 +3,218 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include "kernels.cuh" #include "common.cuh" -#include -#include -#include -#include +#include "kernels.cuh" #include -#include +#include +#include #include +#include #include +#include +#include #include #include - #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 -__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; +__device__ static float nf4_data[16] = { + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0 +}; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { - int* address_as_i = reinterpret_cast(address); - int old = *address_as_i, assumed; - do { - assumed = old; - old = atomicCAS( - reinterpret_cast(address), assumed, - __float_as_int(fmaxf(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS(reinterpret_cast(address), assumed, __float_as_int(fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); } -__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) -{ - float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 111 - return 0.25000000f*absmax*sign; // 1111 - else - return 0.16666667f*absmax*sign; // 1110 - else - if((val & 0b0001) == 1) // 110 - return 0.50000000f*absmax*sign; // 1101 - else - return 0.33333333f*absmax*sign; // 1100 - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 1.00000000f*absmax*sign; // 1011 - else - return 0.66666667f*absmax*sign; // 1010 +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) { + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 111 + return 0.25000000f * absmax * sign; // 1111 + else + return 0.16666667f * absmax * sign; // 1110 + else if ((val & 0b0001) == 1) // 110 + return 0.50000000f * absmax * sign; // 1101 + else + return 0.33333333f * absmax * sign; // 1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 1.00000000f * absmax * sign; // 1011 + else + return 0.66666667f * absmax * sign; // 1010 + else if ((val & 0b0001) == 1) // 100 + return 5.208333333e-03f * absmax * sign; // 1001 else - if((val & 0b0001) == 1) // 100 - return 5.208333333e-03f*absmax*sign; // 1001 - else - return 0.00000000f*absmax*sign; // 1000 + return 0.00000000f * absmax * sign; // 1000 } -__device__ unsigned char dQuantizeFP4(float x) -{ - // FP4 with bias of 3 - // first bit is a sign - // subnormals - // 0b000 = 0 - // 0b001 = 0.0625 - // 0b110 = 2 - // 0b111 = 3 - // 0b100 = 4 - // 0b101 = 6 - // 0b010 = 8 - // 0b011 = 12 - - - // we do a binary search - // the pivots are divided by 12 (the FP4 absmax) - // since we assume input data is in [-1.0, 1.0] - - // !be careful here, its easy to make a mistake - // that is difficult to notice if you add an extra - // zero somewhere! - - int sign = x < 0 ? 0b1000 : 0b0000; - x = fabsf(x); - if(x > 0.29166667f) - if( x > 0.583333f) - if( x > 0.8333333f) - return 0b0011+sign; - else - return 0b0010+sign; - else - if(x > 0.4166667f) - return 0b101+sign; - else - return 0b100+sign; - else - if(x > 0.0859375f) - if(x > 0.20833333f) - return 0b0111+sign; - else - return 0b0110+sign; +__device__ unsigned char dQuantizeFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) + if (x > 0.583333f) + if (x > 0.8333333f) + return 0b0011 + sign; + else + return 0b0010 + sign; + else if (x > 0.4166667f) + return 0b101 + sign; + else + return 0b100 + sign; + else if (x > 0.0859375f) + if (x > 0.20833333f) + return 0b0111 + sign; + else + return 0b0110 + sign; + else if (x > 0.00260417f) + return 0b0001 + sign; else - if(x > 0.00260417f) - return 0b0001+sign; - else - return 0b0000+sign; + return 0b0000 + sign; } -__device__ __forceinline__ float dDequantizeNF4(unsigned char val) -{ +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if((val & 0b1000) == 8) - if((val & 0b0100) == 4) // 1 - if((val & 0b0010) == 2) // 11 - if((val & 0b0001) == 1) // 111 - return 1.0f; - else - return 0.7229568362236023f; - else - if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; - else - return 0.44070982933044434f; - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; - else - return 0.24611230194568634f; - else - if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; - else - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 011 - return 0.0f; + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; else - return -0.09105003625154495f; - else - if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.28444138169288635f; + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; else - return -0.28444138169288635f; + return -0.5250730514526367f; + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; else - if((val & 0b0010) == 2) //00 - if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; - else - return -0.5250730514526367f; - else - if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; - else - return -1.0f; - + return -1.0f; } -__device__ unsigned char dQuantizeNF4(float x) -{ +__device__ unsigned char dQuantizeNF4(float x) { - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if(x > 0.03979014977812767f) - if(x > 0.3893125355243683f) // 1 - if(x > 0.6427869200706482f) // 11 - if(x > 0.8614784181118011f) // 111 - return 0b1111; - else - return 0b1110; - else - if(x > 0.5016634166240692f) // 110 - return 0b1101; - else - return 0b1100; - else - if(x > 0.2035212516784668f) // 10 - if(x > 0.2920137718319893f) // 101 - return 0b1011; - else - return 0b1010; - else - if(x > 0.1202552504837513f) // 100 - return 0b1001; + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if (x > 0.03979014977812767f) + if (x > 0.3893125355243683f) // 1 + if (x > 0.6427869200706482f) // 11 + if (x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else if (x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else if (x > 0.2035212516784668f) // 10 + if (x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else if (x > 0.1202552504837513f) // 100 + return 0b1001; else - return 0b1000; - else - if(x > -0.33967943489551544f) // 0 - if(x > -0.13791173323988914f) // 01 - if(x > -0.045525018125772476f) // 011 - return 0b0111; + return 0b1000; + else if (x > -0.33967943489551544f) // 0 + if (x > -0.13791173323988914f) // 01 + if (x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else if (x > -0.23460740596055984f) // 010 + return 0b0101; else - return 0b0110; - else - if(x > -0.23460740596055984f) // 010 - return 0b0101; + return 0b0100; + else if (x > -0.6106329262256622f) // 00 + if (x > -0.4599952697753906f) // 001 + return 0b0011; else - return 0b0100; + return 0b0010; + else if (x > -0.8480964004993439f) // 000 + return 0b0001; else - if(x > -0.6106329262256622f) // 00 - if(x > -0.4599952697753906f) // 001 - return 0b0011; - else - return 0b0010; - else - if(x > -0.8480964004993439f) // 000 - return 0b0001; - else - return 0b0000; + return 0b0000; } + // sign function for lion // taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA -template __device__ int sgn(T val) -{ - return (T(0) < val) - (val < T(0)); -} +template __device__ int sgn(T val) { return (T(0) < val) - (val < T(0)); } -template -__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) -{ +template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { int pivot = 127; int upper_pivot = 255; int lower_pivot = 0; @@ -240,71 +224,60 @@ __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) float val = smem_code[pivot]; // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 64; i > 0; i>>=1) - { - if(x > val) - { + for (int i = 64; i > 0; i >>= 1) { + if (x > val) { lower_pivot = pivot; lower = val; - pivot+=i; - } - else - { + pivot += i; + } else { upper_pivot = pivot; upper = val; - pivot-=i; + pivot -= i; } val = smem_code[pivot]; } - if(upper_pivot == 255) + if (upper_pivot == 255) upper = smem_code[upper_pivot]; - if(lower_pivot == 0) + if (lower_pivot == 0) lower = smem_code[lower_pivot]; - if(!STOCHASTIC) - { - if(x > val) - { - float midpoint = (upper+val)*0.5f; - if(x > midpoint) - { - return upper_pivot; + if (!STOCHASTIC) { + if (x > val) { + float midpoint = (upper + val) * 0.5f; + if (x > midpoint) { + return upper_pivot; + } else + return pivot; + } else { + float midpoint = (lower + val) * 0.5f; + if (x < midpoint) + return lower_pivot; + else + return pivot; + } + } else { + if (x > val) { + float dist_to_upper = fabsf(upper - x); + float dist_full = upper - val; + if (rand >= dist_to_upper / dist_full) + return upper_pivot; + else + return pivot; + } else { + float dist_to_lower = fabsf(lower - x); + float dist_full = val - lower; + if (rand >= dist_to_lower / dist_full) + return lower_pivot; + else + return pivot; } - else - return pivot; - } - else - { - float midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } - } - else - { - if(x > val) - { - float dist_to_upper = fabsf(upper-x); - float dist_full = upper-val; - if(rand >= dist_to_upper/dist_full) return upper_pivot; - else return pivot; - } - else - { - float dist_to_lower = fabsf(lower-x); - float dist_full = val-lower; - if(rand >= dist_to_lower/dist_full) return lower_pivot; - else return pivot; - } } } template -__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) -{ +__device__ __forceinline__ unsigned char + quantize_2D(float* __restrict__ quadrants, float* __restrict__ const smem_code, float x) { int pivot = 127; int upper_pivot = 255; int lower_pivot = 0; @@ -317,445 +290,414 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran int offset = 1; // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 64; i > 0; i>>=1) - { - if(x > val) - { + for (int i = 64; i > 0; i >>= 1) { + if (x > val) { lower_pivot = pivot; lower = val; - pivot+=i; - //val = i == 64 ? quadrants[2] : smem_code[pivot]; + pivot += i; + // val = i == 64 ? quadrants[2] : smem_code[pivot]; local_pivot += offset; - } - else - { + } else { upper_pivot = pivot; upper = val; - pivot-=i; - //val = i == 64 ? quadrants[0] : smem_code[pivot]; + pivot -= i; + // val = i == 64 ? quadrants[0] : smem_code[pivot]; local_pivot -= offset; } val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; offset -= 1; } - if(x > val) - { - midpoint = (upper+val)*0.5f; - if(x > midpoint) - return upper_pivot; - else - return pivot; - } - else - { - midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; + if (x > val) { + midpoint = (upper + val) * 0.5f; + if (x > midpoint) + return upper_pivot; + else + return pivot; + } else { + midpoint = (lower + val) * 0.5f; + if (x < midpoint) + return lower_pivot; + else + return pivot; } } -__launch_bounds__(TH, 4) -__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) -{ - const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); - int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; - const int base_idx = (blockIdx.x * NUM_BLOCK); - - float vals[NUM]; - unsigned char qvals[NUM]; - //const int lane_id = threadIdx.x % 2; - - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreChar; +__launch_bounds__(TH, 4) __global__ + void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n) { + const int n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x + 1 == gridDim.x) ? n - (blockIdx.x * NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); - __shared__ typename LoadFloat::TempStorage loadf; - __shared__ typename StoreChar::TempStorage storec; - __shared__ float smem_code[256]; - //__shared__ float smem_code[2][257]; + float vals[NUM]; + unsigned char qvals[NUM]; + // const int lane_id = threadIdx.x % 2; - if(threadIdx.x < 256) - { - smem_code[threadIdx.x] = code[threadIdx.x]; - //smem_code[0][threadIdx.x] = code[threadIdx.x]; - //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; - } + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreChar; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) - { - // number of values already processed in blocks + - // number of values already processed in this block + - // rand_offset % mod value - valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + if (threadIdx.x < 256) { + smem_code[threadIdx.x] = code[threadIdx.x]; + // smem_code[0][threadIdx.x] = code[threadIdx.x]; + // smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } - __syncthreads(); - LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_BLOCK) { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); - #pragma unroll 4 - for(int j = 0; j < NUM; j++) - qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); +#pragma unroll 4 + for (int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); - __syncthreads(); - StoreChar(storec).Store(&(out[i]), qvals, valid_items); - } + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } } -template +template //__launch_bounds__(TH, 4) -__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) -{ - const int n_full = gridDim.x * BLOCK_SIZE; - int valid_items = 0; - const int base_idx = (blockIdx.x * BLOCK_SIZE); - - T vals[NUM_PER_TH]; - float rand_vals[NUM_PER_TH]; - unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; - //float local_abs_max = -FLT_MAX; - float local_abs_max = 0.0f; - int local_rand_idx = 0; - - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadFloat; - - __shared__ typename LoadT::TempStorage loadt; - __shared__ typename LoadFloat::TempStorage loadf; - __shared__ typename StoreChar::TempStorage storec; - __shared__ typename BlockReduce::TempStorage reduce; - __shared__ float smem_code[256]; - __shared__ float smem_absmax_value[1]; - - if(DATA_TYPE == General8bit) - for(int i = threadIdx.x; i < 256; i+=blockDim.x) - smem_code[i] = code[i]; - - for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; - local_abs_max = -FLT_MAX; +__global__ void kQuantizeBlockwise( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +) { + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); - __syncthreads(); - LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH]; + // float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore< + unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH, + cub::BLOCK_STORE_WARP_TRANSPOSE> + StoreChar; + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if (DATA_TYPE == General8bit) + for (int i = threadIdx.x; i < 256; i += blockDim.x) + smem_code[i] = code[i]; + + for (int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; - // 1. compute local max - // 2. broadcast local max - // 3. normalize inputs and quantize + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - if (threadIdx.x == 0) { - smem_absmax_value[0] = 1.0f / local_abs_max; - absmax[i / BLOCK_SIZE] = local_abs_max; - } - __syncthreads(); + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); - local_abs_max = smem_absmax_value[0]; + if (threadIdx.x == 0) { + smem_absmax_value[0] = 1.0f / local_abs_max; + absmax[i / BLOCK_SIZE] = local_abs_max; + } + __syncthreads(); - if(STOCHASTIC) - { - local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); - LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); - } + local_abs_max = smem_absmax_value[0]; - unsigned char packed_4bit = 0; - switch(DATA_TYPE) - { + if (STOCHASTIC) { + local_rand_idx = ((blockIdx.x * NUM_BLOCK) + (threadIdx.x * NUM) + rand_offset) % (1024 - 4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch (DATA_TYPE) { case General8bit: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - if(!STOCHASTIC) - qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + if (!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j]) * local_abs_max); else - qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j]) * local_abs_max); } break; case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) - { - packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; - packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); - qvals[j] = packed_4bit; +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); + qvals[j] = packed_4bit; } break; case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) - { - packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; - packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); - qvals[j] = packed_4bit; +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); + qvals[j] = packed_4bit; } break; - } + } - __syncthreads(); - StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); - } + __syncthreads(); + StoreChar(storec).Store( + &(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items + 1) / 2 : valid_items + ); + } } -template -__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) -{ - - const int n_load = (gridDim.x * TILE_SIZE); - int valid_items_load = 0; - int valid_items_store = 0; - const int base_idx = (blockIdx.x * TILE_SIZE); +template +__global__ void + kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) { + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + if (DATA_TYPE > 0) { + valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); + valid_items_store = min(TILE_SIZE * 2, n - i * 2); + } else { + valid_items_load = min(TILE_SIZE, n - i); + valid_items_store = valid_items_load; + } - T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; - unsigned char qvals[NUM_PER_TH]; - float local_abs_max = -FLT_MAX; + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) >> (31 - __clz(blocksize))]); - typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - __shared__ typename LoadChar::TempStorage loadchar; - __shared__ typename StoreT::TempStorage storet; + switch (DATA_TYPE) { + case General8bit: +// load code through read-only cache via __ldg +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]]) * local_abs_max; + break; + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; + } + break; + } - for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) - { - if (DATA_TYPE > 0) - { - valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); - valid_items_store = min(TILE_SIZE * 2, n - i * 2); - } - else - { - valid_items_load = min(TILE_SIZE, n - i); - valid_items_store = valid_items_load; + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, valid_items_store); } +} - // Since blocksize will always be a power-of-2, we avoid more expensive - // division by the blocksize and instead use a shift operation. - // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. - local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); - - __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); +__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n) { + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; - switch (DATA_TYPE) - { - case General8bit: - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; - break; - case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); - vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); - } - break; - case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; - vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; - } - break; + __shared__ float smem_code[256]; + if (threadIdx.x < 256) { + smem_code[threadIdx.x] = code[threadIdx.x]; } __syncthreads(); - StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); - } + + for (int i = idx; i < n; i += numThreads) { + out[i] = smem_code[A[i]]; + } } -__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) -{ - const unsigned int numThreads = blockDim.x * gridDim.x; - const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; +template +__launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +) { + + const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; - __shared__ float smem_code[256]; - if(threadIdx.x < 256) - { - smem_code[threadIdx.x] = code[threadIdx.x]; - } + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; - __syncthreads(); + const float correction1 = 1.0f / (1.0f - powf(beta1, step)); + const float correction2 = 1.0f / (1.0f - powf(beta2, step)); - for (int i = idx;i < n; i += numThreads) - { - out[i] = smem_code[A[i]]; - } -} + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockReduce BlockReduce; + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; -template -__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) -__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) -{ - - const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); - const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); - int valid_items = 0; - - T g_vals[NUM_VALS]; - - float s1_vals[NUM_VALS]; - float s2_vals[NUM_VALS]; - - const float correction1 = 1.0f/(1.0f - powf(beta1, step)); - const float correction2 = 1.0f/(1.0f - powf(beta2, step)); - - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockReduce BlockReduce; - - __shared__ union { - typename Load::TempStorage load; - typename LoadFloat::TempStorage loadf; - typename BlockReduce::TempStorage reduce; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - g_vals[j] = gnorm_scale*((float)g_vals[j]); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - { - switch(OPTIMIZER) - { - case ADAM: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); - s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); - s1_vals[j] *= correction1; - s2_vals[j] *= correction2; - s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update - s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) - break; - } - } - - # pragma unroll NUM_VALS-1 - for(unsigned int j = 1; j < NUM_VALS; j++) - s1_vals[0] += s1_vals[j]; - - __syncthreads(); - s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); - - if(threadIdx.x == 0) - atomicAdd(&unorm[0], s1_vals[0]); - - __syncwarp(); - } -} + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) { + switch (OPTIMIZER) { + case ADAM: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j])); + s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + +#pragma unroll NUM_VALS - 1 + for (unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + if (threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } +} #define NUM_PER_THREAD 4 -template -__launch_bounds__(TH, 1) -__global__ void kOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) -{ - - const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); - const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); - int valid_items = 0; - float update_scale = 0.0f; - T g_vals[NUM_PER_THREAD]; - T p_vals[NUM_PER_THREAD]; - - - float s1_vals[NUM_PER_THREAD]; - float s2_vals[NUM_PER_THREAD]; - - // AdEMAMix has an additional state buffer, which we packed - // into state1. We need thread-local storage here for these. - // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. - float s3_vals[NUM_PER_THREAD]; - - - const float correction1 = 1.0f - powf(beta1, step); - const float correction2 = sqrtf(1.0f - powf(beta2, step)); - const float step_size = -lr*correction2/correction1; - - if(max_unorm > 0.0f) - { - update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; - if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } - else{ update_scale = 1.0f; } - } - else{ update_scale = 1.0f; } - - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; - - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; - - __shared__ union { - typename Load::TempStorage load; - typename Store::TempStorage store; - typename LoadFloat::TempStorage loadf; - typename StoreFloat::TempStorage storef; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); - __syncthreads(); - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - - // Load additional state1 data for AdEMAMix - // TODO: Make constexpr after updating min compiler - if (OPTIMIZER == ADEMAMIX) { +template +__launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +) { + + const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) + + (n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + // AdEMAMix has an additional state buffer, which we packed + // into state1. We need thread-local storage here for these. + // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. + float s3_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr * correction2 / correction1; + + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm) { + update_scale = (max_unorm * param_norm) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; + } + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); - } - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - g_vals[j] = gnorm_scale*((float)g_vals[j]); - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - { - switch(OPTIMIZER) - { - case ADEMAMIX: + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + // Load additional state1 data for AdEMAMix + // TODO: Make constexpr after updating min compiler + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); + } + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { + switch (OPTIMIZER) { + case ADEMAMIX: // m1 update: m1 = beta1 * m1 + (1-beta1) * g s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); @@ -765,241 +707,231 @@ __global__ void kOptimizer32bit2State(T* g, T* p, // nu update: nu = beta2 * nu + (1-beta2) * g^2 s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); - p_vals[j] = (float)p_vals[j] - lr * ( - ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( - (sqrtf(s2_vals[j]) / correction2) + eps - ) - ); + p_vals[j] = (float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / + ((sqrtf(s2_vals[j]) / correction2) + eps)); if (weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); - break; - case ADAM: - - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); - s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); - p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); - - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); - } - break; - } - } - - __syncthreads(); - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); - - if (OPTIMIZER == ADEMAMIX) { + break; + case ADAM: + + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j])); + s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + + (update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (eps * correction2)))); + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + } + break; + } + } + __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); - } - } + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); + } + } } -template -__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) -__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) -{ - - const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); - const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); - int valid_items = 0; - - T g_vals[NUM_VALS]; - - float s1_vals[NUM_VALS]; - - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockReduce BlockReduce; - - __shared__ union { - typename Load::TempStorage load; - typename LoadFloat::TempStorage loadf; - typename BlockReduce::TempStorage reduce; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - g_vals[j] = gnorm_scale*((float)g_vals[j]); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - { - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) +template +__launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +) { + + const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) { + switch (OPTIMIZER) { + case MOMENTUM: + if (step == 1) s1_vals[j] = (float)g_vals[j]; // state update - else - s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update - s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm - break; - case LION: - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update - s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value - s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm - break; - case ADAGRAD: - s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update - s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value - s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm - break; - } - } - - # pragma unroll - for(unsigned int j = 1; j < NUM_VALS; j++) - s1_vals[0] += s1_vals[j]; - - __syncthreads(); - s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); - - if(threadIdx.x == 0) - atomicAdd(&unorm[0], s1_vals[0]); - - __syncwarp(); - } -} + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * (float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = + s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value + s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value + s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm + break; + } + } -template -__launch_bounds__(TH, 1) -__global__ void kOptimizer32bit1State(T *g, T *p, - float *state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) -{ - - const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); - const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); - int valid_items = 0; - float update_scale = 0.0f; - - if(max_unorm > 0.0f) - { - update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; - if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } - else{ update_scale = 1.0f; } - } - else{ update_scale = 1.0f; } - - T g_vals[NUM_PER_THREAD]; - T p_vals[NUM_PER_THREAD]; - - float s1_vals[NUM_PER_THREAD]; - - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; - - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; - - __shared__ union { - typename Load::TempStorage load; - typename Store::TempStorage store; - typename LoadFloat::TempStorage loadf; - typename StoreFloat::TempStorage storef; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - __syncthreads(); - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - { - g_vals[j] = gnorm_scale*((float)g_vals[j]); - if(weight_decay > 0.0f) - g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); - } - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - { - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) - s1_vals[j] = (float)g_vals[j]; - else - s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - - p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); - break; - case LION: - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); - break; - case ADAGRAD: - s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); - break; - } - } - } - - __syncthreads(); - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - } +#pragma unroll + for (unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if (threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } } +template +__launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, + const float beta2, const float eps, const float weight_decay, const int step, const float lr, + const float gnorm_scale, const bool skip_zeros, const int n +) { + + const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) + + (n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm + eps) { + update_scale = (max_unorm * param_norm + eps) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; + } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + if (weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j]) * weight_decay); + } + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + switch (OPTIMIZER) { + case MOMENTUM: + if (step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale * (-lr * (s1_vals[j])); + break; + case LION: + p_vals[j] = + ((float)p_vals[j]) - + update_scale * (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_vals[j])))); + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * ((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - + update_scale * (lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} #define NUM8BIT 16 #define NUM_THREADS 256 #define NUM_PER_BLOCK 4096 -template -__global__ void -__launch_bounds__(NUM_THREADS, 2) -kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - const float gnorm_scale, const int n) -{ +template +__global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit2State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, + float* unorm, const float beta1, const float beta2, const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, const float gnorm_scale, const int n +) { const int n_full = gridDim.x * NUM_PER_BLOCK; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); - int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + int valid_items = + n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_max_s2 = -FLT_MAX; @@ -1015,7 +947,6 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c typedef cub::BlockLoad LoadUInt8; typedef cub::BlockReduce BlockReduce; - __shared__ union { typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; @@ -1025,17 +956,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c __shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles2[256]; - if(threadIdx.x < 256) - { + if (threadIdx.x < 256) { smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; } __syncthreads(); - for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS * gridDim.x * NUM8BIT) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); @@ -1044,38 +973,34 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); __syncthreads(); - #pragma unroll 16 - for(int j = 0; j < NUM8BIT; j++) - { +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; - s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; - s1_vals[j] += (1.0f-beta1)*g_val; + s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0] * beta1; + s1_vals[j] += (1.0f - beta1) * g_val; local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); } - #pragma unroll 16 - for(int j = 0; j < NUM8BIT; j++) - { +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; - s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; - s2_vals[j] += (1.0f-beta2)*g_val*g_val; + s2_vals[j] = smem_quantiles2[r_c2[j]] * max2[0] * beta2; + s2_vals[j] += (1.0f - beta2) * g_val * g_val; local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); } - if(unorm != NULL) - { - #pragma unroll 16 - for(int j = 0; j < NUM8BIT; j++) - { - float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); - float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); - s1_vals[j] *= correction1; - s2_vals[j] *= correction2; - float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update - local_unorm += update_val*update_val; - } + if (unorm != NULL) { +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update + local_unorm += update_val * update_val; + } } } @@ -1083,17 +1008,17 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); __syncthreads(); local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); - if(unorm != NULL) - { - __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + if (unorm != NULL) { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); } - if(threadIdx.x == 0) - { + if (threadIdx.x == 0) { atomicMax(&new_max1[0], local_max_s1); atomicMax(&new_max2[0], local_max_s2); - if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + if (unorm != NULL) { + atomicAdd(&unorm[0], local_unorm); + } } } @@ -1101,20 +1026,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c #define NUM_THREADS2 1024 #define NUM_PER_BLOCK2 4096 -template -__global__ void -__launch_bounds__(NUM_THREADS2, 1) -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, - const float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, - const float gnorm_scale, const int n) -{ - - const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; +template +__global__ void __launch_bounds__(NUM_THREADS2, 1) kOptimizerStatic8bit2State( + T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n +) { + + const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; @@ -1122,19 +1042,22 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float s2_vals[NUM_PER_THREAD2]; const float correction1 = 1.0f - powf(beta1, step); const float correction2 = sqrtf(1.0f - powf(beta2, step)); - const float step_size = -lr*correction2/correction1; - //const float step_size = -lr*correction2/correction1; - float new_max_val1 = 1.0f/new_max1[0]; - float new_max_val2 = 1.0f/new_max2[0]; + const float step_size = -lr * correction2 / correction1; + // const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f / new_max1[0]; + float new_max_val2 = 1.0f / new_max2[0]; float update_scale = 1.0f; - if(max_unorm > 0.0f) - { - update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; - if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } - else{ update_scale = 1.0f; } + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm) { + update_scale = (max_unorm * param_norm) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; } - else{ update_scale = 1.0f; } unsigned char c1s[NUM_PER_THREAD2]; unsigned char c2s[NUM_PER_THREAD2]; @@ -1156,19 +1079,17 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha typename StoreT::TempStorage storeh; } temp_storage; - if(threadIdx.x < 512) - { - if(threadIdx.x < 256) + if (threadIdx.x < 512) { + if (threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; else - smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + smem_quantiles2[threadIdx.x - 256] = quantiles2[threadIdx.x - 256]; } __syncthreads(); - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); @@ -1177,42 +1098,42 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) { + continue; + } - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) - { +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[c1s[j]]; - s1_vals[j] = s1_vals[j]*max1[0]; + s1_vals[j] = s1_vals[j] * max1[0]; - s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val)); - c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) - if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) - { - if(s1_vals[j] > 0.0f) - c1s[j] += 1; - else - c1s[j] -= 1; + if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; } s2_vals[j] = smem_quantiles2[c2s[j]]; - s2_vals[j] = s2_vals[j]*max2[0]; - s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); - c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + s2_vals[j] = s2_vals[j] * max2[0]; + s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j] * new_max_val2); } - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) - { - p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); - if(weight_decay > 0.0f) - p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { + p_vals[j] = (T)(((float)p_vals[j]) + + ((update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (correction2 * eps)))))); + if (weight_decay > 0.0f) + p_vals[j] = update_scale * ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); } StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); @@ -1224,22 +1145,16 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha } } - -template -__global__ void -__launch_bounds__(NUM_THREADS, 2) -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - const float weight_decay, - const float gnorm_scale, const int n) -{ +template +__global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, + const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, + float* new_max1, const float weight_decay, const float gnorm_scale, const int n +) { const int n_full = gridDim.x * NUM_PER_BLOCK; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); - int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + int valid_items = + n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_unorm = 0.0f; @@ -1252,7 +1167,6 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c typedef cub::BlockLoad LoadUInt8; typedef cub::BlockReduce BlockReduce; - __shared__ union { typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; @@ -1261,43 +1175,40 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c __shared__ float smem_quantiles1[256]; - if(threadIdx.x < 256) - smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + if (threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; __syncthreads(); - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS * NUM8BIT) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; __syncthreads(); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); - #pragma unroll 16 - for(int j = 0; j < NUM8BIT; j++) - { +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; - s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; - switch(OPTIMIZER) - { - case ADAGRAD: - case MOMENTUM: - if(step == 1) - s1_vals[j] = (float)g_vals[j]; - else - s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - if(unorm != NULL) - local_unorm += s1_vals[j]*s1_vals[j]; - break; - case LION: - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - break; + s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0]; + switch (OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + if (step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); + if (unorm != NULL) + local_unorm += s1_vals[j] * s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); + break; } local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); @@ -1306,44 +1217,44 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c __syncthreads(); local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); - if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } - if(unorm != NULL) - { - __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); - if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + if (threadIdx.x == 0) { + atomicMax(&new_max1[0], local_max_s1); + } + if (unorm != NULL) { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + if (threadIdx.x == 0) { + atomicAdd(&unorm[0], local_unorm); + } } - } -template -__global__ void -__launch_bounds__(1024, 1) -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - float weight_decay, - const float gnorm_scale, const int n) -{ - - const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; +template +__global__ void __launch_bounds__(1024, 1) kOptimizerStatic8bit1State( + T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, + const int n +) { + + const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; - float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val1 = 1.0f / new_max1[0]; float update_scale = 1.0f; - if(max_unorm > 0.0f) - { - update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; - if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } - else{ update_scale = 1.0f; } + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm) { + update_scale = (max_unorm * param_norm) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; } - else{ update_scale = 1.0f; } unsigned char c1s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; @@ -1363,72 +1274,72 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, typename StoreT::TempStorage storeh; } temp_storage; - if(threadIdx.x < 256) + if (threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; __syncthreads(); - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) { + continue; + } - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) - { +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(weight_decay > 0.0f) { - switch(OPTIMIZER) { - case ADAGRAD: + if (weight_decay > 0.0f) { + switch (OPTIMIZER) { + case ADAGRAD: case MOMENTUM: case RMSPROP: - g_val += ((float)p_vals[j])*weight_decay; - break; + g_val += ((float)p_vals[j]) * weight_decay; + break; case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); - break; - } + p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay); + break; + } } - s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + s1_vals[j] = smem_quantiles1[c1s[j]] * max1[0]; - switch(OPTIMIZER){ - case ADAGRAD: - case MOMENTUM: - if(step == 1) + switch (OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + if (step == 1) s1_vals[j] = g_vals[j]; - else - s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - - p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); - break; - case LION: - p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); - break; + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr * update_scale * (s1_vals[j])); + break; + case LION: + p_vals[j] = + ((float)p_vals[j]) - (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_val)))); + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr * __fdividef(g_val, sqrtf(s1_vals[j]) + eps)); + break; } - c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1); // make sure state1 term has still the same sign after quantization - if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) - { - if(s1_vals[j] > 0.0f) - c1s[j] += 1; - else - c1s[j] -= 1; + if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; } } @@ -1439,80 +1350,56 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, } } +template +__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n) { + const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadT; -template -__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) -{ - const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); - int valid_items = 0; - - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadT; - - __shared__ typename BlockReduce::TempStorage reduce; - - __shared__ typename LoadT::TempStorage loadT; - T vals[NUM_VALS]; - float local_sum = 0.0f; - - for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) - { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; - local_sum = 0.0f; - - __syncthreads(); - LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); - - #pragma unroll NUM_VALS - for(int j = 0; j < NUM_VALS; j++) - local_sum += ((float)vals[j])*((float)vals[j]); - - local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); - if(threadIdx.x == 0) - { - if(step == 1) - { - // initialize with the same norm for all positions - //#pragma unroll 10 - for(int j = 0; j < 100; j++) - atomicAdd(&gnorm_vec[j], local_sum); - } - else - atomicAdd(&gnorm_vec[step % 100], local_sum); - } + __shared__ typename BlockReduce::TempStorage reduce; - } -} + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + +#pragma unroll NUM_VALS + for (int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j]) * ((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if (threadIdx.x == 0) { + if (step == 1) { + // initialize with the same norm for all positions + // #pragma unroll 10 + for (int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + } +} #define LANES 2 #define QUAD 3 -template -__launch_bounds__(256, 3) -__global__ void -kOptimizerStatic8bit2StateBlockwise( - T* p, - T* __restrict__ const g, - unsigned char* state1, - unsigned char* state2, - const float beta1, - const float beta2, - const float beta3, - const float alpha, - const float eps, - const int step, - const float lr, - float* __restrict__ const quantiles1, - float* __restrict__ const quantiles2, - float* absmax1, - float* absmax2, - float weight_decay, - const float gnorm_scale, - const bool skip_zeros, - const int n + +template +__launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, + const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n ) { - //const int n_full = n + (n%BLOCK_SIZE); + // const int n_full = n + (n%BLOCK_SIZE); const int n_full = gridDim.x * BLOCK_SIZE; const int base_idx = (blockIdx.x * BLOCK_SIZE); int valid_items = 0; @@ -1523,8 +1410,8 @@ kOptimizerStatic8bit2StateBlockwise( // 2-5% const float correction1 = 1.0f - __powf(beta1, step); - const float correction2 = sqrtf(1.0f -__powf(beta2, step)); - const float step_size = __fdividef(-lr*correction2,correction1); + const float correction2 = sqrtf(1.0f - __powf(beta2, step)); + const float step_size = __fdividef(-lr * correction2, correction1); const int lane_id = threadIdx.x % LANES; float new_local_abs_max1 = -FLT_MAX; float new_local_abs_max2 = -FLT_MAX; @@ -1538,23 +1425,23 @@ kOptimizerStatic8bit2StateBlockwise( T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles2[LANES][257]; - typedef cub::BlockReduce BlockReduce1; - typedef cub::BlockReduce BlockReduce2; - typedef cub::BlockReduce BlockReduce3; + typedef cub::BlockReduce BlockReduce1; + typedef cub::BlockReduce BlockReduce2; + typedef cub::BlockReduce BlockReduce3; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ typename BlockReduce2::TempStorage reduce2; __shared__ typename BlockReduce2::TempStorage reduce3; __shared__ float smem_exchange1[1]; __shared__ float smem_exchange2[1]; - __shared__ float smem_exchange3[1]; // [[maybe_unused]] + __shared__ float smem_exchange3[1]; // [[maybe_unused]] __shared__ union { typename LoadT::TempStorage loadh; @@ -1562,30 +1449,27 @@ kOptimizerStatic8bit2StateBlockwise( typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; + // init: 0.2 -> 0.23 // 0.23 -> 0.23 - smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; - smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; - # pragma unroll - for(unsigned int j = 1; j < LANES; j++) - { + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; +#pragma unroll + for (unsigned int j = 1; j < LANES; j++) { smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; - } + } __syncthreads(); - #pragma unroll - for(int k = 0; k < QUAD; k++) - { - quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; - quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; +#pragma unroll + for (int k = 0; k < QUAD; k++) { + quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; } - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; __syncthreads(); @@ -1597,146 +1481,134 @@ kOptimizerStatic8bit2StateBlockwise( // AdEMAMix has an additional state packed into state1. if (OPTIMIZER == ADEMAMIX) { - __syncthreads(); - LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); } new_local_abs_max1 = -FLT_MAX; new_local_abs_max2 = -FLT_MAX; new_local_abs_max3 = -FLT_MAX; - // update: 2.48/1.57 -> 2.51/1.60 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) - { - s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; - g_val = g_vals[j]; - //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); - //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; - g_val *= gnorm_scale; - - s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); - - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); - - if (OPTIMIZER == ADEMAMIX) { - // The absmax for the third state is appended to absmax1 - s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE]; - s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); - } - } - else - { - s1_vals[j] = 0.0f; - s2_vals[j] = 0.0f; +// update: 2.48/1.57 -> 2.51/1.60 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]] * absmax2[i / BLOCK_SIZE]; + g_val = g_vals[j]; + // float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + // g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val)); + + if (OPTIMIZER == ADEMAMIX) { + // The absmax for the third state is appended to absmax1 + s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i) / BLOCK_SIZE]; + s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); + } + } else { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; - if (OPTIMIZER == ADEMAMIX) { - s3_vals[j] = 0.0f; - } + if (OPTIMIZER == ADEMAMIX) { + s3_vals[j] = 0.0f; + } } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); + new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); } } - // reduce: 2.51/1.60 -> 2.67/1.69 new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max()); + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max()); } - if(threadIdx.x == 0) - { - smem_exchange1[0] = new_local_abs_max1; - smem_exchange2[0] = new_local_abs_max2; + if (threadIdx.x == 0) { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; - if (OPTIMIZER == ADEMAMIX) { - smem_exchange3[0] = new_local_abs_max3; - } + if (OPTIMIZER == ADEMAMIX) { + smem_exchange3[0] = new_local_abs_max3; + } } __syncthreads(); - if(threadIdx.x == 0) - { - absmax1[i/BLOCK_SIZE] = new_local_abs_max1; - absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + if (threadIdx.x == 0) { + absmax1[i / BLOCK_SIZE] = new_local_abs_max1; + absmax2[i / BLOCK_SIZE] = new_local_abs_max2; - if (OPTIMIZER == ADEMAMIX) { - absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3; - } - } - else - { - new_local_abs_max1 = smem_exchange1[0]; - new_local_abs_max2 = smem_exchange2[0]; + if (OPTIMIZER == ADEMAMIX) { + absmax1[(n + i) / BLOCK_SIZE] = new_local_abs_max3; + } + } else { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; - if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = smem_exchange3[0]; - } + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = smem_exchange3[0]; + } } __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - // reduce: 2.67/1.69 -> 2.67/1.70 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) - { - if (OPTIMIZER == ADEMAMIX) { - p_vals[j] = T((float)p_vals[j] - lr * ( - ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( - (sqrtf(s2_vals[j]) / correction2) + eps - ) - )); - } else { - p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); - } - - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); - } +// reduce: 2.67/1.69 -> 2.67/1.70 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + // if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { + if (OPTIMIZER == ADEMAMIX) { + p_vals[j] = + T((float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / + ((sqrtf(s2_vals[j]) / correction2) + eps))); + } else { + p_vals[j] = + (T)(((float)p_vals[j]) + + ((step_size * (__fdividef(s1_vals[j], (sqrtf(s2_vals[j]) + (correction2 * eps))))))); + } + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - // quantizaztion: 2.67/1.70 -> 3.4/3.3 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); - c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); +// quantizaztion: 2.67/1.70 -> 3.4/3.3 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j], new_local_abs_max2)); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) - if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) - { - if(s1_vals[j] > 0.0f) - c1s[j] += 1; - else - c1s[j] -= 1; + if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; } if (OPTIMIZER == ADEMAMIX) { - c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3)); + c3s[j] = + quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j], new_local_abs_max3)); - if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { - c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; - } + if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { + c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; + } } } @@ -1746,28 +1618,23 @@ kOptimizerStatic8bit2StateBlockwise( StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); if (OPTIMIZER == ADEMAMIX) { - __syncthreads(); - StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); } } } - #define LANES 2 #define QUAD 3 -template -__launch_bounds__(256, 3) -__global__ void -kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* absmax1, - float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n) -{ - - //const int n_full = n + (n%BLOCK_SIZE); + +template +__launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, + const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n +) { + + // const int n_full = n + (n%BLOCK_SIZE); const int n_full = gridDim.x * BLOCK_SIZE; const int base_idx = (blockIdx.x * BLOCK_SIZE); int valid_items = 0; @@ -1780,16 +1647,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char unsigned char c1s[N_PER_TH]; T g_vals[N_PER_TH]; - T p_vals[N_PER_TH]; + T p_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; - typedef cub::BlockReduce BlockReduce1; + typedef cub::BlockReduce BlockReduce1; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ float smem_exchange1[1]; @@ -1799,22 +1666,22 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; + // init: 0.2 -> 0.23 // 0.23 -> 0.23 - smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; - # pragma unroll - for(unsigned int j = 1; j < LANES; j++) - smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; +#pragma unroll + for (unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; __syncthreads(); - #pragma unroll - for(int k = 0; k < QUAD; k++) - quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; +#pragma unroll + for (int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; __syncthreads(); @@ -1826,112 +1693,104 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char new_local_abs_max1 = -FLT_MAX; - // update: 2.48/1.57 -> 2.51/1.60 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { +// update: 2.48/1.57 -> 2.51/1.60 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - if(weight_decay > 0.0f) { - switch(OPTIMIZER) { - case MOMENTUM: - case ADAGRAD: - case RMSPROP: - g_val += ((float)p_vals[j])*weight_decay; + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + if (weight_decay > 0.0f) { + switch (OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j]) * weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE]; + + switch (OPTIMIZER) { + case MOMENTUM: + if (step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j] * beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, + // before the momentum is updated by beta2 + g_vals[j] = lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * g_val)); + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); break; - case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val * g_val); break; } - } - - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) - s1_vals[j] = g_val; - else - s1_vals[j] = (s1_vals[j]*beta1) + g_val; - break; - case LION: - // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 - g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - break; - case ADAGRAD: - s1_vals[j] = s1_vals[j] + (g_val*g_val); - break; - } - } + } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); } - // reduce: 2.51/1.60 -> 2.67/1.69 new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); - if(threadIdx.x == 0) - smem_exchange1[0] = new_local_abs_max1; + if (threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; __syncthreads(); - if(threadIdx.x == 0) - absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + if (threadIdx.x == 0) + absmax1[i / BLOCK_SIZE] = new_local_abs_max1; else - new_local_abs_max1 = smem_exchange1[0]; - - // reduce: 2.67/1.69 -> 2.67/1.70 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - switch(OPTIMIZER) - { - case MOMENTUM: - p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); - break; - case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); - break; - case RMSPROP: - g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); - break; - case ADAGRAD: - g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); - break; - } - } - } + new_local_abs_max1 = smem_exchange1[0]; + +// reduce: 2.67/1.69 -> 2.67/1.70 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + switch (OPTIMIZER) { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr * (s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps)); + break; + } + } + } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - // quantizaztion: 2.67/1.70 -> 3.4/3.3 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); +// quantizaztion: 2.67/1.70 -> 3.4/3.3 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1)); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) - if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) - { - if(s1_vals[j] > 0.0f) - c1s[j] += 1; - else - c1s[j] -= 1; + if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; } } @@ -1945,978 +1804,1010 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char // Outputs: // rowStats [rows] // out [rows, cols] -template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) -__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ + void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { - // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. - // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. + // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. + // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. #if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE - using TReduction = T; + using TReduction = T; #else - using TReduction = float; + using TReduction = float; #endif - using BlockReduceT = cub::BlockReduce; + using BlockReduceT = cub::BlockReduce; - // One block per row. - // Threads load column values in a striped arrangement. - // e.g. t0 reads row[0], row[0+nthreads], .. - // and t1 reads row[1], row[1+nthreads], .. - // Each thread will determine its local absmax. - // We then do a blockwise reduction to determine the row's absmax. + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. - __shared__ typename BlockReduceT::TempStorage temp_storage; - __shared__ TReduction smem_row_absmax; + __shared__ typename BlockReduceT::TempStorage temp_storage; + __shared__ TReduction smem_row_absmax; - const int row_id = blockIdx.x; - const T* row_data = A + (row_id * cols); + const int row_id = blockIdx.x; + const T* row_data = A + (row_id * cols); - // Threads will read the row values in a striped access pattern and find a local absmax. - TReduction row_local_absmax = -FLT_MIN; - for (int i = threadIdx.x; i < cols; i += THREADS) { - const TReduction absval = fabsf(__ldcs(&(row_data[i]))); + // Threads will read the row values in a striped access pattern and find a local absmax. + TReduction row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const TReduction absval = fabsf(__ldcs(&(row_data[i]))); - // For sparse decomposition, values outside of the threshold are not to be - // included when calculating the row's absmax. - if constexpr (SPARSE_DECOMP) { - row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); - } else { - row_local_absmax = fmaxf(row_local_absmax, absval); - } - } - - // Reduce thread-local absmax across the block. - const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); - if (threadIdx.x == 0) { - // Save our block's absmax to shared memory for the quantization step. - rowStats[row_id] = smem_row_absmax = row_absmax; - } - __syncthreads(); - - // Quantize row-wise. - const float scale = __fdividef(127.0f, smem_row_absmax); - for (int i = threadIdx.x; i < cols; i += THREADS) { - float val = row_data[i]; - - if constexpr (SPARSE_DECOMP) { - // For sparse decomposition, we do not want to quantize the outliers. - // Instead they're zeroed out. - out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; - } else { - out[row_id * cols + i] = __float2int_rn(val * scale); + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } } - } -} - -template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) -__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { - using BlockReduceT = cub::BlockReduce; - // One block per row. - // Threads load column values in a striped arrangement. - // e.g. t0 reads row[0], row[0+nthreads], .. - // and t1 reads row[1], row[1+nthreads], .. - // Each thread will determine its local absmax. - // We then do a blockwise reduction to determine the row's absmax. - - __shared__ typename BlockReduceT::TempStorage temp_storage; + // Reduce thread-local absmax across the block. + const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = smem_row_absmax = row_absmax; + } + __syncthreads(); - const int row_id = blockIdx.x; - const T* __restrict__ row_data = A + (row_id * cols); + // Quantize row-wise. + const float scale = __fdividef(127.0f, smem_row_absmax); + for (int i = threadIdx.x; i < cols; i += THREADS) { + float val = row_data[i]; + + if constexpr (SPARSE_DECOMP) { + // For sparse decomposition, we do not want to quantize the outliers. + // Instead they're zeroed out. + out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; + } else { + out[row_id * cols + i] = __float2int_rn(val * scale); + } + } +} - // Threads will read the row values in a striped access pattern and find a local absmax. - float row_local_absmax = -FLT_MIN; - for (int i = threadIdx.x; i < cols; i += THREADS) { - const float absval = fabsf(row_data[i]); +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ + void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) { + using BlockReduceT = cub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(row_data[i]); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } - // For sparse decomposition, values outside of the threshold are not to be - // included when calculating the row's absmax. - if constexpr (SPARSE_DECOMP) { - row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); - } else { - row_local_absmax = fmaxf(row_local_absmax, absval); + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = row_absmax; } - } - - // Reduce thread-local absmax across the block. - // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY - const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); - if (threadIdx.x == 0) { - // Save our block's absmax to shared memory for the quantization step. - rowStats[row_id] = row_absmax; - } } -template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); -template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); - -template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); -template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); +template __global__ void + kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols); +template __global__ void + kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols +); +template __global__ void kInt8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols +); -#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) +#define MM_DEQUANT_CONST 6.200012e-05f // 1.0f/(127.0f*127.0f) template __global__ void kdequant_mm_int32_fp16( - int* __restrict__ const A, - float *__restrict__ const rowStats, - float *__restrict__ const colStats, - half *out, - half *__restrict__ const bias, - const int numRows, - const int numCols, - const int n + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n ) { - const int n_out = numRows * numCols; + const int n_out = numRows * numCols; - int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; - int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; - int local_values[ITEMS_PER_THREAD]; - half local_output[ITEMS_PER_THREAD]; + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; - float local_rowStats[ITEMS_PER_THREAD]; - float local_colStats[ITEMS_PER_THREAD]; - float local_biasValue[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; + float local_colStats[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; - typedef cub::BlockLoad LoadInt32; - __shared__ typename LoadInt32::TempStorage loadint32; + typedef cub::BlockLoad LoadInt32; + __shared__ typename LoadInt32::TempStorage loadint32; - int row_idx, col_idx; + int row_idx, col_idx; - #pragma unroll ITEMS_PER_THREAD - for (int j = 0; j < ITEMS_PER_THREAD; ++j) { +#pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { - row_idx = (block_offset + thread_offset + j) / numCols; - col_idx = (block_offset + thread_offset + j) % numCols; + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; - local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]); - local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]); - local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); - } + local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]); + local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]); + local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); + } - // Each block loads THREADS * ITEMS_PER_THREAD values from A - int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out - ? THREADS * ITEMS_PER_THREAD - : n_out - block_offset; - LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + // Each block loads THREADS * ITEMS_PER_THREAD values from A + int valid_items = + block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); - #pragma unroll ITEMS_PER_THREAD - for (int j = 0; j < ITEMS_PER_THREAD; ++j) { - local_output[j] = __float2half( - fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) - ); - } +#pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + local_output[j] = __float2half( + fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) + ); + } - #pragma unroll ITEMS_PER_THREAD - for (int j = 0; j < ITEMS_PER_THREAD; j++) { - int outIdx = block_offset + thread_offset + j; - if (outIdx < n_out) { - out[outIdx] = local_output[j]; +#pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int outIdx = block_offset + thread_offset + j; + if (outIdx < n_out) { + out[outIdx] = local_output[j]; + } } - } } -#define DENORM 1.0f/127.0f +#define DENORM 1.0f / 127.0f #define MAX_SPARSE_COUNT 32 -#define SMEM_SIZE 8*256 +#define SMEM_SIZE 8 * 256 + template -__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) -{ - - // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block - // If a block finishes, the next one is scheduled. Since the last blocks like have fewer - // elements they finish faster "fillin up" the gaps left by larger blocks - - // without tensor cores - // 1. use rowidx_length to find what to load (as many blocks as there are rows) - // 2. Load A into registers - // 3. each warp loads all required rows of B but each warp is offset by k - // 4. Do mma operations that accumulate into registers - // 5. Each warp stores its output row into matrix C - - const int count = max_count[blockIdx.x]; - const int local_max_idx = max_idx[blockIdx.x]; - const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; - const int local_row_idx = rowidx[offset]; - - const int warp_id = threadIdx.x / 32; - const int warp_idx = threadIdx.x % 32; - const int warp_offset = (warp_id*32)*SPMM_ITEMS; - const int num_items = BITS == 8 ? 8 : 8; - int idx_col_B = warp_offset; - int local_idx_col_B_offset = 0; - - half local_valA[MAX_SPARSE_COUNT]; - int local_colidxA[MAX_SPARSE_COUNT]; - half local_valC[SPMM_ITEMS]; - T local_valsB[num_items]; - half local_valOut[num_items]; - // 128 byte loads per warp == 4 bytes per thread - - // 2. Load A into registers - for(int j = 0; j < MAX_SPARSE_COUNT; j++) - { - local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); - local_colidxA[j] = j < count ? colidx[offset+j] : 0; - } - - // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 - // we expect each warp to be SPMM_ITEMS*32 apart - // we have a total of 128 bytes for the bank with a bank size of 4 bytes - // added 3 bytes = 6 values between warps should reduce bank conflicts - __shared__ half smem_dequant_stats[SMEM_SIZE]; - - - while(idx_col_B < colsB) - { - - if(dequant_stats != NULL) - { - for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) - if((idx_col_B+i-local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; - - __syncthreads(); +__global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +) { + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx - 1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id * 32) * SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for (int j = 0; j < MAX_SPARSE_COUNT; j++) { + local_valA[j] = j < count ? values[offset + j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset + j] : 0; } - #pragma unroll SPMM_ITEMS - for(int j = 0; j < SPMM_ITEMS; j++) - local_valC[j] = 0.0f; - - #pragma unroll - for(int i = 0; i < count; i++) - { - // 3. each warp loads all required rows of B but each warp is offset by k - int row_offset = colsB*local_colidxA[i]; - - #pragma unroll SPMM_ITEMS - for(int j = 0; j < SPMM_ITEMS; j+=num_items) - { - // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached - int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; - if(idx >= colsB){ break; } - if((idx+num_items < colsB)) - { - if(BITS == 8) - reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; - else - reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; - } - else - { - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - if(idx+k < colsB) - local_valsB[k] = B[row_offset+idx+k]; - else - local_valsB[k] = 0.0f; - } - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - { - if(BITS == 8 && dequant_stats != NULL) - // we do texture cache reads (__ldg) on dequant_stats which should be super fast - { - float valB = local_valsB[k]; - float valA = local_valA[i]; - if(valB != 0.0 && valA != 0.0) - local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + while (idx_col_B < colsB) { + + if (dequant_stats != NULL) { + for (int i = threadIdx.x; i < SMEM_SIZE; i += blockDim.x) + if ((idx_col_B + i - local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B + i - local_idx_col_B_offset]; + + __syncthreads(); + } + +#pragma unroll SPMM_ITEMS + for (int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + +#pragma unroll + for (int i = 0; i < count; i++) { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB * local_colidxA[i]; + +#pragma unroll SPMM_ITEMS + for (int j = 0; j < SPMM_ITEMS; j += num_items) { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx * SPMM_ITEMS) + j; + if (idx >= colsB) { + break; + } + if ((idx + num_items < colsB)) { + if (BITS == 8) + reinterpret_cast(local_valsB)[0] = + reinterpret_cast(B)[(row_offset + idx) / num_items]; + else + reinterpret_cast(local_valsB)[0] = + reinterpret_cast(B)[(row_offset + idx) / num_items]; + } else { +#pragma unroll num_items + for (int k = 0; k < num_items; k++) + if (idx + k < colsB) + local_valsB[k] = B[row_offset + idx + k]; + else + local_valsB[k] = 0.0f; + } +#pragma unroll num_items + for (int k = 0; k < num_items; k++) { + if (BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if (valB != 0.0 && valA != 0.0) + local_valC[j + k] = + (float)local_valC[j + k] + + ((float)smem_dequant_stats[idx + k - local_idx_col_B_offset]) * DENORM * valB * valA; + } else + local_valC[j + k] = (float)local_valC[j + k] + (float)local_valsB[k] * (float)local_valA[i]; + } } - else - local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; - } } - } - int idx_row_C = (colsB*local_row_idx); - - #pragma unroll SPMM_ITEMS - for(int j = 0; j < SPMM_ITEMS; j+=num_items) - { - //int idx_col_C = idx_col_B + (32*j) + warp_idx; - int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; - int idx_val = idx_col_C + idx_row_C; - - if(idx_col_C +num_items < colsB) - { - - // load outputs to do inplace addition - reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; - - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; - - reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; - } - else - { - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - if(idx_col_C + k < colsB) - out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; - } - } + int idx_row_C = (colsB * local_row_idx); + +#pragma unroll SPMM_ITEMS + for (int j = 0; j < SPMM_ITEMS; j += num_items) { + // int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx * SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if (idx_col_C + num_items < colsB) { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = + reinterpret_cast(out)[idx_val / num_items]; + +#pragma unroll num_items + for (int k = 0; k < num_items; k++) + local_valC[(j / num_items) + k] = (float)local_valC[(j / num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val / num_items] = + reinterpret_cast(local_valC)[j / num_items]; + } else { +#pragma unroll num_items + for (int k = 0; k < num_items; k++) + if (idx_col_C + k < colsB) + out[idx_val + k] = (float)out[idx_val + k] + (float)local_valC[j + k]; + } + } - idx_col_B += blockDim.x*SPMM_ITEMS; - local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; - } + idx_col_B += blockDim.x * SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x * SPMM_ITEMS; + } } #define WARPS 3 -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) -{ + +template +__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc) { #if __CUDA_ARCH__ >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x *32; - const int warp_id = threadIdx.x / 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; - const int batch_size_warps = (WARPS-1)*2; - const int val_per_iter = blockDim.x-32; - - T local_A[4]; - T local_B[128]; - - const int a_tile_offset = 16; - const int b_tile_offset = (16*32 + 16); - - __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - //__shared__ T smem_C[8*32]; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - - int ticktock = 0; - int idx = 0 + threadIdx.x; - int loaded_values = 0; - // prefetch - if(idx < K && warp_id < (WARPS-1)) - { - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; - } - loaded_values = 3; - } - else - { - - if(loaded_values == 3) - { - local_A[0] = local_A[1]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(32)]; - } - else if(loaded_values == 2) - { - local_A[0] = local_A[2]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(64)]; - } - else - { - local_A[0] = local_A[3]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(96)]; - } - loaded_values--; - } + using namespace nvcuda; + int col_offset = blockIdx.x * 32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS - 1) * 2; + const int val_per_iter = blockDim.x - 32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16 * 32 + 16); + + __shared__ T smem_A[8 * 16 + (2 * 16 * (batch_size_warps - 1))]; + __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if (idx < K && warp_id < (WARPS - 1)) { + if (loaded_values == 0) { + local_A[0] = A[idx]; + local_A[1] = A[idx + (1 * val_per_iter)]; + local_A[2] = A[idx + (2 * val_per_iter)]; + local_A[3] = A[idx + (3 * val_per_iter)]; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) { + local_B[col] = B[(col_offset + col) * ldb + idx]; + local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; + local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; + local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; + } + loaded_values = 3; + } else { + + if (loaded_values == 3) { + local_A[0] = local_A[1]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = local_B[col + (32)]; + } else if (loaded_values == 2) { + local_A[0] = local_A[2]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = local_B[col + (64)]; + } else { + local_A[0] = local_A[3]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = local_B[col + (96)]; + } + loaded_values--; + } - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + local_B[col]; + } else if (warp_id < (WARPS - 1)) { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = 0.0f; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; - //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - { - idx = base_idx + threadIdx.x; + // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { + idx = base_idx + threadIdx.x; - __syncthreads(); - if(idx < K && warp_id < (WARPS-1)) - { - //local_A[0] = A[idx]; - - //#pragma unroll 32 - //for(int col = 0; col < 32; col++) - // local_B[col] = B[(col_offset+col)*ldb+idx]; - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; - } - loaded_values = 3; - - } - else - { - - if(loaded_values == 3) - { - local_A[0] = local_A[1]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(32)]; - } - else if(loaded_values == 2) - { - local_A[0] = local_A[2]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(64)]; - } - else - { - local_A[0] = local_A[3]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(96)]; - } - loaded_values--; - } + __syncthreads(); + if (idx < K && warp_id < (WARPS - 1)) { + // local_A[0] = A[idx]; + + // #pragma unroll 32 + // for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if (loaded_values == 0) { + local_A[0] = A[idx]; + local_A[1] = A[idx + (1 * val_per_iter)]; + local_A[2] = A[idx + (2 * val_per_iter)]; + local_A[3] = A[idx + (3 * val_per_iter)]; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) { + local_B[col] = B[(col_offset + col) * ldb + idx]; + local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; + local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; + local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; + } + loaded_values = 3; + + } else { + + if (loaded_values == 3) { + local_A[0] = local_A[1]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = local_B[col + (32)]; + } else if (loaded_values == 2) { + local_A[0] = local_A[2]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = local_B[col + (64)]; + } else { + local_A[0] = local_A[3]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = local_B[col + (96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + local_B[col]; + } else if (warp_id < (WARPS - 1)) { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = 0.0f; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if (warp_id == (WARPS - 1)) + for (int k = 0; k < batch_size_warps; k++) { + wmma::load_matrix_sync( + a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 + ); // 111 mu + wmma::load_matrix_sync( + b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 + ); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + + __syncthreads(); + if (warp_id != (WARPS - 1)) { + return; } - ticktock = ticktock == 0 ? 1 : 0; + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; - if(warp_id == (WARPS-1)) - for(int k = 0; k < batch_size_warps; k++) - { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + ticktock = ticktock == 0 ? 1 : 0; + for (int k = 0; k < batch_size_warps; k++) { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - } - - __syncthreads(); - if(warp_id != (WARPS-1)){ return; } - // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; - - ticktock = ticktock == 0 ? 1 : 0; - for(int k = 0; k < batch_size_warps; k++) - { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - - // 129 mu - if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - - if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_A[warp_lane]; + } + + // 129 mu + if (warp_id == (WARPS - 1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if (col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; #endif } - -template __device__ void printnonzero(T *A, int num_values, const char * strval) -{ - for(int i = 0; i < num_values; i++) - if((float)A[i] != 0.0) - printf("%s %i %f\n", strval, i, (float)A[i]); +template __device__ void printnonzero(T* A, int num_values, const char* strval) { + for (int i = 0; i < num_values; i++) + if ((float)A[i] != 0.0) + printf("%s %i %f\n", strval, i, (float)A[i]); } +template +__global__ void kgemm_4bit_inference( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, + int blocksize +) { -template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) -{ - - //// element-wise kernel - //// 1. Load batch x k into registers - //// 2. Load k x k into registers - //// 3. dequantize and store in second pair of k x k - //// 4. matmul - //// 5. sum with cub - //// 6. store outputs - //// TC kernel - //// use k warps per thread block - //// 1. threadblock use read-only cache to read in register tile for A into shared memory - //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments - //// 3. each warp reads a segment of values 16x32 from B - //// 4. do dequantization from register of B into second pair of registers - //// 5. store (4) into fragment - //// 6. matmul aggregate into fragment C - //// 7. aggregate files of C into shared memory block C - //// 8. sum (7) - //// 9. write outputs to matmul output matrix + //// element-wise kernel + //// 1. Load batch x k into registers + //// 2. Load k x k into registers + //// 3. dequantize and store in second pair of k x k + //// 4. matmul + //// 5. sum with cub + //// 6. store outputs + //// TC kernel + //// use k warps per thread block + //// 1. threadblock use read-only cache to read in register tile for A into shared memory + //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments + //// 3. each warp reads a segment of values 16x32 from B + //// 4. do dequantization from register of B into second pair of registers + //// 5. store (4) into fragment + //// 6. matmul aggregate into fragment C + //// 7. aggregate files of C into shared memory block C + //// 8. sum (7) + //// 9. write outputs to matmul output matrix #if __CUDA_ARCH__ >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x *32; - const int warp_id = threadIdx.x / 32; - const int warp_idx = threadIdx.x % 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; - const int batch_size_warps = (WARPS-1)*2; - - T quant_map[16]; - - #pragma unroll 16 - for(int i = 0; i < 16; i++) - quant_map[i] = nf4_data[i]; - //__shared__ T quant_map[16*160]; - - T local_A[2]; - T local_B[64]; - unsigned char local_B_4bit[32]; - - - const int a_tile_offset = 16; - const int b_tile_offset = (16*32 + 16); - - __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[8*32]; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - - for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) - smem_C[i] = 0.0f; - - __syncthreads(); - - int ticktock = 0; - int idx = 0 + threadIdx.x; - int loaded_values = 0; - // prefetch - if(idx < K && warp_id < (WARPS-1)) - { - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; - - loaded_values = 1; - } - else - { - local_A[0] = local_A[1]; - loaded_values--; - - #pragma unroll 64 - for(int col = 0; col < 64; col+=2) - { - //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); - //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); - //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); - //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); - //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); - //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); - - //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); - //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); - local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); - local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); - } - } + using namespace nvcuda; + int col_offset = blockIdx.x * 32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS - 1) * 2; - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - //if(threadIdx.x == 0) - //printf("aa %i %i\n", idx, loaded_values); - - //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - { - idx = base_idx + threadIdx.x; - //if(threadIdx.x == 0) - //printf("%i %i\n", idx, loaded_values); - - //__syncthreads(); - if(idx < K && warp_id < (WARPS-1)) - { - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; - local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; - } + T quant_map[16]; + +#pragma unroll 16 + for (int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; - loaded_values = 1; - } - else - { - local_A[0] = local_A[1]; - loaded_values--; - - int absidx = (idx + col_offset)/blocksize; - half local_absmax = __ldg(&(absmax[absidx])); - - #pragma unroll 64 - for(int col = 0; col < 64; col+=2) - { - //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); - //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); - //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); - //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); - - //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); - //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); - local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); - local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16 * 32 + 16); + + __shared__ T smem_A[8 * 16 + (16 * (batch_size_warps - 1))]; + __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; + __shared__ T smem_C[8 * 32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + for (int i = threadIdx.x; i < (8 * 32); i += blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if (idx < K && warp_id < (WARPS - 1)) { + if (loaded_values == 0) { + local_A[0] = A[idx]; + local_A[1] = A[idx + blockDim.x - 32]; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; + + loaded_values = 1; + } else { + local_A[0] = local_A[1]; + loaded_values--; + +#pragma unroll 64 + for (int col = 0; col < 64; col += 2) { + // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + // local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + // local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + // local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + // local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + // local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + // local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160 * (local_B_4bit[col / 2] >> 4) + warp_idx] * T(17.0); + local_B[col + 1] = quant_map[160 * (local_B_4bit[col / 2] & 0x0F) + warp_idx] * T(17.0); + } } - //printnonzero(local_B, 128, ""); - } - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + local_B[col]; + } else if (warp_id < (WARPS - 1)) { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = 0.0f; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + 0.0f; } ticktock = ticktock == 0 ? 1 : 0; + // if(threadIdx.x == 0) + // printf("aa %i %i\n", idx, loaded_values); + + // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { + idx = base_idx + threadIdx.x; + // if(threadIdx.x == 0) + // printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if (idx < K && warp_id < (WARPS - 1)) { + if (loaded_values == 0) { + local_A[0] = A[idx]; + local_A[1] = A[idx + blockDim.x - 32]; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) { + local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; + local_B_4bit[col + 16] = B[(col_offset + col) * ldb + idx]; + } + + loaded_values = 1; + } else { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset) / blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + +#pragma unroll 64 + for (int col = 0; col < 64; col += 2) { + // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + // local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + // local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + // local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + // local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col / 2] >> 4)] * T(absidx); + local_B[col + 1] = quant_map[(local_B_4bit[col / 2] & 0x0F)] * T(absidx); + } + // printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; - if(warp_id == (WARPS-1)) - for(int k = 0; k < batch_size_warps; k++) - { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + local_B[col]; + } else if (warp_id < (WARPS - 1)) { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) + local_B[col] = 0.0f; + +#pragma unroll 32 + for (int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = + 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if (warp_id == (WARPS - 1)) + for (int k = 0; k < batch_size_warps; k++) { + wmma::load_matrix_sync( + a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 + ); // 111 mu + wmma::load_matrix_sync( + b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 + ); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + // if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + // } + if (warp_id != (WARPS - 1)) { + return; + } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for (int k = 0; k < batch_size_warps; k++) { + // if(warp_lane == 0) + // printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - } - - __syncthreads(); - //if(threadIdx.x == 0) - //{ - // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); - // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); - //} - if(warp_id != (WARPS-1)){ return; } - // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; - - ticktock = ticktock == 0 ? 1 : 0; - for(int k = 0; k < batch_size_warps; k++) - { - //if(warp_lane == 0) - //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - - // 129 mu - if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - - //printnonzero(smem_C, 32, ""); - - if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_C[warp_lane]; + } + + // 129 mu + if (warp_id == (WARPS - 1)) + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); + + // printnonzero(smem_C, 32, ""); + + if (col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; #endif } #define num_values_4bit 32 -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) -{ - - // per threadblock: - // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] - // 4 warps -> 4 loads per iter - // 1x32 * 32x4 -> 1x4 outputs per thread block - typedef cub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; - - const int warp_idx = threadIdx.x / 32; - const int warp_lane = threadIdx.x % 32; - const int row_B = (THREADS/32)*blockIdx.x + warp_idx; - const int offset_B = ldb*row_B; - const int num_values_8bit = num_values_4bit/2; - float local_C = 0.0f; - - unsigned char local_B_4bit[num_values_8bit]; - T local_B[num_values_4bit/4]; - T local_A[num_values_4bit/4]; - __shared__ T quant_map[16]; - T local_absmax = T(0.0f); - - if (threadIdx.x < 16) - quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); - //for(int i = threadIdx.x; i < 16; i++) - //quant_map[i] = T(__ldg(&datatype[i])); - __syncthreads(); - - // A: [1, K] - // B: [N, K] - for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) - { - const int inner_idx_halved = inner_idx/2; - - // Since blocksize will always be a power-of-2, we avoid more expensive - // division by the blocksize and instead use a shift operation. - // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. - const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); - - local_absmax = __ldg(&(absmax[absidx])); - - if(row_B < M) - { - if((inner_idx_halved + num_values_8bit) < (K/2)) - { - // this is the most important for performance considerations - reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; - } - else - { - #pragma unroll - for(int j = 0; j < (num_values_8bit); j++) - if((inner_idx_halved) + j < (K/2)) - local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; - else - local_B_4bit[j] = 0b01110111; - } - } - else - { - #pragma unroll - for(int j = 0; j < (num_values_8bit); j++) - local_B_4bit[j] = 0b01110111; - } - for(int i = 0; i < 4; i++) - { - #pragma unroll - for(int k = 0; k < num_values_8bit/4; k++) - { - #if BNB_BF16_AVAILABLE - local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; - local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; - #else - // bf16 multipliation not supported - local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); - local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); - #endif - } - - if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) - { - // this is also relatively important for performance - if(BITS==16) - { - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; - } - else - { - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; - reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; +template +__global__ void kgemm_4bit_inference_naive( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, + int lda, int ldb, int ldc, int blocksize +) { + + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int offset_B = ldb * row_B; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + if (threadIdx.x < 16) + quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); + // for(int i = threadIdx.x; i < 16; i++) + // quant_map[i] = T(__ldg(&datatype[i])); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + const int inner_idx_halved = inner_idx / 2; + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize)); + + local_absmax = __ldg(&(absmax[absidx])); + + if (row_B < M) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; } - } - else - #pragma unroll - for(int k = 0; k < num_values_4bit/4; k++) - if(inner_idx + (i*num_values_4bit/4) + k < K) - local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; - else - local_A[k] = T(0.0f); - - - // accumulate in float; small performance hit for Ampere, but lower error for outputs - #pragma unroll - for(int k = 0; k < num_values_4bit/4; k++) - { - #if BNB_BF16_AVAILABLE - local_C += (float)(local_A[k]*local_B[k]); - #else - // bf16 multipliation not supported - local_C += ((float)local_A[k]*(float)local_B[k]); - #endif - } - } - } + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { +#if BNB_BF16_AVAILABLE + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; +#else + // bf16 multipliation not supported + local_B[k * 2] = + T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax); + local_B[k * 2 + 1] = + T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax); +#endif + } - local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } - if(row_B < M && warp_lane == 0) - out[row_B] = T(local_C); + } else +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); +// accumulate in float; small performance hit for Ampere, but lower error for outputs +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { +#if BNB_BF16_AVAILABLE + local_C += (float)(local_A[k] * local_B[k]); +#else + // bf16 multipliation not supported + local_C += ((float)local_A[k] * (float)local_B[k]); +#endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < M && warp_lane == 0) + out[row_B] = T(local_C); } -template __global__ void kfunc(T *A, T *B, T value, long n) -{ - for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) - { - switch(FUNC) - { - case FILL: - A[i] = (T)value; - break; - case ARANGE: - A[i] = (T)i; - break; - case _MUL: - A[i] = A[i]*B[i]; - break; +template __global__ void kfunc(T* A, T* B, T value, long n) { + for (long i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += (blockDim.x * gridDim.x)) { + switch (FUNC) { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i] * B[i]; + break; + } } - } } - //============================================================== // TEMPLATE DEFINITIONS //============================================================== -template __global__ void kfunc(float *A, float *B, float value, long n); -template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); -template __global__ void kfunc(float *A, float *B, float value, long n); -template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float* A, float* B, float value, long n); +template __global__ void kfunc(unsigned char* A, unsigned char* B, unsigned char value, long n); +template __global__ void kfunc(float* A, float* B, float value, long n); +template __global__ void kfunc(float* A, float* B, float value, long n); // these are not used and make no sense, but the compiler needs them -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, +// float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, +// float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); // these are not used and make no sense, but the compiler needs them -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); - -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); - -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); - -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); - -template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); +// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, +// float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, +// float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); +template __global__ void gemm_device( + int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc +); + +template __global__ void kgemm_4bit_inference( + int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, + int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference( + int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, + int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference( + int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, + int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference( + int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, + int ldc, int blocksize +); + +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out, + int lda, int ldb, int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>( + int M, int N, int K, __nv_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, + __nv_bfloat16* out, int lda, int ldb, int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, + float* out, int lda, int ldb, int ldc, int blocksize +); + +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); + +template __global__ void kdequant_mm_int32_fp16<4, 512>( + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n +); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ -template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ - float* state1, float *unorm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n); \ +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ + template __global__ void kPreconditionOptimizer32bit1State( \ + gtype * g, gtype * p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, \ + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n \ + ); MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) @@ -2931,9 +2822,12 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) -#define MAKE_Optimizer32bit1State(oname, gtype) \ -template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ +#define MAKE_Optimizer32bit1State(oname, gtype) \ + template __global__ void kOptimizer32bit1State( \ + gtype * g, gtype * p, float* state1, float* unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, \ + const float lr, const float gnorm_scale, const bool skip_zeros, const int n \ + ); MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -2948,11 +2842,12 @@ MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16) -#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ -template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ - float* state1, float* state2, float *unorm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n); \ +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ + template __global__ void kPreconditionOptimizer32bit2State( \ + gtype * g, gtype * p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, \ + const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, \ + const int n \ + ); MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, half) @@ -2961,31 +2856,49 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16) -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - - -#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ -template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ - float *unorm, \ - const float beta1, \ - const float beta2, \ - const float eps, const int step, \ - float* __restrict__ const quantiles1, \ - float* max1, float* new_max1, \ - const float weight_decay, \ - const float gnorm_scale, \ - const int n); \ +template __global__ void kOptimizer32bit2State( + float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>( + __nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>( + __nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ + template __global__ void kPreconditionOptimizerStatic8bit1State( \ + gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, \ + const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n \ + ); MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) @@ -2996,17 +2909,13 @@ MAKE_PreconditionStatic8bit1State(LION, float) MAKE_PreconditionStatic8bit1State(ADAGRAD, half) MAKE_PreconditionStatic8bit1State(ADAGRAD, float) -#define MAKE_optimizerStatic8bit1State(oname, gtype) \ -template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ - const float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, \ - const float beta2, \ - const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, \ - float* max1, float* new_max1, \ - float weight_decay, \ - const float gnorm_scale, \ - const int n); \ +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ + template __global__ void kOptimizerStatic8bit1State( \ + gtype * p, gtype* const g, unsigned char* state1, const float* unorm, const float max_unorm, \ + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, \ + const float gnorm_scale, const int n \ + ); MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) @@ -3017,126 +2926,143 @@ MAKE_optimizerStatic8bit1State(LION, float) MAKE_optimizerStatic8bit1State(ADAGRAD, half) MAKE_optimizerStatic8bit1State(ADAGRAD, float) - -#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ -template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ - float *unorm, \ - const float beta1, const float beta2, \ - const float eps, const int step, \ - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ - float* max1, float* max2, float* new_max1, float* new_max2, \ - const float gnorm_scale, \ - const int n); \ +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ + template __global__ void kPreconditionOptimizerStatic8bit2State( \ + gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, \ + unsigned char* __restrict__ const state2, float* unorm, const float beta1, const float beta2, const float eps, \ + const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, \ + float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n \ + ); MAKE_PreconditionStatic8bit2State(ADAM, half) MAKE_PreconditionStatic8bit2State(ADAM, float) -#define MAKE_optimizerStatic8bit2State(oname, gtype) \ -template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ - const float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, \ - const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ - float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, \ - const float gnorm_scale, \ - const int n); \ +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ + template __global__ void kOptimizerStatic8bit2State( \ + gtype * p, gtype* const g, unsigned char* state1, unsigned char* state2, const float* unorm, \ + const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, \ + const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, \ + const int n \ + ); MAKE_optimizerStatic8bit2State(ADAM, half) MAKE_optimizerStatic8bit2State(ADAM, float) -template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); -template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); - -#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ - -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +template __global__ void + kPercentileClipping(float* __restrict__ g, float* gnorm_vec, int step, const int n); +template __global__ void + kPercentileClipping(half* __restrict__ g, float* gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ + template __global__ void kQuantizeBlockwise( \ + float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ + const int rand_offset, const int n \ + ); + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) - -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); - -#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ -template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ - const float beta1, const float beta2, const float beta3, const float alpha, \ - const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ - float* absmax1, float* absmax2, \ - float weight_decay, \ - const float gnorm_scale, const bool skip_zeros, const int n); \ +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n +); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ + template __global__ void kOptimizerStatic8bit2StateBlockwise( \ + gtype * p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, \ + const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, \ + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \ + ); MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) @@ -3145,15 +3071,12 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) -#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ -template __global__ void kOptimizerStatic8bit1StateBlockwise( \ - gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ - const float beta1, const float beta2, \ - const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, \ - float* absmax1, \ - float weight_decay, \ - const float gnorm_scale, const bool skip_zeros, const int n); \ +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ + template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype * p, gtype* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, \ + const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, \ + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \ + ); MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) @@ -3168,5 +3091,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) -template __device__ void printnonzero(float *A, int num_values, const char*strval); -template __device__ void printnonzero(half *A, int num_values, const char*strval); +template __device__ void printnonzero(float* A, int num_values, const char* strval); +template __device__ void printnonzero(half* A, int num_values, const char* strval); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index c5b996262..f60e6fdd0 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,116 +9,129 @@ #ifndef kernels #define kernels - -__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); -__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); - -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); - -template -__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -__global__ void kOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, - const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -__global__ void kOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -__global__ void -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - const float weight_decay, - const float gnorm_scale, const int n); - - -template -__global__ void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - float weight_decay, const float gnorm_scale, const int n); - - - -template -__global__ void -kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - const float gnorm_scale, const int n); - - -template +__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); +__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); + +template +__global__ void kQuantizeBlockwise( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +); +template __global__ void -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, const float gnorm_scale, const int n); - -template __global__ void kOptimizerStatic8bit2StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); - -template __global__ void kOptimizerStatic8bit1StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* absmax1, - float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n); - - -template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); - -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); - -template __global__ void kdequant_mm_int32_fp16( - int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); - -template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); -template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); - -template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); - -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); -template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); - -template __global__ void kfunc(T *A, T *B, T value, long n); + kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); + +template +__global__ void kPreconditionOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, + const float beta2, const float eps, const float weight_decay, const int step, const float lr, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit1State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, + const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, + float* new_max1, const float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit1State( + T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, + const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit2State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, + float* unorm, const float beta1, const float beta2, const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2State( + T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, + const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, + const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); + +template +__global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n +); + +template +__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols); +template +__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); + +template +__global__ void kTransformRowToFormat( + char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols +); + +template +__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); +template +__global__ void kgemm_4bit_inference( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, + int blocksize +); +template +__global__ void kgemm_4bit_inference_naive( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, + int lda, int ldb, int ldc, int blocksize +); + +template __global__ void kfunc(T* A, T* B, T value, long n); #endif diff --git a/csrc/mps_ops.mm b/csrc/mps_ops.mm index d198b3552..85ed1b1e4 100644 --- a/csrc/mps_ops.mm +++ b/csrc/mps_ops.mm @@ -5,63 +5,58 @@ #define NUM 4 #define NUM_BLOCK 4096 -static inline MPSGraph* get_graph() -{ - static MPSGraph* cur = nil; - if(!cur) { - cur = [[MPSGraph alloc] init]; - } - return cur; +static inline MPSGraph* get_graph() { + static MPSGraph* cur = nil; + if (!cur) { + cur = [[MPSGraph alloc] init]; + } + return cur; } -static inline id get_device() -{ - NSError *error = nil; - static id device = nil; - if(!device) { - device = MTLCreateSystemDefaultDevice(); - } - if(!device) { - NSLog(@"Failed to get MPS device"); - abort(); - } - return device; +static inline id get_device() { + NSError* error = nil; + static id device = nil; + if (!device) { + device = MTLCreateSystemDefaultDevice(); + } + if (!device) { + NSLog(@"Failed to get MPS device"); + abort(); + } + return device; } -static inline id get_library() -{ - NSError *error = nil; - static id library = nil; - if(!library) { - library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; - } - if(!library) { - NSLog(@"Failed to load bitsandbytes.metallib"); - abort(); - } - return library; +static inline id get_library() { + NSError* error = nil; + static id library = nil; + if (!library) { + library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; + } + if (!library) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } + return library; } /*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) { - id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; - return out; + id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 +dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out; }*/ - // MPSGraph function for quantize -extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) -{ - id device = get_device(); - id library = get_library(); - static id kernel = nil; - if(!kernel) { - kernel = [library newFunctionWithName:@"quantize"]; - if(!kernel) { - NSLog(@"Failed to load bitsandbytes.metallib"); - abort(); +extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) { + id device = get_device(); + id library = get_library(); + static id kernel = nil; + if (!kernel) { + kernel = [library newFunctionWithName:@"quantize"]; + if (!kernel) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } } - } - NSLog(@"Not implemented"); - return nil; + NSLog(@"Not implemented"); + return nil; } diff --git a/csrc/ops.cu b/csrc/ops.cu index a99df1a06..71256719f 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -3,175 +3,195 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include -#include -#include -#include #include #include #include +#include +#include +#include +#include #define ERR_NOT_IMPLEMENTED 100 - using namespace BinSearch; using std::cout; using std::endl; - -void quantize(float *code, float *A, unsigned char *out, int n) -{ - int num_blocks = n/1024; - num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; - kQuantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void quantize(float* code, float* A, unsigned char* out, int n) { + int num_blocks = n / 1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kQuantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream) -{ - int num_blocks = n/1024; - num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; - kDequantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { + int num_blocks = n / 1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kDequantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) -{ - int num_blocks = n/blocksize; - num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - - if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - - - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template +void quantizeBlockwise( + float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +) { + int num_blocks = n / blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if (blocksize == 4096) + kQuantizeBlockwise + <<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 2048) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 1024) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 512) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 256) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 128) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 64) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream) -{ - // printf("stream==%d\n",stream); - int num_blocks = n/blocksize; - num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - int tile_size = (DATA_TYPE > 0) ? 1024 : 512; - if(DATA_TYPE > 0) - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n); - else - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); - - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream +) { + // printf("stream==%d\n",stream); + int num_blocks = n / blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + if (DATA_TYPE > 0) + kDequantizeBlockwise + <<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n); + else + kDequantizeBlockwise + <<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - - -template void optimizer32bit(T* g, T* p, - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) -{ - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - switch(OPTIMIZER) - { - case ADAM: +template +void optimizer32bit( + T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, const float beta1, + const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, + const float lr, const float gnorm_scale, bool skip_zeros, const int n +) { + int num_blocks = n / 4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch (OPTIMIZER) { + case ADAM: case ADEMAMIX: - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + if (max_unorm > 0.0f) { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit2State<<>>( + g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + kOptimizer32bit2State<<>>( + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, + gnorm_scale, skip_zeros, n + ); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: + break; + case MOMENTUM: case RMSPROP: case ADAGRAD: - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + if (max_unorm > 0.0f) { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit1State + <<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + kOptimizer32bit1State<<>>( + g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, + skip_zeros, n + ); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; + break; case LION: - // in lion, the momentum update after the parameter update - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + // in lion, the momentum update after the parameter update + kOptimizer32bit1State<<>>( + g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, + skip_zeros, n + ); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - break; - } + + if (max_unorm > 0.0f) { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit1State + <<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + break; + } } -template void optimizerStatic8bit(T* p, T* g, - unsigned char* state1, unsigned char* state2, - float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, - float eps, int step, float lr, - float* quantiles1, float* quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, - const float gnorm_scale, int n) -{ - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - - if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } - - switch(OPTIMIZER) - { - case ADAM: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: +template +void optimizerStatic8bit( + T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, + float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n +) { + int num_blocks = n / 4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if (max_unorm > 0.0f) { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + } + + switch (OPTIMIZER) { + case ADAM: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); + CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit2State<<>>( + p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, + new_max2, gnorm_scale, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit2State<<>>( + p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, + max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: case RMSPROP: case ADAGRAD: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, + weight_decay, gnorm_scale, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; case LION: - // in lion, the momentum update happens after the parameter update - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - default: - break; - } + // in lion, the momentum update happens after the parameter update + kOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, + weight_decay, gnorm_scale, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + default: + break; + } } #define BLOCKSIZE_2STATE 256 @@ -179,148 +199,120 @@ template void optimizerStatic8bit(T* p, T* g, #define BLOCKSIZE_1STATE 256 #define NUM_1STATE 1 -template void optimizerStatic8bitBlockwise( - T* p, - T* g, - unsigned char* state1, - unsigned char* state2, - float beta1, - float beta2, - float beta3, - float alpha, - float eps, - int step, - float lr, - float* quantiles1, - float* quantiles2, - float* absmax1, - float* absmax2, - float weight_decay, - const float gnorm_scale, - bool skip_zeros, - int n +template +void optimizerStatic8bitBlockwise( + T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, bool skip_zeros, int n ) { - int num_blocks = 0; - switch(OPTIMIZER) - { - case ADAM: + int num_blocks = 0; + switch (OPTIMIZER) { + case ADAM: case ADEMAMIX: - num_blocks = n/BLOCKSIZE_2STATE; - num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; - kOptimizerStatic8bit2StateBlockwise<<>>( - p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, - skip_zeros, n - ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: + num_blocks = n / BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit2StateBlockwise + <<>>( + p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, + absmax2, weight_decay, gnorm_scale, skip_zeros, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: case ADAGRAD: case LION: - num_blocks = n/BLOCKSIZE_1STATE; - num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; - kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, - quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - } + num_blocks = n / BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit1StateBlockwise + <<>>( + p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + } } - - -template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) -{ - int num_blocks = n/2048; - num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); - kPercentileClipping<<>>(g, gnorm_vec, step, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template void percentileClipping(T* g, float* gnorm_vec, int step, const int n) { + int num_blocks = n / 2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float))); + kPercentileClipping<<>>(g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) -{ - const int falpha = 1; - const int fbeta = 0; - const void * alpha = &falpha; - const void * beta = &fbeta; - cublasStatus_t status; - - status = cublasGemmEx(context->m_handle, - transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, - transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, - m, n, k, - alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, - C, CUDA_R_32I, ldc, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); - - if (status != CUBLAS_STATUS_SUCCESS) - { - std::cout << "CUBLAS ERROR: Status " << status << std::endl; +void gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +) { + const int falpha = 1; + const int fbeta = 0; + const void* alpha = &falpha; + const void* beta = &fbeta; + cublasStatus_t status; + + status = cublasGemmEx( + context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k, + alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, C, CUDA_R_32I, ldc, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP + ); + + if (status != CUBLAS_STATUS_SUCCESS) { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; } - } -void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, - long long int strideA, long long int strideB, long long int strideC, int batchCount) -{ - const int falpha = 1; - const int fbeta = 0; - const void * alpha = &falpha; - const void * beta = &fbeta; - cublasStatus_t status; - - //cout << transposeA << transposeB << endl; - //printf("%i %i %i\n", m,n,k); - //printf("%i %i %i\n", lda,ldb,ldc); - //printf("%i %i %i\n", strideA, strideB, strideC); - //printf("%i\n", batchCount); - - status = cublasGemmStridedBatchedEx(context->m_handle, - transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, - transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, - m, n, k, - alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, - C, CUDA_R_32I, ldc, (long long int)strideC, batchCount, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT); - - if (status != CUBLAS_STATUS_SUCCESS) - { - std::cout << "CUBLAS ERROR: Status " << status << std::endl; +void strided_gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount +) { + const int falpha = 1; + const int fbeta = 0; + const void* alpha = &falpha; + const void* beta = &fbeta; + cublasStatus_t status; + + // cout << transposeA << transposeB << endl; + // printf("%i %i %i\n", m,n,k); + // printf("%i %i %i\n", lda,ldb,ldc); + // printf("%i %i %i\n", strideA, strideB, strideC); + // printf("%i\n", batchCount); + + status = cublasGemmStridedBatchedEx( + context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k, + alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, C, + CUDA_R_32I, ldc, (long long int)strideC, batchCount, CUDA_R_32I, CUBLAS_GEMM_DEFAULT + ); + + if (status != CUBLAS_STATUS_SUCCESS) { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; } - -} - -int roundoff(int v, int d) { - return (v + d - 1) / d * d; } +int roundoff(int v, int d) { return (v + d - 1) / d * d; } -template cublasLtOrder_t get_order() -{ - switch(ORDER) - { - case ROW: - return CUBLASLT_ORDER_ROW; - break; +template cublasLtOrder_t get_order() { + switch (ORDER) { + case ROW: + return CUBLASLT_ORDER_ROW; + break; case COL: - return CUBLASLT_ORDER_COL; - break; + return CUBLASLT_ORDER_COL; + break; case COL32: - return CUBLASLT_ORDER_COL32; - break; + return CUBLASLT_ORDER_COL32; + break; case COL_TURING: - return CUBLASLT_ORDER_COL4_4R2_8C; - break; + return CUBLASLT_ORDER_COL4_4R2_8C; + break; case COL_AMPERE: - return CUBLASLT_ORDER_COL32_2R_4R4; - break; - default: - break; - } + return CUBLASLT_ORDER_COL32_2R_4R4; + break; + default: + break; + } - return CUBLASLT_ORDER_ROW; + return CUBLASLT_ORDER_ROW; } template cublasLtOrder_t get_order(); @@ -329,355 +321,394 @@ template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); - -template int get_leading_dim(int dim1, int dim2) -{ - switch(ORDER) - { - case ROW: - return dim2; - break; +template int get_leading_dim(int dim1, int dim2) { + switch (ORDER) { + case ROW: + return dim2; + break; case COL: - return dim1; - break; + return dim1; + break; case COL32: - // 32*row tiles - return dim1*32; - break; + // 32*row tiles + return dim1 * 32; + break; case COL_TURING: - return 32*roundoff(dim1, 8); - break; + return 32 * roundoff(dim1, 8); + break; case COL_AMPERE: - // 32*32 tiles - return 32*roundoff(dim1, 32); - break; - default: - return 0; - break; - } + // 32*32 tiles + return 32 * roundoff(dim1, 32); + break; + default: + return 0; + break; + } } -template int igemmlt( - cublasLtHandle_t ltHandle, - int m, int n, int k, - const int8_t * A, - const int8_t * B, - void * C, - float * row_scale, - int lda, int ldb, int ldc, - cudaStream_t stream +template +int igemmlt( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream ) { - // Calculate C = A^T @ B, in col-major layout. - // - // Use the IMMA kernels requires: - // * A must be transposed and B must be non-transposed. - // * Dimensions m and k must be multiples of 4. - // * All pointers must be 4-byte aligned; 16-byte alignment preferred. - - int has_error = 0; - - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t aDesc, bDesc, cDesc; - cublasOperation_t opT = CUBLAS_OP_T; - - cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I; - cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F; - - cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); - - // Default layout order is col major - - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType)); - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); - - if (DTYPE_OUT == 32) { - int alpha = 1, beta = 0; - has_error |= checkCublasStatus(cublasLtMatmul( - ltHandle, matmulDesc, - &alpha, A, aDesc, - B, bDesc, &beta, - (int32_t*)C, cDesc, - (int32_t*)C, cDesc, - NULL, NULL, 0, stream - )); - } else { - // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. - - if (!SCALE_ROWS) { - float alpha = 1.0f, beta = 0.0f; - has_error |= checkCublasStatus(cublasLtMatmul( - ltHandle, matmulDesc, - &alpha, A, aDesc, - B, bDesc, &beta, - (int8_t*)C, cDesc, - (int8_t*)C, cDesc, - NULL, NULL, 0, stream - )); + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + + int has_error = 0; + + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t aDesc, bDesc, cDesc; + cublasOperation_t opT = CUBLAS_OP_T; + + cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I; + cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F; + + cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + + // Default layout order is col major + + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType)); + has_error |= + checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + int alpha = 1, beta = 0; + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, NULL, NULL, + 0, stream + )); } else { - cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; - float beta = 0.0f; - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute( - matmulDesc, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointerMode, - sizeof(alphaVec) - )); - has_error |= checkCublasStatus(cublasLtMatmul( - ltHandle, matmulDesc, - row_scale, A, aDesc, - B, bDesc, &beta, - (int8_t*)C, cDesc, - (int8_t*)C, cDesc, - NULL, NULL, 0, stream - )); + // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. + + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL, + NULL, 0, stream + )); + } else { + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + float beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, sizeof(alphaVec) + )); + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, row_scale, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL, + NULL, 0, stream + )); + } } - } - has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc)); - has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc)); + has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); - if(has_error == 1) - printf("error detected"); + if (has_error == 1) + printf("error detected"); - return has_error; + return has_error; } -int fill_up_to_nearest_multiple(int value, int multiple) -{ - return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +int fill_up_to_nearest_multiple(int value, int multiple) { + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream) -{ - const int threads = 512; - const int num_per_thread = 4; - const int num_per_block = threads * num_per_thread; - const int n = numRows*numCols; - const int num_blocks = (n + num_per_block - 1) / num_per_block; - - kdequant_mm_int32_fp16<<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream +) { + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows * numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + kdequant_mm_int32_fp16 + <<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { - if (threshold == 0.0) { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); - } else { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); - } - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream +) { + if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { - if (threshold == 0.0) - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); - else - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) { + if (threshold == 0.0) + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + else + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) -{ +void spmm_coo( + cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, half* B, int ldc, half* C, bool transposed_B +) { cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; float alpha = 1.0f; float beta = 0.0f; - void *dBuffer = NULL; + void* dBuffer = NULL; size_t bufferSize = 0; - CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, - A_rowidx, A_colidx, A_vals, - CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) ); + CHECK_CUSPARSE(cusparseCreateCoo( + &descA, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, + CUDA_R_16F + )); // Create dense matrix C - CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, - CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + CHECK_CUSPARSE(cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, CUDA_R_16F, CUSPARSE_ORDER_ROW)); // Create dense matrix B - if(transposed_B) - { - int tmp = A_cols; - A_cols = B_cols; - B_cols = tmp; + if (transposed_B) { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; } - CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, - CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + CHECK_CUSPARSE(cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, CUDA_R_16F, CUSPARSE_ORDER_ROW)); // allocate an external buffer if needed - CHECK_CUSPARSE( cusparseSpMM_bufferSize( - handle, - CUSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, CUDA_R_32F, - CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); - CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) ); + CHECK_CUSPARSE(cusparseSpMM_bufferSize( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta, + descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize + )); + CUDA_CHECK_RETURN(cudaMalloc(&dBuffer, bufferSize)); // execute SpMM - CHECK_CUSPARSE( cusparseSpMM(handle, - CUSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, CUDA_R_32F, - CUSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + CHECK_CUSPARSE(cusparseSpMM( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta, + descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, dBuffer + )); // destroy matrix/vector descriptors - CHECK_CUSPARSE( cusparseDestroySpMat(descA) ); - CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); - CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); - CUDA_CHECK_RETURN( cudaFree(dBuffer) ); + CHECK_CUSPARSE(cusparseDestroySpMat(descA)); + CHECK_CUSPARSE(cusparseDestroyDnMat(descB)); + CHECK_CUSPARSE(cusparseDestroyDnMat(descC)); + CUDA_CHECK_RETURN(cudaFree(dBuffer)); } -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) -{ +template +void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { - kspmm_coo_very_sparse_naive<<>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kspmm_coo_very_sparse_naive<<>>( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB + ); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) -{ +template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits) { - int num_blocks = (m+31)/32; + int num_blocks = (m + 31) / 32; - if(bits == 32) - gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - if(bits == 16) - gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if (bits == 32) + gemm_device<<>>(m, n, k, A, B, out, lda, ldb, ldc); + if (bits == 16) + gemm_device<<>>(m, n, k, A, B, out, lda, ldb, ldc); } -template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) -{ +template +void gemm_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize +) { - int num_blocks = (m+31)/32; + int num_blocks = (m + 31) / 32; - kgemm_4bit_inference<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference<<>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) -{ +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, cudaStream_t stream +) { - int num_blocks = (m+3)/4; - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + int num_blocks = (m + 3) / 4; + kgemm_4bit_inference_naive + <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void func(T *A, T *B, T value, long n) -{ - int threads = 512; - int blocks = n/threads; - blocks = n % threads == 0 ? blocks : blocks + 1; - blocks = blocks > 65535 ? 65535 : blocks; - kfunc<<>>(A, B, value, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template void func(T* A, T* B, T value, long n) { + int threads = 512; + int blocks = n / threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } //============================================================== // TEMPLATE DEFINITIONS //============================================================== -template void func(float *A, float *B, float value, long n); -template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); -template void func(float *A, float *B, float value, long n); -template void func(float *A, float *B, float value, long n); - -template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); -template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); -template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); - -//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); -template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); - -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); - -template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); -template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); -template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); - -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); - -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); -template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); - -#define MAKE_optimizer32bit(name, gtype) \ -template void optimizer32bit(gtype* g, gtype* p, \ - float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float beta3, const float alpha, \ - const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -MAKE_optimizer32bit(ADAM, half) -MAKE_optimizer32bit(ADAM, float) -MAKE_optimizer32bit(ADAM, __nv_bfloat16) -MAKE_optimizer32bit(MOMENTUM, half) -MAKE_optimizer32bit(MOMENTUM, float) -MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16) -MAKE_optimizer32bit(RMSPROP, half) -MAKE_optimizer32bit(RMSPROP, float) -MAKE_optimizer32bit(RMSPROP, __nv_bfloat16) -MAKE_optimizer32bit(LION, half) -MAKE_optimizer32bit(LION, float) -MAKE_optimizer32bit(LION, __nv_bfloat16) -MAKE_optimizer32bit(ADAGRAD, half) -MAKE_optimizer32bit(ADAGRAD, float) -MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16) -MAKE_optimizer32bit(ADEMAMIX, half) -MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) -MAKE_optimizer32bit(ADEMAMIX, float) - -#define MAKE_optimizerStatic8bit(name, gtype) \ -template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ - float *unorm, float max_unorm, float param_norm, \ - float beta1, float beta2, \ - float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, \ - float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, \ - const float gnorm_scale, int n); \ - -MAKE_optimizerStatic8bit(ADAM, half) -MAKE_optimizerStatic8bit(ADAM, float) -MAKE_optimizerStatic8bit(MOMENTUM, half) -MAKE_optimizerStatic8bit(MOMENTUM, float) -MAKE_optimizerStatic8bit(RMSPROP, half) -MAKE_optimizerStatic8bit(RMSPROP, float) -MAKE_optimizerStatic8bit(LION, half) -MAKE_optimizerStatic8bit(LION, float) -MAKE_optimizerStatic8bit(ADAGRAD, half) -MAKE_optimizerStatic8bit(ADAGRAD, float) - - -#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ -template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ - -MAKE_optimizerStatic8bitBlockwise(half, ADAM); +template void func(float* A, float* B, float value, long n); +template void func(unsigned char* A, unsigned char* B, unsigned char value, long n); +template void func(float* A, float* B, float value, long n); +template void func(float* A, float* B, float value, long n); + +template void gemm_4bit_inference( + int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize +); +template void gemm_4bit_inference_naive( + int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +); +template void gemm_4bit_inference_naive<__nv_bfloat16, 16>( + int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, + int lda, int ldb, int ldc, int blocksize, cudaStream_t stream +); +template void gemm_4bit_inference_naive( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +); + +// template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, +// int bits); +template void gemm_host(int m, int n, int k, half* A, half* B, half* out, int lda, int ldb, int ldc, int bits); + +template void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); +template void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); + +template int igemmlt<32, 0>( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +); +template int igemmlt<8, 0>( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +); +template int igemmlt<8, 1>( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +); + +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); +template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); +template void quantizeBlockwise<__nv_bfloat16, 0, FP4>( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); +template void quantizeBlockwise<__nv_bfloat16, 0, NF4>( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise<__nv_bfloat16, General8bit>( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise<__nv_bfloat16, FP4>( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +); +template void dequantizeBlockwise<__nv_bfloat16, NF4>( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +); + +#define MAKE_optimizer32bit(name, gtype) \ + template void optimizer32bit( \ + gtype * g, gtype * p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, \ + const int n \ + ); + +MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit( + MOMENTUM, __nv_bfloat16 +) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, __nv_bfloat16) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ + template void optimizerStatic8bit( \ + gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ + float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ + float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ + const float gnorm_scale, int n \ + ); + + MAKE_optimizerStatic8bit(ADAM, half) MAKE_optimizerStatic8bit(ADAM, float) MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit( + RMSPROP, half + ) MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(LION, half) MAKE_optimizerStatic8bit(LION, float) MAKE_optimizerStatic8bit(ADAGRAD, half) MAKE_optimizerStatic8bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ + template void optimizerStatic8bitBlockwise( \ + gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ + float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ + float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ + ); + + MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); @@ -696,8 +727,8 @@ MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); -template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); -template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(float* g, float* gnorm_vec, int step, const int n); +template void percentileClipping(half* g, float* gnorm_vec, int step, const int n); template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 99a24a209..01e11ff31 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -3,41 +3,41 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - #ifndef ops_H #define ops_H +#include #include -#include #include -#include +#include -#include -#include -#include #include +#include +#include +#include #include -#include #include +#include +#define CUDA_CHECK_RETURN(value) \ + { \ + cudaError_t _m_cudaStat = value; \ + if (_m_cudaStat != cudaSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } -#define CUDA_CHECK_RETURN(value) { \ - cudaError_t _m_cudaStat = value; \ - if (_m_cudaStat != cudaSuccess) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", \ - cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ - exit(1); \ - } } - - -#define CHECK_CUSPARSE(value) { \ - cusparseStatus_t _m_cudaStat = value; \ - if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", \ - cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ - exit(1); \ - } } - +#define CHECK_CUSPARSE(value) \ + { \ + cusparseStatus_t _m_cudaStat = value; \ + if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ + fprintf( \ + stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__ \ + ); \ + exit(1); \ + } \ + } inline void checkCudaStatus(cudaError_t status) { if (status != cudaSuccess) { @@ -49,140 +49,163 @@ inline void checkCudaStatus(cudaError_t status) { inline int checkCublasStatus(cublasStatus_t status) { if (status != CUBLAS_STATUS_SUCCESS) { printf("cuBLAS API failed with status %d\n", status); - //throw std::logic_error("cuBLAS API failed"); + // throw std::logic_error("cuBLAS API failed"); return 1; } return 0; } -typedef enum Operations_t -{ - ksmul = 0, +typedef enum Operations_t { + ksmul = 0, } Operations_t; -typedef enum Optimizer_t -{ - ADAM = 0, - MOMENTUM = 1, - RMSPROP = 2, - LARS = 3, - ADAGRAD = 4, - LION = 5, - ADEMAMIX = 6 +typedef enum Optimizer_t { + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6 } Optimizer_t; -typedef enum Transform_t -{ - ROW = 0, - COL = 1, - COL32 = 2, - COL_TURING = 3, - COL_AMPERE = 4, +typedef enum Transform_t { + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, } Transform_t; -typedef enum DataType_t -{ - General8bit = 0, - FP4 = 1, - NF4 = 2, +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, } DataType_t; -typedef enum Funcs_t -{ - FILL = 0, - ARANGE = 1, - _MUL = 2, +typedef enum Funcs_t { + FILL = 0, + ARANGE = 1, + _MUL = 2, } Funcs_t; -class Context -{ - public: - cublasHandle_t m_handle; - - Context() - { - cublasHandle_t handle; - cublasCreate_v2(&handle); - m_handle = handle; - } +class Context { + public: + cublasHandle_t m_handle; + Context() { + cublasHandle_t handle; + cublasCreate_v2(&handle); + m_handle = handle; + } }; -class ContextLt -{ - public: - cublasLtHandle_t m_handle; - - ContextLt() - { - cublasLtHandle_t handle; - cublasLtCreate(&handle); - m_handle = handle; - } +class ContextLt { + public: + cublasLtHandle_t m_handle; + ContextLt() { + cublasLtHandle_t handle; + cublasLtCreate(&handle); + m_handle = handle; + } }; -class ContextCusparse -{ - public: - cusparseHandle_t m_handle; - - ContextCusparse() - { - cusparseHandle_t handle; - cusparseCreate(&handle); - m_handle = handle; - } +class ContextCusparse { + public: + cusparseHandle_t m_handle; + ContextCusparse() { + cusparseHandle_t handle; + cusparseCreate(&handle); + m_handle = handle; + } }; -void quantize(float *code, float *A, unsigned char *out, int n); -void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream); - -template void optimizer32bit(T* g, T* p, - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, - int step, float lr, const float gnorm_scale, bool skip_zeros, int n); - -template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, - float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, - float eps, int step, float lr, - float* quantiles1, float* quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, - const float gnorm_scale, int n); - -template void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, - bool skip_zeros, int n); - -template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); - -void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, - long long int strideA, long long int strideB, long long int strideC, int batchCount); - -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); - -void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream); -void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); -void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); - -void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); - -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); - -void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); - -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); -template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); - -template void func(T *A, T *B, T value, long n); +void quantize(float* code, float* A, unsigned char* out, int n); +void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream); +template +void quantizeBlockwise( + float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream +); + +template +void optimizer32bit( + T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, + float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, + bool skip_zeros, int n +); + +template +void optimizerStatic8bit( + T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, + float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n +); + +template +void optimizerStatic8bitBlockwise( + T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, bool skip_zeros, int n +); + +template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); + +void gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +); +void strided_gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount +); + +template +int igemmlt( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +); + +void cutlass_igemm( + bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc +); +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream +); +void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream); +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream +); + +void spmm_coo( + cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, half* B, int ldc, half* C, bool transposed_B +); + +template +void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); + +void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits); +template +void gemm_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize +); +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, cudaStream_t stream +); + +template void func(T* A, T* B, T value, long n); #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 0b8b1942b..63f46a20c 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -20,39 +20,60 @@ #if BUILD_CUDA -//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +// void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } -void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) -{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } +void gemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) { + gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); +} -void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference( + int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize +) { + gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} -void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +void gemm_4bit_inference_naive_fp16( + int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} -void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) -{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +void gemm_4bit_inference_naive_bf16( + int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, + int lda, int ldb, int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive<__nv_bfloat16, 16>( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream + ); +} -void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +void gemm_4bit_inference_naive_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} -#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ -void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { func(A, B, value, n); } MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) - -#define MAKE_FUNC32(fname, oname, gtype, gbits) \ -void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float beta3, const float alpha, \ - const float eps, const float weight_decay, \ - const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ -{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ +#define MAKE_FUNC32(fname, oname, gtype, gbits) \ + void fname##32bit_grad_##gbits( \ + gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ + const float weight_decay, const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n \ + ) { \ + optimizer32bit( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \ + lr, gnorm_scale, skip_zeros, n \ + ); \ + } MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) @@ -70,19 +91,18 @@ MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32) MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16) MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16) - -#define MAKE_FUNC8(fname, oname, gtype, gbits) \ -void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ - float *unorm, float max_unorm, float param_norm, \ - float beta1, float beta2, \ - float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, \ - float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, float gnorm_scale, int n) \ -{ \ - optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ -} \ +#define MAKE_FUNC8(fname, oname, gtype, gbits) \ + void fname##_static_8bit_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ + float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ + float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ + float gnorm_scale, int n \ + ) { \ + optimizerStatic8bit( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \ + max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \ + ); \ + } MAKE_FUNC8(adam, ADAM, float, 32) MAKE_FUNC8(adam, ADAM, half, 16) @@ -93,11 +113,17 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16) MAKE_FUNC8(lion, LION, float, 32) MAKE_FUNC8(lion, LION, half, 16) -#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ -void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ -{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ +#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ + void fname##_8bit_blockwise_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ + float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ + float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ + ) { \ + optimizerStatic8bitBlockwise( \ + p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \ + weight_decay, gnorm_scale, skip_zeros, n \ + ); \ + } MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) @@ -118,239 +144,511 @@ MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) +void percentileClipping_g32(float* g, float* gnorm_vec, int step, const int n) { + percentileClipping(g, gnorm_vec, step, n); +} -void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } +void percentileClipping_g16(half* g, float* gnorm_vec, int step, const int n) { + percentileClipping(g, gnorm_vec, step, n); +} + +void quantizeBlockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_bf16( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_bf16_fp4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_bf16_nf4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_fp32_fp4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); +} + +void quantizeBlockwise_fp32_nf4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); +} + +void dequantizeBlockwise_fp16( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void dequantizeBlockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} -void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void dequantizeBlockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void dequantizeBlockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } \ -void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } \ -void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } \ +void dequantizeBlockwise_bf16( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); +} -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } -void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } -void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); +} -void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); } -void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } -void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); +} -int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { +int igemmlt_32( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +) { return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } -int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + +int igemmlt_8( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +) { return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } -int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + +int igemmlt_8_rowscale( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +) { return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } -void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) -{ spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } +void spmm_coo_very_sparse_naive_fp16( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} -void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) -{ spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } +void spmm_coo_very_sparse_naive_int8( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} #endif -extern "C" -{ +extern "C" { #if BUILD_CUDA - void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } - void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); } - - void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } - void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } - void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); } - - void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } - - void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } - - void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } - void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } - void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } - - void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } - - void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); } - void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } - void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } - - #define MAKE_CFUNC32(name, gtype, gbits) \ - void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float beta3, const float alpha, \ - const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ - { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ - - MAKE_CFUNC32(adam, float, fp32) - MAKE_CFUNC32(adam, half, fp16) - MAKE_CFUNC32(adam, __nv_bfloat16, bf16) - MAKE_CFUNC32(momentum, float, 32) - MAKE_CFUNC32(momentum, half, 16) - MAKE_CFUNC32(rmsprop, float, 32) - MAKE_CFUNC32(rmsprop, half, 16) - MAKE_CFUNC32(lion, float, fp32) - MAKE_CFUNC32(lion, half, fp16) - MAKE_CFUNC32(lion, __nv_bfloat16, bf16) - MAKE_CFUNC32(adagrad, float, 32) - MAKE_CFUNC32(adagrad, half, 16) - MAKE_CFUNC32(ademamix, float, fp32) - MAKE_CFUNC32(ademamix, half, fp16) - MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) - - #define MAKE_CFUNC8(name, gtype, gbits) \ - void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ - float *unorm, float max_unorm, float param_norm, \ - float beta1, float beta2, \ - float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, \ - float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, float gnorm_scale, int n) \ - { \ - name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ - } \ - - MAKE_CFUNC8(adam, float, 32) - MAKE_CFUNC8(adam, half, 16) - MAKE_CFUNC8(momentum, float, 32) - MAKE_CFUNC8(momentum, half, 16) - MAKE_CFUNC8(rmsprop, float, 32) - MAKE_CFUNC8(rmsprop, half, 16) - MAKE_CFUNC8(lion, float, 32) - MAKE_CFUNC8(lion, half, 16) - - #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ - void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ - - MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) - MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) - MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(lion, LION, half, fp16) - MAKE_CBLOCKWISE8(lion, LION, float, fp32) - MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) - MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) - MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) - - void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } - void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } - - void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) - { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } - void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, - long strideA, long strideB, long strideC, int batchCount) - { strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); } - - Context *get_context(){ return new Context(); } - ContextCusparse *get_cusparse(){ return new ContextCusparse(); } - - int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { - return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); - } - int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { - return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); - } - int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { - return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); - } - void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) - { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } - void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { - getRowStats(A, rowStats, threshold, rows, cols, stream); - } - void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { - int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); - } - - void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) - { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } - - void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) - { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } - - void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) - { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } - - //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) - //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } - - void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) - { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } - - void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } - - void *cget_managed_ptr(size_t bytes) - { - void *ptr; - CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - - return ptr; - } - - void cprefetch(void *ptr, size_t bytes, int device) - { - - int hasPrefetch = 0; - CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead - if (hasPrefetch == 0) return; - - CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - - #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ - void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ - - CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) - CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) - CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) - CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) - - void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) - { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } - - void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) - { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } - - void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) - { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } + +void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { + dequantize(code, A, out, n, stream); +} + +void cdequantize_blockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_fp4( + float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_nf4( + float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cquantize_blockwise_bf16( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_bf16_fp4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_bf16_nf4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_bf16( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +#define MAKE_CFUNC32(name, gtype, gbits) \ + void c##name##32bit_grad_##gbits( \ + gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ + const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, \ + const int n \ + ) { \ + name##32bit_grad_##gbits( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \ + lr, gnorm_scale, skip_zeros, n \ + ); \ + } + +MAKE_CFUNC32(adam, float, fp32) +MAKE_CFUNC32(adam, half, fp16) +MAKE_CFUNC32(adam, __nv_bfloat16, bf16) +MAKE_CFUNC32(momentum, float, 32) +MAKE_CFUNC32(momentum, half, 16) +MAKE_CFUNC32(rmsprop, float, 32) +MAKE_CFUNC32(rmsprop, half, 16) +MAKE_CFUNC32(lion, float, fp32) +MAKE_CFUNC32(lion, half, fp16) +MAKE_CFUNC32(lion, __nv_bfloat16, bf16) +MAKE_CFUNC32(adagrad, float, 32) +MAKE_CFUNC32(adagrad, half, 16) +MAKE_CFUNC32(ademamix, float, fp32) +MAKE_CFUNC32(ademamix, half, fp16) +MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) + +#define MAKE_CFUNC8(name, gtype, gbits) \ + void c##name##_static_8bit_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ + float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ + float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ + float gnorm_scale, int n \ + ) { \ + name##_static_8bit_grad_##gbits( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \ + max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \ + ); \ + } + +MAKE_CFUNC8(adam, float, 32) +MAKE_CFUNC8(adam, half, 16) +MAKE_CFUNC8(momentum, float, 32) +MAKE_CFUNC8(momentum, half, 16) +MAKE_CFUNC8(rmsprop, float, 32) +MAKE_CFUNC8(rmsprop, half, 16) +MAKE_CFUNC8(lion, float, 32) +MAKE_CFUNC8(lion, half, 16) + +#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ + void c##fname##_8bit_blockwise_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ + float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ + float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ + ) { \ + fname##_8bit_blockwise_grad_##gbits( \ + p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \ + weight_decay, gnorm_scale, skip_zeros, n \ + ); \ + } + +MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) +MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) +MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(lion, LION, half, fp16) +MAKE_CBLOCKWISE8(lion, LION, float, fp32) +MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) +MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) +MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) + +void cpercentile_clipping_g32(float* g, float* gnorm_vec, int step, const int n) { + percentileClipping_g32(g, gnorm_vec, step, n); +} + +void cpercentile_clipping_g16(half* g, float* gnorm_vec, int step, const int n) { + percentileClipping_g16(g, gnorm_vec, step, n); +} + +void cigemm( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +) { + gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); +} + +void cbatched_igemm( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long strideA, long strideB, long strideC, int batchCount +) { + strided_gemmex( + context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount + ); +} + +Context* get_context() { return new Context(); } + +ContextCusparse* get_cusparse() { return new ContextCusparse(); } + +int cigemmlt_32( + Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, + int ldb, int ldc, cudaStream_t stream +) { + return igemmlt_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +int cigemmlt_8( + Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, + int ldb, int ldc, cudaStream_t stream +) { + return igemmlt_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +int cigemmlt_8_rowscale( + Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, + int ldb, int ldc, cudaStream_t stream +) { + return igemmlt_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +void cdequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream +) { + dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); +} + +void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) { + getRowStats(A, rowStats, threshold, rows, cols, stream); +} + +void cint8_vector_quant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream +) { + int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); +} + +void cspmm_coo( + ContextCusparse* context, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, half* B, int ldc, half* C, bool transposed_B +) { + spmm_coo( + (cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, + transposed_B + ); +} + +void cspmm_coo_very_sparse_naive_fp16( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive_fp16( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} + +void cspmm_coo_very_sparse_naive_int8( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive_int8( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} + +// void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + +void cgemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) { + gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); +} + +void cgemm_4bit_inference( + int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize +) { + gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +void* cget_managed_ptr(size_t bytes) { + void* ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; +} + +void cprefetch(void* ptr, size_t bytes, int device) { + + int hasPrefetch = 0; + CUDA_CHECK_RETURN( + cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device) + ); // 40ns overhead + if (hasPrefetch == 0) + return; + + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { fname##_##type_name(A, B, value, n); } + +CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + +void cgemm_4bit_inference_naive_fp16( + int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemm_4bit_inference_naive_bf16( + int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, + int lda, int ldb, int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemm_4bit_inference_naive_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} #endif - void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } +void cquantize_blockwise_cpu_fp32( + float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n +) { + quantize_cpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp32( + float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n +) { + dequantize_cpu(code, A, absmax, out, blocksize, n); +} }