From c9f4400e8d28c3f1552ca58bb79d9bf05446b97a Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 9 Dec 2025 12:40:37 -0500 Subject: [PATCH 1/2] CUDA/ROCm: Remove dead code --- csrc/kernels.cu | 488 --------------------------------------- csrc/kernels.cuh | 7 - csrc/kernels.hip | 465 ------------------------------------- csrc/kernels_hip.cuh | 7 - csrc/ops.cu | 27 --- csrc/ops.cuh | 7 - csrc/ops.hip | 23 -- csrc/ops_hip.cuh | 7 - csrc/pythonInterface.cpp | 25 -- 9 files changed, 1056 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 804e9db40..5c021ade1 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2025,429 +2025,12 @@ __global__ void kspmm_coo_very_sparse_naive( } } -#define WARPS 3 - -template -__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc) { - -#if __CUDA_ARCH__ >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x * 32; - const int warp_id = threadIdx.x / 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; - const int batch_size_warps = (WARPS - 1) * 2; - const int val_per_iter = blockDim.x - 32; - - T local_A[4]; - T local_B[128]; - - const int a_tile_offset = 16; - const int b_tile_offset = (16 * 32 + 16); - - __shared__ T smem_A[8 * 16 + (2 * 16 * (batch_size_warps - 1))]; - __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; - //__shared__ T smem_C[8*32]; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - - int ticktock = 0; - int idx = 0 + threadIdx.x; - int loaded_values = 0; - // prefetch - if (idx < K && warp_id < (WARPS - 1)) { - if (loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx + (1 * val_per_iter)]; - local_A[2] = A[idx + (2 * val_per_iter)]; - local_A[3] = A[idx + (3 * val_per_iter)]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) { - local_B[col] = B[(col_offset + col) * ldb + idx]; - local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; - local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; - local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; - } - loaded_values = 3; - } else { - - if (loaded_values == 3) { - local_A[0] = local_A[1]; -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = local_B[col + (32)]; - } else if (loaded_values == 2) { - local_A[0] = local_A[2]; -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = local_B[col + (64)]; - } else { - local_A[0] = local_A[3]; -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = local_B[col + (96)]; - } - loaded_values--; - } - - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - local_B[col]; - } else if (warp_id < (WARPS - 1)) { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { - idx = base_idx + threadIdx.x; - - __syncthreads(); - if (idx < K && warp_id < (WARPS - 1)) { - // local_A[0] = A[idx]; - - // #pragma unroll 32 - // for(int col = 0; col < 32; col++) - // local_B[col] = B[(col_offset+col)*ldb+idx]; - if (loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx + (1 * val_per_iter)]; - local_A[2] = A[idx + (2 * val_per_iter)]; - local_A[3] = A[idx + (3 * val_per_iter)]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) { - local_B[col] = B[(col_offset + col) * ldb + idx]; - local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; - local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; - local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; - } - loaded_values = 3; - - } else { - - if (loaded_values == 3) { - local_A[0] = local_A[1]; -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = local_B[col + (32)]; - } else if (loaded_values == 2) { - local_A[0] = local_A[2]; -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = local_B[col + (64)]; - } else { - local_A[0] = local_A[3]; -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = local_B[col + (96)]; - } - loaded_values--; - } - - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - local_B[col]; - } else if (warp_id < (WARPS - 1)) { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - if (warp_id == (WARPS - 1)) - for (int k = 0; k < batch_size_warps; k++) { - wmma::load_matrix_sync( - a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 - ); // 111 mu - wmma::load_matrix_sync( - b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 - ); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - } - - __syncthreads(); - if (warp_id != (WARPS - 1)) { - return; - } - // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; - - ticktock = ticktock == 0 ? 1 : 0; - for (int k = 0; k < batch_size_warps; k++) { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - - // 129 mu - if (warp_id == (WARPS - 1)) - wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - - if (col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_A[warp_lane]; -#endif -} - template __device__ void printnonzero(T* A, int num_values, const char* strval) { for (int i = 0; i < num_values; i++) if ((float)A[i] != 0.0) printf("%s %i %f\n", strval, i, (float)A[i]); } -template -__global__ void kgemm_4bit_inference( - int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, - int blocksize -) { - - //// element-wise kernel - //// 1. Load batch x k into registers - //// 2. Load k x k into registers - //// 3. dequantize and store in second pair of k x k - //// 4. matmul - //// 5. sum with cub - //// 6. store outputs - //// TC kernel - //// use k warps per thread block - //// 1. threadblock use read-only cache to read in register tile for A into shared memory - //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments - //// 3. each warp reads a segment of values 16x32 from B - //// 4. do dequantization from register of B into second pair of registers - //// 5. store (4) into fragment - //// 6. matmul aggregate into fragment C - //// 7. aggregate files of C into shared memory block C - //// 8. sum (7) - //// 9. write outputs to matmul output matrix -#if __CUDA_ARCH__ >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x * 32; - const int warp_id = threadIdx.x / 32; - const int warp_idx = threadIdx.x % 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; - const int batch_size_warps = (WARPS - 1) * 2; - - T quant_map[16]; - -#pragma unroll 16 - for (int i = 0; i < 16; i++) - quant_map[i] = nf4_dequantization_lut[i]; - //__shared__ T quant_map[16*160]; - - T local_A[2]; - T local_B[64]; - unsigned char local_B_4bit[32]; - - const int a_tile_offset = 16; - const int b_tile_offset = (16 * 32 + 16); - - __shared__ T smem_A[8 * 16 + (16 * (batch_size_warps - 1))]; - __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; - __shared__ T smem_C[8 * 32]; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - - for (int i = threadIdx.x; i < (8 * 32); i += blockDim.x) - smem_C[i] = 0.0f; - - __syncthreads(); - - int ticktock = 0; - int idx = 0 + threadIdx.x; - int loaded_values = 0; - // prefetch - if (idx < K && warp_id < (WARPS - 1)) { - if (loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx + blockDim.x - 32]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; - - loaded_values = 1; - } else { - local_A[0] = local_A[1]; - loaded_values--; - -#pragma unroll 64 - for (int col = 0; col < 64; col += 2) { - // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); - // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); - // local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); - // local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); - // local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); - // local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); - - // local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); - // local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); - local_B[col] = quant_map[160 * (local_B_4bit[col / 2] >> 4) + warp_idx] * T(17.0); - local_B[col + 1] = quant_map[160 * (local_B_4bit[col / 2] & 0x0F) + warp_idx] * T(17.0); - } - } - - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - local_B[col]; - } else if (warp_id < (WARPS - 1)) { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - // if(threadIdx.x == 0) - // printf("aa %i %i\n", idx, loaded_values); - - // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { - idx = base_idx + threadIdx.x; - // if(threadIdx.x == 0) - // printf("%i %i\n", idx, loaded_values); - - //__syncthreads(); - if (idx < K && warp_id < (WARPS - 1)) { - if (loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx + blockDim.x - 32]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) { - local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; - local_B_4bit[col + 16] = B[(col_offset + col) * ldb + idx]; - } - - loaded_values = 1; - } else { - local_A[0] = local_A[1]; - loaded_values--; - - int absidx = (idx + col_offset) / blocksize; - half local_absmax = __ldg(&(absmax[absidx])); - -#pragma unroll 64 - for (int col = 0; col < 64; col += 2) { - // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); - // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); - // local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); - // local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); - - // local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); - // local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); - local_B[col] = quant_map[(local_B_4bit[col / 2] >> 4)] * T(absidx); - local_B[col + 1] = quant_map[(local_B_4bit[col / 2] & 0x0F)] * T(absidx); - } - // printnonzero(local_B, 128, ""); - } - - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - local_B[col]; - } else if (warp_id < (WARPS - 1)) { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - local_B[col] = 0.0f; - -#pragma unroll 32 - for (int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = - 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - if (warp_id == (WARPS - 1)) - for (int k = 0; k < batch_size_warps; k++) { - wmma::load_matrix_sync( - a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 - ); // 111 mu - wmma::load_matrix_sync( - b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 - ); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - } - - __syncthreads(); - // if(threadIdx.x == 0) - //{ - // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); - // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); - // } - if (warp_id != (WARPS - 1)) { - return; - } - // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; - - ticktock = ticktock == 0 ? 1 : 0; - for (int k = 0; k < batch_size_warps; k++) { - // if(warp_lane == 0) - // printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - - // 129 mu - if (warp_id == (WARPS - 1)) - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - - // printnonzero(smem_C, 32, ""); - - if (col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_C[warp_lane]; -#endif -} - #define num_values_4bit 32 template @@ -2592,77 +2175,6 @@ template __global__ void kfunc(unsigned char* A, unsigned c template __global__ void kfunc(float* A, float* B, float value, long n); template __global__ void kfunc(float* A, float* B, float value, long n); -// these are not used and make no sense, but the compiler needs them -// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, -// float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, -// float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -// these are not used and make no sense, but the compiler needs them - -// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, -// float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -// template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, -// float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); -template __global__ void gemm_device( - int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc -); - -template __global__ void kgemm_4bit_inference( - int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, - int ldc, int blocksize -); -template __global__ void kgemm_4bit_inference( - int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, - int ldc, int blocksize -); -template __global__ void kgemm_4bit_inference( - int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, - int ldc, int blocksize -); -template __global__ void kgemm_4bit_inference( - int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, - int ldc, int blocksize -); - template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out, int lda, int ldb, int ldc, int blocksize diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 1d9b8b82e..558b46236 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -112,13 +112,6 @@ __global__ void kdequant_mm_int32_fp16( template __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); -template -__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); -template -__global__ void kgemm_4bit_inference( - int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, - int blocksize -); template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, diff --git a/csrc/kernels.hip b/csrc/kernels.hip index eb139c6ce..b8eb9195d 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2162,216 +2162,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } -#define WARPS 3 -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) -{ - -#if __CUDA_ARCH__ >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x *32; - const int warp_id = threadIdx.x / 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; - const int batch_size_warps = (WARPS-1)*2; - const int val_per_iter = blockDim.x-32; - - T local_A[4]; - T local_B[128]; - - const int a_tile_offset = 16; - const int b_tile_offset = (16*32 + 16); - - __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - //__shared__ T smem_C[8*32]; - - rocwmma::fragment a_frag; - rocwmma::fragment b_frag; - rocwmma::fragment c_frag; - rocwmma::fill_fragment(c_frag, 0.0f); - - int ticktock = 0; - int idx = 0 + threadIdx.x; - int loaded_values = 0; - // prefetch - if(idx < K && warp_id < (WARPS-1)) - { - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; - } - loaded_values = 3; - } - else - { - - if(loaded_values == 3) - { - local_A[0] = local_A[1]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(32)]; - } - else if(loaded_values == 2) - { - local_A[0] = local_A[2]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(64)]; - } - else - { - local_A[0] = local_A[3]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(96)]; - } - loaded_values--; - } - - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - { - idx = base_idx + threadIdx.x; - - __syncthreads(); - if(idx < K && warp_id < (WARPS-1)) - { - //local_A[0] = A[idx]; - - //#pragma unroll 32 - //for(int col = 0; col < 32; col++) - // local_B[col] = B[(col_offset+col)*ldb+idx]; - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; - } - loaded_values = 3; - - } - else - { - - if(loaded_values == 3) - { - local_A[0] = local_A[1]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(32)]; - } - else if(loaded_values == 2) - { - local_A[0] = local_A[2]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(64)]; - } - else - { - local_A[0] = local_A[3]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(96)]; - } - loaded_values--; - } - - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - if(warp_id == (WARPS-1)) - for(int k = 0; k < batch_size_warps; k++) - { - rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - } - - __syncthreads(); - if(warp_id != (WARPS-1)){ return; } - // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; - - ticktock = ticktock == 0 ? 1 : 0; - for(int k = 0; k < batch_size_warps; k++) - { - rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - - // 129 mu - if(warp_id == (WARPS-1)) - rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major); - - if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_A[warp_lane]; -#endif -} - - template __device__ void printnonzero(T *A, int num_values, const char * strval) { for(int i = 0; i < num_values; i++) @@ -2379,234 +2169,6 @@ template __device__ void printnonzero(T *A, int num_values, const c printf("%s %i %f\n", strval, i, (float)A[i]); } -template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) -{ - - //// element-wise kernel - //// 1. Load batch x k into registers + //// 2. Load k x k into registers - //// 3. dequantize and store in second pair of k x k + //// 4. matmul - //// 5. sum with cub - //// 6. store outputs - //// TC kernel - //// use k warps per thread block - //// 1. threadblock use read-only cache to read in register tile for A into shared memory - //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments - //// 3. each warp reads a segment of values 16x32 from B - //// 4. do dequantization from register of B into second pair of registers - //// 5. store (4) into fragment - //// 6. matmul aggregate into fragment C - //// 7. aggregate files of C into shared memory block C - //// 8. sum (7) - //// 9. write outputs to matmul output matrix -#if __CUDA_ARCH__ >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x *32; - const int warp_id = threadIdx.x / 32; - const int warp_idx = threadIdx.x % 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; - const int batch_size_warps = (WARPS-1)*2; - - T quant_map[16]; - - #pragma unroll 16 - for(int i = 0; i < 16; i++) - quant_map[i] = nf4_dequantization_lut[i]; - //__shared__ T quant_map[16*160]; - - T local_A[2]; - T local_B[64]; - unsigned char local_B_4bit[32]; - - - const int a_tile_offset = 16; - const int b_tile_offset = (16*32 + 16); - - __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[8*32]; - - rocwmma::fragment a_frag; - rocwmma::fragment b_frag; - rocwmma::fragment c_frag; - rocwmma::fill_fragment(c_frag, 0.0f); - - for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) - smem_C[i] = 0.0f; - - __syncthreads(); - - int ticktock = 0; - int idx = 0 + threadIdx.x; - int loaded_values = 0; - // prefetch - if(idx < K && warp_id < (WARPS-1)) - { - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; - - loaded_values = 1; - } - else - { - local_A[0] = local_A[1]; - loaded_values--; - - #pragma unroll 64 - for(int col = 0; col < 64; col+=2) - { - //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); - //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); - //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); - //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); - //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); - //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); - - //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); - //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); - local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); - local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); - } - } - - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - //if(threadIdx.x == 0) - //printf("aa %i %i\n", idx, loaded_values); - - //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - { - idx = base_idx + threadIdx.x; - //if(threadIdx.x == 0) - //printf("%i %i\n", idx, loaded_values); - - //__syncthreads(); - if(idx < K && warp_id < (WARPS-1)) - { - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; - local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; - } - - loaded_values = 1; - } - else - { - local_A[0] = local_A[1]; - loaded_values--; - - int absidx = (idx + col_offset)/blocksize; - half local_absmax = __ldg(&(absmax[absidx])); - - #pragma unroll 64 - for(int col = 0; col < 64; col+=2) - { - //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); - //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); - //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); - //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); - - //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); - //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); - local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); - local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); - } - //printnonzero(local_B, 128, ""); - } - - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - if(warp_id == (WARPS-1)) - for(int k = 0; k < batch_size_warps; k++) - { - rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - } - - __syncthreads(); - //if(threadIdx.x == 0) - //{ - // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); - // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); - //} - if(warp_id != (WARPS-1)){ return; } - // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; - - ticktock = ticktock == 0 ? 1 : 0; - for(int k = 0; k < batch_size_warps; k++) - { - //if(warp_lane == 0) - //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); - rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - - // 129 mu - if(warp_id == (WARPS-1)) - rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major); - - //printnonzero(smem_C, 32, ""); - - if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_C[warp_lane]; -#endif -} - // No of 4bit values processed by each thread #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) @@ -2764,33 +2326,6 @@ template __global__ void kfunc(unsigned char *A, unsigned c template __global__ void kfunc(float *A, float *B, float value, long n); template __global__ void kfunc(float *A, float *B, float value, long n); -// these are not used and make no sense, but the compiler needs them -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -// these are not used and make no sense, but the compiler needs them - -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); - -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); - template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index efd8a9048..1430d6441 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -114,13 +114,6 @@ __global__ void kdequant_mm_int32_fp16( template __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); -template -__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); -template -__global__ void kgemm_4bit_inference( - int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, - int blocksize -); template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, diff --git a/csrc/ops.cu b/csrc/ops.cu index 7b0c60cca..226ed10f6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -451,26 +451,6 @@ void spmm_coo_very_sparse_naive( CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits) { - - int num_blocks = (m + 31) / 32; - - if (bits == 32) - gemm_device<<>>(m, n, k, A, B, out, lda, ldb, ldc); - if (bits == 16) - gemm_device<<>>(m, n, k, A, B, out, lda, ldb, ldc); -} - -template -void gemm_4bit_inference( - int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize -) { - - int num_blocks = (m + 31) / 32; - - kgemm_4bit_inference<<>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); -} - template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, @@ -501,9 +481,6 @@ template void func(unsigned char* A, unsigned char* B, unsi template void func(float* A, float* B, float value, long n); template void func(float* A, float* B, float value, long n); -template void gemm_4bit_inference( - int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize -); template void gemm_4bit_inference_naive( int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream @@ -517,10 +494,6 @@ template void gemm_4bit_inference_naive( int ldc, int blocksize, cudaStream_t stream ); -// template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, -// int bits); -template void gemm_host(int m, int n, int k, half* A, half* B, half* out, int lda, int ldb, int ldc, int bits); - template void spmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB diff --git a/csrc/ops.cuh b/csrc/ops.cuh index a9c9bbb12..709432dcb 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -179,13 +179,6 @@ void spmm_coo_very_sparse_naive( float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ); -void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB); - -template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits); -template -void gemm_4bit_inference( - int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize -); template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, diff --git a/csrc/ops.hip b/csrc/ops.hip index 26a2362e3..dc3dc091e 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -589,25 +589,6 @@ template void spmm_coo_very_sparse_naive(int *max_count, CUDA_CHECK_RETURN(hipPeekAtLastError()); } -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) -{ - - int num_blocks = (m+31)/32; - - if(bits == 32) - hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); - if(bits == 16) - hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); -} - -template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) -{ - - int num_blocks = (m+31)/32; - - hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); -} - template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream) { @@ -641,14 +622,10 @@ template void func(unsigned char *A, unsigned char *B, unsi template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); -template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); -//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); -template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); - template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 72cdf4e01..4eb446206 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -181,13 +181,6 @@ void spmm_coo_very_sparse_naive( float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ); -void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB); - -template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits); -template -void gemm_4bit_inference( - int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize -); template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 07c79fc95..340f06145 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -42,18 +42,6 @@ #if BUILD_CUDA || BUILD_HIP -// void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) -//{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } -void gemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) { - gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); -} - -void gemm_4bit_inference( - int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize -) { - gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); -} - void gemm_4bit_inference_naive_fp16( int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream @@ -677,19 +665,6 @@ void cspmm_coo_very_sparse_naive_int8( ); } -// void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) -//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } - -void cgemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) { - gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); -} - -void cgemm_4bit_inference( - int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize -) { - gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); -} - void* cget_managed_ptr(size_t bytes) { void* ptr; CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); From 25e6a70821e21fb83357e566d144112e51aa4169 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:14:08 -0500 Subject: [PATCH 2/2] more cleanup --- csrc/kernels.cu | 9 --------- csrc/kernels.hip | 10 ---------- 2 files changed, 19 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5c021ade1..8af27075e 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2025,12 +2025,6 @@ __global__ void kspmm_coo_very_sparse_naive( } } -template __device__ void printnonzero(T* A, int num_values, const char* strval) { - for (int i = 0; i < num_values; i++) - if ((float)A[i] != 0.0) - printf("%s %i %f\n", strval, i, (float)A[i]); -} - #define num_values_4bit 32 template @@ -2508,6 +2502,3 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) - -template __device__ void printnonzero(float* A, int num_values, const char* strval); -template __device__ void printnonzero(half* A, int num_values, const char* strval); diff --git a/csrc/kernels.hip b/csrc/kernels.hip index b8eb9195d..f4bbfdd79 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2162,13 +2162,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } -template __device__ void printnonzero(T *A, int num_values, const char * strval) -{ - for(int i = 0; i < num_values; i++) - if((float)A[i] != 0.0) - printf("%s %i %f\n", strval, i, (float)A[i]); -} - // No of 4bit values processed by each thread #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) @@ -2621,6 +2614,3 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, hip_bfloat16, 256, 1) - -template __device__ void printnonzero(float *A, int num_values, const char*strval); -template __device__ void printnonzero(half *A, int num_values, const char*strval);