diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c23a2799..9c133e09f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu") message(FATAL_ERROR "XPU is not supported on macOS" ) endif() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS OFF) set(BUILD_XPU ON) else() diff --git a/csrc/kernels.hip b/csrc/kernels.hip index ec3f7f025..bef6cffa6 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -19,37 +19,42 @@ #define NUM 4 #define NUM_BLOCK 4096 -__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; +__device__ static float fp4_dequantization_lut[8] = { + 0.0f, // 0b000 + 0.005208333333f, // 0b001 + 0.66666667f, // 0b010 + 1.0f, // 0b011 + 0.33333333f, // 0b100 + 0.5f, // 0b101 + 0.16666667f, // 0b110 + 0.25f // 0b111 +}; + +__device__ static float nf4_dequantization_lut[16] = { + -1.0f, // 0b0000 + -0.6961928009986877f, // 0b0001 + -0.5250730514526367f, // 0b0010 + -0.39491748809814453f, // 0b0011 + -0.28444138169288635f, // 0b0100 + -0.18477343022823334f, // 0b0101 + -0.09105003625154495f, // 0b0110 + 0.0f, // 0b0111 + 0.07958029955625534f, // 0b1000 + 0.16093020141124725f, // 0b1001 + 0.24611230194568634f, // 0b1010 + 0.33791524171829224f, // 0b1011 + 0.44070982933044434f, // 0b1100 + 0.5626170039176941f, // 0b1101 + 0.7229568362236023f, // 0b1110 + 1.0f // 0b1111 +}; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda // Luckily we have atomicmax and atomicmin in ROCm - -__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) -{ - float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 111 - return 0.25000000f*absmax*sign; // 1111 - else - return 0.16666667f*absmax*sign; // 1110 - else - if((val & 0b0001) == 1) // 110 - return 0.50000000f*absmax*sign; // 1101 - else - return 0.33333333f*absmax*sign; // 1100 - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 1.00000000f*absmax*sign; // 1011 - else - return 0.66666667f*absmax*sign; // 1010 - else - if((val & 0b0001) == 1) // 100 - return 5.208333333e-03f*absmax*sign; // 1001 - else - return 0.00000000f*absmax*sign; // 1000 +__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { + float sign = 1.0f - 2 * ((val & 0b1000) >> 3); + return fp4_dequantization_lut[val & 0b111] * sign; } __device__ unsigned char dQuantizeFP4(float x) @@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x) return 0b0000+sign; } - -__device__ __forceinline__ float dDequantizeNF4(unsigned char val) -{ - - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if((val & 0b1000) == 8) - if((val & 0b0100) == 4) // 1 - if((val & 0b0010) == 2) // 11 - if((val & 0b0001) == 1) // 111 - return 1.0f; - else - return 0.7229568362236023f; - else - if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; - else - return 0.44070982933044434f; - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; - else - return 0.24611230194568634f; - else - if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; - else - return 0.07958029955625534f; - - else - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 011 - return 0.0f; - else - return -0.09105003625154495f; - else - if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; - else - return -0.28444138169288635f; - else - if((val & 0b0010) == 2) //00 - if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; - else - return -0.5250730514526367f; - else - if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; - else - return -1.0f; - -} +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } __device__ unsigned char dQuantizeNF4(float x) { @@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - unsigned char packed_4bit = 0; switch(DATA_TYPE) { case General8bit: @@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { - packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; - packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); } break; case NF4: #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { - packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; - packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); - qvals[j] = packed_4bit; + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; } @@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { - vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); - vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max; } break; case NF4: @@ -2109,7 +2057,11 @@ __global__ void kdequant_mm_int32_fp16( #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 -#define WARP_SIZE warpSize +#if defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { @@ -2503,7 +2455,7 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 16 for(int i = 0; i < 16; i++) - quant_map[i] = nf4_data[i]; + quant_map[i] = nf4_dequantization_lut[i]; //__shared__ T quant_map[16*160]; T local_A[2]; @@ -2708,13 +2660,13 @@ template __global__ void kgemm_4bit_inferenc // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block - typedef hipcub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE]; - const int warp_idx = threadIdx.x / warpSize; - const int warp_lane = threadIdx.x % warpSize; - const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; - const int offset_B = ldb*row_B; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int warp_lane = threadIdx.x % WARP_SIZE; + const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx; + const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; @@ -2732,7 +2684,7 @@ template __global__ void kgemm_4bit_inferenc // A: [1, K] // B: [M, K] - for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit) { const int inner_idx_halved = inner_idx/2; diff --git a/csrc/ops.hip b/csrc/ops.hip index 260b74b30..b26d138e1 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -20,6 +20,12 @@ #define ERR_NOT_IMPLEMENTED 100 +#if defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif + using namespace BinSearch; using std::cout; using std::endl; @@ -692,7 +698,7 @@ template void gemm_4bit_inference_naive(int m, int n, int //warpsize - 32 int num_blocks = (m+3)/4; //warpsize - 64 - if (warpSize == 64) { + if (WARP_SIZE == 64) { num_blocks = (m+1)/2; } diff --git a/tests/test_functional.py b/tests/test_functional.py index fb67430ae..072e3b4f5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -10,7 +10,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH +from bitsandbytes.cextension import HIP_ENVIRONMENT from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -463,6 +463,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -1408,10 +1409,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) - @pytest.mark.skipif( - HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", - reason="this test is not supported on ROCm with gfx90a architecture yet", - ) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_gemv_eye_4bit(self, device, storage_type, dtype): if device == "hpu" and not is_supported_on_hpu(storage_type, dtype): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 0e5f7bc18..51b4cf9cd 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -9,6 +9,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, @@ -233,6 +234,7 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows": pytest.skip("Triton is not officially supported on Windows") diff --git a/tests/test_ops.py b/tests/test_ops.py index 3b52bf284..02472630e 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -211,6 +211,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.")