From 83216e03977554c9d8523ad823bbead42934769f Mon Sep 17 00:00:00 2001 From: badaoui Date: Wed, 18 Feb 2026 14:36:10 +0100 Subject: [PATCH 1/2] first commit --- csrc/compat.cuh | 297 ++++++++++++++ csrc/examples/CMakeLists_changes.md | 52 +++ csrc/examples/common_unified.cuh | 94 +++++ csrc/examples/kernels_unified.cu | 600 ++++++++++++++++++++++++++++ csrc/examples/ops_unified.cu | 333 +++++++++++++++ csrc/examples/ops_unified.cuh | 183 +++++++++ 6 files changed, 1559 insertions(+) create mode 100644 csrc/compat.cuh create mode 100644 csrc/examples/CMakeLists_changes.md create mode 100644 csrc/examples/common_unified.cuh create mode 100644 csrc/examples/kernels_unified.cu create mode 100644 csrc/examples/ops_unified.cu create mode 100644 csrc/examples/ops_unified.cuh diff --git a/csrc/compat.cuh b/csrc/compat.cuh new file mode 100644 index 000000000..53188e3a7 --- /dev/null +++ b/csrc/compat.cuh @@ -0,0 +1,297 @@ +// compat.cuh — Platform abstraction layer for CUDA/HIP portability +// +// This header resolves ALL mechanical differences between CUDA and HIP. +// Kernel code should include this header and use the bnb_* types/macros +// instead of cuda*/hip* identifiers directly. +// +// The guard macro is BNB_HIP, which is defined when compiling for ROCm/HIP +// (set via CMakeLists.txt's add_compile_definitions(__HIP_PLATFORM_AMD__)). + +#pragma once + +// ============================================================================ +// Platform detection +// ============================================================================ + +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#define BNB_HIP 1 +#else +#define BNB_HIP 0 +#endif + +// ============================================================================ +// Runtime and FP16/BF16 headers +// ============================================================================ + +#if BNB_HIP + +#include +#include +#include + +#else // CUDA + +#include +#include +#include +#include +#include + +#endif + +// ============================================================================ +// CUB / hipCUB — namespace alias +// +// Usage: bnb_cub::BlockLoad<...>, bnb_cub::BlockReduce<...>, etc. +// This single alias eliminates ~90% of the cub:: vs hipcub:: differences. +// ============================================================================ + +#if BNB_HIP + +#include +namespace bnb_cub = hipcub; + +#else // CUDA + +#include +#include +#include +#include +#include +#include +#include +namespace bnb_cub = cub; + +#endif + +// ============================================================================ +// Reduction operators — CUB's Max()/Sum() API differs across versions +// ============================================================================ + +#if BNB_HIP + +#define BNB_MAX_OP hipcub::Max() +#define BNB_SUM_OP hipcub::Sum() + +#else // CUDA + +// CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max() +#if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002 +#include +#define BNB_MAX_OP \ + cuda::maximum<> {} +#else +#define BNB_MAX_OP cub::Max() +#endif +#define BNB_SUM_OP cub::Sum() + +#endif + +// ============================================================================ +// Stream and error types +// ============================================================================ + +#if BNB_HIP + +using bnb_stream_t = hipStream_t; +using bnb_error_t = hipError_t; + +#define BNB_SUCCESS hipSuccess +#define BNB_PEEK_LAST_ERROR() hipPeekAtLastError() +#define BNB_GET_ERROR_STRING(e) hipGetErrorString(e) +#define BNB_DEVICE_MALLOC(p, s) hipMalloc(p, s) +#define BNB_DEVICE_FREE(p) hipFree(p) + +#else // CUDA + +using bnb_stream_t = cudaStream_t; +using bnb_error_t = cudaError_t; + +#define BNB_SUCCESS cudaSuccess +#define BNB_PEEK_LAST_ERROR() cudaPeekAtLastError() +#define BNB_GET_ERROR_STRING(e) cudaGetErrorString(e) +#define BNB_DEVICE_MALLOC(p, s) cudaMalloc(p, s) +#define BNB_DEVICE_FREE(p) cudaFree(p) + +#endif + +// ============================================================================ +// Error checking macro (unified name, platform-specific implementation) +// ============================================================================ + +#define BNB_CHECK_RETURN(value) \ + { \ + bnb_error_t _bnb_stat = value; \ + if (_bnb_stat != BNB_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", BNB_GET_ERROR_STRING(_bnb_stat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } + +// Keep backward compat for existing code during migration +#define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value) + +// ============================================================================ +// BFloat16 type alias +// +// CUDA uses __nv_bfloat16, HIP uses hip_bfloat16. Unified as bnb_bfloat16. +// ============================================================================ + +#if BNB_HIP +using bnb_bfloat16 = hip_bfloat16; +#else +using bnb_bfloat16 = __nv_bfloat16; +#endif + +// ============================================================================ +// Data type enum aliases for BLAS/Sparse libraries +// ============================================================================ + +#if BNB_HIP + +#define BNB_R_16F HIP_R_16F +#define BNB_R_32F HIP_R_32F +#define BNB_R_8I HIP_R_8I +#define BNB_R_32I HIP_R_32I + +#else // CUDA + +#define BNB_R_16F CUDA_R_16F +#define BNB_R_32F CUDA_R_32F +#define BNB_R_8I CUDA_R_8I +#define BNB_R_32I CUDA_R_32I + +#endif + +// ============================================================================ +// BLAS Lt types and functions +// ============================================================================ + +#if BNB_HIP + +#ifndef NO_HIPBLASLT +#include +#endif + +using bnb_blasLt_handle_t = hipblasLtHandle_t; +using bnb_blasLt_matmul_desc_t = hipblasLtMatmulDesc_t; +using bnb_blasLt_layout_t = hipblasLtMatrixLayout_t; +using bnb_blasLt_preference_t = hipblasLtMatmulPreference_t; + +#define BNB_BLASLT_OP_T HIPBLAS_OP_T +#define BNB_BLASLT_COMPUTE_32I HIPBLAS_COMPUTE_32I + +#define bnb_blasLtCreate hipblasLtCreate +#define bnb_blasLtMatmulDescCreate hipblasLtMatmulDescCreate +#define bnb_blasLtMatmulDescSetAttr hipblasLtMatmulDescSetAttribute +#define bnb_blasLtLayoutCreate hipblasLtMatrixLayoutCreate +#define bnb_blasLtLayoutDestroy hipblasLtMatrixLayoutDestroy +#define bnb_blasLtMatmulDescDestroy hipblasLtMatmulDescDestroy +#define bnb_blasLtMatmul hipblasLtMatmul +#define bnb_blasLtPrefCreate hipblasLtMatmulPreferenceCreate +#define bnb_blasLtPrefSetAttr hipblasLtMatmulPreferenceSetAttribute +#define bnb_blasLtAlgoGetHeuristic hipblasLtMatmulAlgoGetHeuristic + +#define BNB_BLASLT_DESC_TRANSA HIPBLASLT_MATMUL_DESC_TRANSA +#define BNB_BLASLT_DESC_POINTER_MODE HIPBLASLT_MATMUL_DESC_POINTER_MODE +#define BNB_BLASLT_PREF_MAX_WORKSPACE HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES +#define BNB_BLASLT_PTR_MODE_ALPHA_VEC HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST + +using bnb_blasLt_heuristic_t = hipblasLtMatmulHeuristicResult_t; +using bnb_blas_status_t = hipblasStatus_t; +#define BNB_BLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS + +#else // CUDA + +#include +#include + +using bnb_blasLt_handle_t = cublasLtHandle_t; +using bnb_blasLt_matmul_desc_t = cublasLtMatmulDesc_t; +using bnb_blasLt_layout_t = cublasLtMatrixLayout_t; + +#define BNB_BLASLT_OP_T CUBLAS_OP_T +#define BNB_BLASLT_COMPUTE_32I CUBLAS_COMPUTE_32I + +#define bnb_blasLtCreate cublasLtCreate +#define bnb_blasLtMatmulDescCreate cublasLtMatmulDescCreate +#define bnb_blasLtMatmulDescSetAttr cublasLtMatmulDescSetAttribute +#define bnb_blasLtLayoutCreate cublasLtMatrixLayoutCreate +#define bnb_blasLtLayoutDestroy cublasLtMatrixLayoutDestroy +#define bnb_blasLtMatmulDescDestroy cublasLtMatmulDescDestroy +#define bnb_blasLtMatmul cublasLtMatmul + +#define BNB_BLASLT_DESC_TRANSA CUBLASLT_MATMUL_DESC_TRANSA +#define BNB_BLASLT_DESC_POINTER_MODE CUBLASLT_MATMUL_DESC_POINTER_MODE +#define BNB_BLASLT_PTR_MODE_ALPHA_VEC CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO + +using bnb_blas_status_t = cublasStatus_t; +#define BNB_BLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS + +#endif + +// ============================================================================ +// Sparse library types +// ============================================================================ + +#if BNB_HIP + +#include + +using bnb_sparse_handle_t = hipsparseHandle_t; + +#define bnb_sparseCreate hipsparseCreate +#define bnb_sparseCreateCoo hipsparseCreateCoo +#define bnb_sparseCreateDnMat hipsparseCreateDnMat +#define bnb_sparseSpMM_bufSize hipsparseSpMM_bufferSize +#define bnb_sparseSpMM hipsparseSpMM +#define bnb_sparseDestroySpMat hipsparseDestroySpMat +#define bnb_sparseDestroyDnMat hipsparseDestroyDnMat + +#define BNB_SPARSE_INDEX_32I HIPSPARSE_INDEX_32I +#define BNB_SPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO +#define BNB_SPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW +#define BNB_SPARSE_OP_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE +#define BNB_SPARSE_OP_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE +#define BNB_SPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT + +#define CHECK_SPARSE(value) \ + { \ + hipsparseStatus_t _stat = value; \ + if (_stat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", hipsparseGetErrorString(_stat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } + +#else // CUDA + +#include + +using bnb_sparse_handle_t = cusparseHandle_t; + +#define bnb_sparseCreate cusparseCreate +#define bnb_sparseCreateCoo cusparseCreateCoo +#define bnb_sparseCreateDnMat cusparseCreateDnMat +#define bnb_sparseSpMM_bufSize cusparseSpMM_bufferSize +#define bnb_sparseSpMM cusparseSpMM +#define bnb_sparseDestroySpMat cusparseDestroySpMat +#define bnb_sparseDestroyDnMat cusparseDestroyDnMat + +#define BNB_SPARSE_INDEX_32I CUSPARSE_INDEX_32I +#define BNB_SPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO +#define BNB_SPARSE_ORDER_ROW CUSPARSE_ORDER_ROW +#define BNB_SPARSE_OP_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE +#define BNB_SPARSE_OP_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE +#define BNB_SPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT + +#define CHECK_SPARSE(value) \ + { \ + cusparseStatus_t _stat = value; \ + if (_stat != CUSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_stat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } + +#endif diff --git a/csrc/examples/CMakeLists_changes.md b/csrc/examples/CMakeLists_changes.md new file mode 100644 index 000000000..99a43d521 --- /dev/null +++ b/csrc/examples/CMakeLists_changes.md @@ -0,0 +1,52 @@ +# CMakeLists.txt Changes for Unified Kernels + +## Summary of changes + +Replace separate `CUDA_FILES` and `HIP_FILES` with a single `GPU_FILES` list. +For HIP builds, tell CMake to compile `.cu` files using the HIP language. + +## Diff + +```diff + # Define included source files + set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp) +-set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +-set(HIP_FILES csrc/ops.hip csrc/kernels.hip) ++set(GPU_FILES csrc/ops.cu csrc/kernels.cu) + set(MPS_FILES csrc/mps_ops.mm) + set(METAL_FILES csrc/mps_kernels.metal) + set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) +``` + +```diff + if(BUILD_CUDA) + # ... (CUDA setup unchanged) +- list(APPEND SRC_FILES ${CUDA_FILES}) ++ list(APPEND SRC_FILES ${GPU_FILES}) + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + add_compile_definitions(BUILD_CUDA) + elseif(BUILD_HIP) + # ... (HIP setup unchanged) +- list(APPEND SRC_FILES ${HIP_FILES}) ++ list(APPEND SRC_FILES ${GPU_FILES}) + string(APPEND BNB_OUTPUT_NAME "_rocm") + # ... +``` + +```diff + if(BUILD_HIP) + # ... +- set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) ++ set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP) + set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) + # ... + endif() +``` + +## Files to delete after migration + +- `csrc/common_hip.cuh` +- `csrc/kernels.hip` +- `csrc/kernels_hip.cuh` +- `csrc/ops.hip` +- `csrc/ops_hip.cuh` diff --git a/csrc/examples/common_unified.cuh b/csrc/examples/common_unified.cuh new file mode 100644 index 000000000..081dc7780 --- /dev/null +++ b/csrc/examples/common_unified.cuh @@ -0,0 +1,94 @@ +// common_unified.cuh — Merged architecture constants for CUDA and HIP +// +// This replaces both csrc/common.cuh and csrc/common_hip.cuh. +// Platform detection uses compat.cuh's BNB_HIP macro. + +#pragma once + +#include "compat.cuh" + +// ============================================================================ +// Warp size +// ============================================================================ + +#if BNB_HIP +// AMD GFX9 (CDNA) uses 64-wide warps; RDNA uses 32-wide +#ifdef __GFX9__ +#define BNB_WARP_SIZE 64 +#else +#define BNB_WARP_SIZE 32 +#endif +#else +#define BNB_WARP_SIZE 32 +#endif + +// ============================================================================ +// BF16 availability +// ============================================================================ + +#if BNB_HIP +// BF16 is available on all currently-supported ROCm architectures (CDNA2+, RDNA3+) +#define BNB_BF16_AVAILABLE true +#else +#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) +#endif + +// ============================================================================ +// CUDA compute capability constants (CUDA-only, but harmless to define on HIP) +// ============================================================================ + +#define BNB_CC_PASCAL 600 +#define BNB_CC_PASCAL_X2 620 +#define BNB_CC_VOLTA 700 +#define BNB_CC_VOLTA_XAVIER 720 +#define BNB_CC_TURING 750 +#define BNB_CC_AMPERE 800 +#define BNB_CC_AMPERE2 860 +#define BNB_CC_AMPERE2_ORIN 870 +#define BNB_CC_ADA 890 +#define BNB_CC_HOPPER 900 +#define BNB_CC_BLACKWELL 1000 + +// ============================================================================ +// Feature availability based on arch (CUDA uses __CUDA_ARCH__, HIP is simpler) +// ============================================================================ + +#if BNB_HIP +// HIP: MMA not supported via mma.h; FP8 support varies by arch +#define BNB_FP16_MMA_AVAILABLE 0 +#define BNB_INT8_MMA_AVAILABLE 0 +#define BNB_FP8_AVAILABLE 0 +#else +#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) +#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) +#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) +#endif + +// ============================================================================ +// Maximum threads per SM/CU +// ============================================================================ + +#if BNB_HIP +// For currently supported ROCm architectures (CDNA2, RDNA3) +#define BNB_MAX_THREADS_PER_SM 2048 +#else +// The maximum number of resident threads per SM varies by NVIDIA arch. +// Reference: CUDA Programming Guide, Technical Specifications per Compute Capability +#if __CUDA_ARCH__ == 750 +#define BNB_MAX_THREADS_PER_SM 1024 +#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 +#define BNB_MAX_THREADS_PER_SM 1536 +#else +#define BNB_MAX_THREADS_PER_SM 2048 +#endif +#endif + +// Maximum resident warps per SM/CU +#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) + +// Maximum resident blocks per SM/CU +#if !BNB_HIP && (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870) +#define BNB_MAX_BLOCKS_PER_SM 16 +#else +#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) +#endif diff --git a/csrc/examples/kernels_unified.cu b/csrc/examples/kernels_unified.cu new file mode 100644 index 000000000..7244dbce2 --- /dev/null +++ b/csrc/examples/kernels_unified.cu @@ -0,0 +1,600 @@ +// kernels_unified.cu — EXAMPLE of merged CUDA/HIP kernel source +// +// This file demonstrates how kernels.cu and kernels.hip can be unified +// into a single source file. It shows representative kernels covering +// all categories of differences: +// +// 1. Shared code (identical on both platforms) — kQuantize, kQuantizeBlockwise +// 2. Platform-specific atomics — atomicMax (CUDA needs custom, HIP has native) +// 3. Warp-size-dependent kernels — kQuantizeBlockwiseSmall (replaces +// kQuantizeBlockwise32 on CUDA and kQuantizeBlockwise64 on HIP) +// 4. Template instantiations — bnb_bfloat16 alias for __nv_bfloat16 / hip_bfloat16 +// +// Key principles: +// - Include "compat.cuh" for all platform abstractions +// - Use bnb_cub:: instead of cub:: or hipcub:: +// - Use BNB_MAX_OP / BNB_SUM_OP instead of cub::Max() / hipcub::Max() +// - Use bnb_bfloat16 instead of __nv_bfloat16 / hip_bfloat16 +// - Use #if BNB_HIP for truly divergent sections +// - <<>> syntax works on both platforms (HIP supports it natively) +// +// This file compiles as: +// - CUDA: nvcc compiles it as .cu (default) +// - HIP: CMake sets LANGUAGE HIP on this .cu file, hipcc compiles it +// +// Copyright (c) Facebook, Inc. and its affiliates. +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "common.cuh" // merged common_unified.cuh in the real version +#include "compat.cuh" +#include "kernels.cuh" // merged kernel declarations +#include // DataType_t enum + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +// ============================================================================ +// Lookup tables — identical on both platforms +// ============================================================================ + +__device__ static float fp4_dequantization_lut[8] = { + 0.0f, // 0b000 + 0.005208333333f, // 0b001 + 0.66666667f, // 0b010 + 1.0f, // 0b011 + 0.33333333f, // 0b100 + 0.5f, // 0b101 + 0.16666667f, // 0b110 + 0.25f // 0b111 +}; + +__device__ static float nf4_dequantization_lut[16] = { + -1.0f, // 0b0000 + -0.6961928009986877f, // 0b0001 + -0.5250730514526367f, // 0b0010 + -0.39491748809814453f, // 0b0011 + -0.28444138169288635f, // 0b0100 + -0.18477343022823334f, // 0b0101 + -0.09105003625154495f, // 0b0110 + 0.0f, // 0b0111 + 0.07958029955625534f, // 0b1000 + 0.16093020141124725f, // 0b1001 + 0.24611230194568634f, // 0b1010 + 0.33791524171829224f, // 0b1011 + 0.44070982933044434f, // 0b1100 + 0.5626170039176941f, // 0b1101 + 0.7229568362236023f, // 0b1110 + 1.0f // 0b1111 +}; + +// ============================================================================ +// atomicMax for float — CUDA needs a custom CAS loop, HIP has native support +// ============================================================================ + +#if !BNB_HIP +// CUDA: no native atomicMax for float, use CAS loop +// source: https://stackoverflow.com/questions/17399119 +__device__ float atomicMax(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS(reinterpret_cast(address), assumed, __float_as_int(fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} +#endif +// HIP: atomicMax for float is available natively in ROCm — no custom impl needed + +// ============================================================================ +// Device helper functions — identical on both platforms +// ============================================================================ + +__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { + float sign = 1.0f - 2 * ((val & 0b1000) >> 3); + return fp4_dequantization_lut[val & 0b111] * sign; +} + +__device__ unsigned char dQuantizeFP4(float x) { + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) + if (x > 0.583333f) + if (x > 0.8333333f) + return 0b0011 + sign; + else + return 0b0010 + sign; + else if (x > 0.4166667f) + return 0b101 + sign; + else + return 0b100 + sign; + else if (x > 0.0859375f) + if (x > 0.20833333f) + return 0b0111 + sign; + else + return 0b0110 + sign; + else if (x > 0.00260417f) + return 0b0001 + sign; + else + return 0b0000 + sign; +} + +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } + +__device__ unsigned char dQuantizeNF4(float x) { + if (x > 0.03979014977812767f) + if (x > 0.3893125355243683f) + if (x > 0.6427869200706482f) + if (x > 0.8614784181118011f) + return 0b1111; + else + return 0b1110; + else if (x > 0.5016634166240692f) + return 0b1101; + else + return 0b1100; + else if (x > 0.2035212516784668f) + if (x > 0.2920137718319893f) + return 0b1011; + else + return 0b1010; + else if (x > 0.1202552504837513f) + return 0b1001; + else + return 0b1000; + else if (x > -0.33967943489551544f) + if (x > -0.13791173323988914f) + if (x > -0.045525018125772476f) + return 0b0111; + else + return 0b0110; + else if (x > -0.23460740596055984f) + return 0b0101; + else + return 0b0100; + else if (x > -0.6106329262256622f) + if (x > -0.4599952697753906f) + return 0b0011; + else + return 0b0010; + else if (x > -0.8480964004993439f) + return 0b0001; + else + return 0b0000; +} + +// (dQuantize<> helper omitted for brevity — same pattern, no platform diffs) +template __device__ unsigned char dQuantize(float* smem_code, float rand, float x) { + // Binary search in quantization code — identical on both platforms + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float val = smem_code[pivot]; + for (int i = 64; i > 0; i >>= 1) { + if (x > val) { + lower_pivot = pivot; + pivot += i; + } else { + upper_pivot = pivot; + pivot -= i; + } + val = smem_code[pivot]; + } + + if (upper_pivot == 255) + upper_pivot = 254; + + if (STOCHASTIC) { + if (rand >= (x - smem_code[lower_pivot]) / (smem_code[upper_pivot] - smem_code[lower_pivot])) + return lower_pivot; + else + return upper_pivot; + } else { + if (fabsf(x - smem_code[lower_pivot]) < fabsf(x - smem_code[upper_pivot])) + return lower_pivot; + else + return upper_pivot; + } +} + +// ============================================================================ +// kQuantize — fully shared, zero #ifdefs needed +// +// Before (CUDA): typedef cub::BlockLoad<...> +// Before (HIP): typedef hipcub::BlockLoad<...> +// After (unified): typedef bnb_cub::BlockLoad<...> +// ============================================================================ + +__launch_bounds__(TH, 4) __global__ + void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n) { + const int n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x + 1 == gridDim.x) ? n - (blockIdx.x * NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + + // vvvvvvvv unified namespace alias — resolves to cub:: or hipcub:: + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + + if (threadIdx.x < 256) + smem_code[threadIdx.x] = code[threadIdx.x]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_BLOCK) { + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + +#pragma unroll 4 + for (int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +// ============================================================================ +// kQuantizeBlockwise — fully shared, uses BNB_MAX_OP +// +// The only change vs the original CUDA version: +// cub:: → bnb_cub:: +// CUB_REDUCTIONOP_MAX → BNB_MAX_OP +// ============================================================================ + +template +__global__ void kQuantizeBlockwise( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +) { + const int n_full = min(gridDim.x * BLOCK_SIZE, INT32_MAX); + const int base_idx = blockIdx.x * BLOCK_SIZE; + int valid_items = 0; + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH]; + + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockStore< + unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH, + bnb_cub::BLOCK_STORE_WARP_TRANSPOSE> + StoreChar; + typedef bnb_cub::BlockReduce BlockReduce; + typedef bnb_cub::BlockLoad + LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if (DATA_TYPE == General8bit) + for (int i = threadIdx.x; i < 256; i += blockDim.x) + smem_code[i] = code[i]; + + for (int64_t i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = min(BLOCK_SIZE, static_cast(n - i)); + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + // vvvvvvvvvv unified reduction op + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, BNB_MAX_OP, valid_items); + + if (threadIdx.x == 0) { + smem_absmax_value[0] = 1.0f / local_abs_max; + absmax[i / BLOCK_SIZE] = local_abs_max; + } + __syncthreads(); + + local_abs_max = smem_absmax_value[0]; + + if (STOCHASTIC) { + local_rand_idx = ((blockIdx.x * NUM_BLOCK) + (threadIdx.x * NUM) + rand_offset) % (1024 - 4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + switch (DATA_TYPE) { + case General8bit: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + if (!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j]) * local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j]) * local_abs_max); + } + break; + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + } + + __syncthreads(); + StoreChar(storec).Store( + &(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items + 1) / 2 : valid_items + ); + } +} + +// ============================================================================ +// kQuantizeBlockwiseSmall — unified warp-size-dependent kernel +// +// This replaces: +// CUDA: kQuantizeBlockwise32 (32 threads, blocksize=32, WarpReduce) +// HIP: kQuantizeBlockwise64 (64 threads, blocksize=64, WarpReduce) +// +// Strategy: Use BNB_WARP_SIZE to derive all constants at compile time. +// On CUDA (warp=32): SMALL_BLOCK_SIZE=32, THREADS=32, THREADS_PER_BLOCK=16 +// On HIP (warp=64): SMALL_BLOCK_SIZE=64, THREADS=64, THREADS_PER_BLOCK=32 +// On HIP (warp=32): SMALL_BLOCK_SIZE=32, THREADS=32, THREADS_PER_BLOCK=16 +// +// The algorithm is identical — only the numeric constants change. +// ============================================================================ + +template +__global__ void kQuantizeBlockwiseSmall( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +) { + // All constants derived from BNB_WARP_SIZE — no #ifdefs needed! + constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // 32 on CUDA, 32 or 64 on HIP + constexpr int NUM_PER_TH = 2; + constexpr int THREADS = BNB_WARP_SIZE; // One full hardware warp + constexpr int THREADS_PER_BLOCK = BNB_WARP_SIZE / 2; // Half-warp per quantization block + + const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 quantization blocks per thread block + + T vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH / 2]; + float local_abs_max = 0.0f; + + const int block_id = threadIdx.x / THREADS_PER_BLOCK; + const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; + + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockStore StoreChar; + // Logical warp of THREADS_PER_BLOCK: on warp32 HW this is a half-warp, + // on warp64 HW this splits the single HW warp into two logical warps + typedef bnb_cub::WarpReduce WarpReduce; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename WarpReduce::TempStorage warp_reduce[2]; + __shared__ float smem_absmax_value[2]; + + const int i = base_idx + block_id * BLOCK_SIZE; + const bool block_valid = (i < n); + + __syncthreads(); + LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f); + + local_abs_max = -FLT_MAX; +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, BNB_MAX_OP); + + if (local_thread_id == 0) { + if (block_valid) { + smem_absmax_value[block_id] = 1.0f / local_abs_max; + absmax[blockIdx.x * 2 + block_id] = local_abs_max; + } else { + smem_absmax_value[block_id] = 0.0f; + } + } + __syncthreads(); + + local_abs_max = smem_absmax_value[block_id]; + + switch (DATA_TYPE) { + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2)); +} + +// ============================================================================ +// kDequantizeBlockwise — fully shared +// ============================================================================ + +template +__global__ void + kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) { + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef bnb_cub::BlockLoad LoadChar; + typedef bnb_cub::BlockStore 0) ? 2 : 1), bnb_cub::BLOCK_STORE_WARP_TRANSPOSE> + StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + if (DATA_TYPE > 0) { + valid_items_load = min(TILE_SIZE, static_cast((static_cast(n) + 1) / 2) - i); + valid_items_store = min(TILE_SIZE * 2, n - i * 2); + } else { + valid_items_load = min(TILE_SIZE, n - i); + valid_items_store = valid_items_load; + } + + // blocksize is always power-of-2: use bitwise AND instead of division + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load); + + switch (DATA_TYPE) { + case General8bit: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + local_abs_max = absmax[(i + (threadIdx.x * NUM_PER_TH) + j) / blocksize]; + vals[j] = (T)(code[qvals[j]] * local_abs_max); + } + break; + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + local_abs_max = absmax[((i * 2) + (threadIdx.x * NUM_PER_TH * 2) + (j * 2)) / blocksize]; + vals[j * 2] = (T)(dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max); + vals[j * 2 + 1] = (T)(dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max); + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + local_abs_max = absmax[((i * 2) + (threadIdx.x * NUM_PER_TH * 2) + (j * 2)) / blocksize]; + vals[j * 2] = (T)(dDequantizeNF4(qvals[j] >> 4) * local_abs_max); + vals[j * 2 + 1] = (T)(dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max); + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, valid_items_store); + } +} + +// ============================================================================ +// Template instantiations — bnb_bfloat16 replaces __nv_bfloat16 / hip_bfloat16 +// ============================================================================ + +#define MAKE_kQuantizeBlockwise(dtype, block_size, num_per_th, stochastic, data_type_name) \ + template __global__ void kQuantizeBlockwise( \ + float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ + const int rand_offset, const int n \ + ); + +// half instantiations +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +// ... (remaining half/float instantiations identical to current) + +// float instantiations +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +// ... (remaining float instantiations) + +// bnb_bfloat16 — resolves to __nv_bfloat16 on CUDA, hip_bfloat16 on HIP +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4) + +// Unified small-blocksize kernel instantiations +#define MAKE_kQuantizeBlockwiseSmall(dtype, data_type_name) \ + template __global__ void kQuantizeBlockwiseSmall( \ + float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ + const int rand_offset, const int n \ + ); + +MAKE_kQuantizeBlockwiseSmall(half, FP4) MAKE_kQuantizeBlockwiseSmall(float, FP4) MAKE_kQuantizeBlockwiseSmall( + bnb_bfloat16, FP4 +) MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float, NF4) MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, NF4) + + // Dequantize instantiations + template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n + ); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n +); diff --git a/csrc/examples/ops_unified.cu b/csrc/examples/ops_unified.cu new file mode 100644 index 000000000..902c15046 --- /dev/null +++ b/csrc/examples/ops_unified.cu @@ -0,0 +1,333 @@ +// ops_unified.cu — EXAMPLE of merged host wrappers for CUDA/HIP +// +// This replaces both csrc/ops.cu and csrc/ops.hip. Shows representative +// functions covering all categories of differences. +// +// Key points: +// - <<>> works on both CUDA and HIP (no hipLaunchKernelGGL needed) +// - BNB_CHECK_RETURN replaces CUDA_CHECK_RETURN / hip equivalent +// - bnb_stream_t replaces cudaStream_t / hipStream_t +// - #if BNB_HIP only for genuinely different library code (igemmlt, spmm_coo) + +#include "common.cuh" +#include "compat.cuh" +#include "kernels.cuh" +#include "ops_unified.cuh" + +#include +#include +#include + +#if !BNB_HIP +#include +#endif + +#define ERR_NOT_IMPLEMENTED 100 + +using std::cout; +using std::endl; + +// ============================================================================ +// Quantize / Dequantize — fully shared, <<<>>> works on both platforms +// ============================================================================ + +void quantize(float* code, float* A, unsigned char* out, int n) { + int num_blocks = n / 1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kQuantize<<>>(code, A, out, n); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +void dequantize(float* code, unsigned char* A, float* out, int n, bnb_stream_t stream) { + int num_blocks = n / 1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kDequantize<<>>(code, A, out, n); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +// ============================================================================ +// quantizeBlockwise — mostly shared, small warp-size dispatch difference +// ============================================================================ + +template +void quantizeBlockwise( + float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +) { + int num_blocks = n / blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if (blocksize == 4096) + kQuantizeBlockwise + <<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 2048) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 1024) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 512) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 256) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 128) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 64) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + // Smallest blocksize: uses unified kQuantizeBlockwiseSmall + // BNB_WARP_SIZE is the compile-time block size (32 on CUDA, 32 or 64 on HIP) + else if (blocksize == BNB_WARP_SIZE) { + if constexpr (DATA_TYPE > 0) { + int num_blocks_adjusted = (num_blocks + 1) / 2; + kQuantizeBlockwiseSmall + <<>>(code, A, absmax, out, rand, rand_offset, n); + } + } + + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +// ============================================================================ +// dequantizeBlockwise — fully shared +// ============================================================================ + +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, bnb_stream_t stream +) { + constexpr int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + int grid_blocks = ((int64_t)n + tile_size - 1) / tile_size; + + if (DATA_TYPE > 0) + kDequantizeBlockwise + <<>>(code, A, absmax, out, blocksize / 2, n); + else + kDequantizeBlockwise + <<>>(code, A, absmax, out, blocksize, n); + + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +// ============================================================================ +// gemm_4bit_inference_naive — small warp-size difference in block count +// ============================================================================ + +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, bnb_stream_t stream +) { + // Warp size affects how many rows each block processes + int num_blocks; + if constexpr (BNB_WARP_SIZE == 64) + num_blocks = (m + 1) / 2; + else + num_blocks = (m + 3) / 4; + + kgemm_4bit_inference_naive + <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +// ============================================================================ +// igemmlt — BLAS library calls genuinely differ between cuBLAS and hipBLAS +// +// This is one of the few functions requiring substantial #if BNB_HIP blocks. +// The algorithm is the same but hipBLAS requires explicit heuristic selection +// while cuBLAS auto-selects. +// ============================================================================ + +template +int igemmlt( + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream +) { +#if BNB_HIP && defined(NO_HIPBLASLT) + return ERR_NOT_IMPLEMENTED; +#else + int has_error = 0; + + bnb_blasLt_matmul_desc_t matmulDesc; + bnb_blasLt_layout_t aDesc, bDesc, cDesc; + + auto outType = DTYPE_OUT == 32 ? BNB_R_32I : BNB_R_8I; + auto scaleType = DTYPE_OUT == 32 ? BNB_R_32I : BNB_R_32F; + auto opT = BNB_BLASLT_OP_T; + + has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&aDesc, BNB_R_8I, m, k, lda)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&bDesc, BNB_R_8I, m, n, ldb)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&cDesc, outType, k, n, ldc)); + + has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescCreate(&matmulDesc, BNB_BLASLT_COMPUTE_32I, scaleType)); + has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescSetAttr(matmulDesc, BNB_BLASLT_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + int alpha = 1, beta = 0; + +#if BNB_HIP + // HIP requires explicit algorithm heuristic selection + bnb_blasLt_preference_t pref; + const int64_t max_workspace_size = 0; + checkBlasLtStatus(bnb_blasLtPrefCreate(&pref)); + checkBlasLtStatus( + bnb_blasLtPrefSetAttr(pref, BNB_BLASLT_PREF_MAX_WORKSPACE, &max_workspace_size, sizeof(max_workspace_size)) + ); + + bnb_blasLt_heuristic_t heuristicResult[1]; + int returnedAlgoCount = 0; + checkBlasLtStatus(bnb_blasLtAlgoGetHeuristic( + ltHandle, matmulDesc, aDesc, bDesc, cDesc, cDesc, pref, 1, heuristicResult, &returnedAlgoCount + )); + + if (returnedAlgoCount == 0) { + has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); + } else { + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( + ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, + &heuristicResult[0].algo, NULL, 0, stream + )); + } +#else + // CUDA: cuBLAS auto-selects algorithm + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( + ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, NULL, NULL, + 0, stream + )); +#endif + } else { + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( + ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL, + NULL, 0, stream + )); + } else { + auto pointerMode = BNB_BLASLT_PTR_MODE_ALPHA_VEC; + float beta = 0.0f; + has_error |= checkBlasLtStatus( + bnb_blasLtMatmulDescSetAttr(matmulDesc, BNB_BLASLT_DESC_POINTER_MODE, &pointerMode, sizeof(pointerMode)) + ); + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( + ltHandle, matmulDesc, row_scale, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL, + NULL, 0, stream + )); + } + } + + has_error |= checkBlasLtStatus(bnb_blasLtLayoutDestroy(cDesc)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutDestroy(bDesc)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutDestroy(aDesc)); + has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescDestroy(matmulDesc)); + + if (has_error == 1) + printf("error detected"); + + return has_error; +#endif +} + +// ============================================================================ +// spmm_coo — sparse library calls differ but structure is identical +// Uses unified CHECK_SPARSE and bnb_sparse* macros from compat.cuh +// ============================================================================ + +void spmm_coo( + bnb_sparse_handle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, + int B_cols, int ldb, half* B, int ldc, half* C, bool transposed_B +) { +#if BNB_HIP && defined(NO_HIPBLASLT) + // No sparse support on older ROCm +#else + float alpha = 1.0f; + float beta = 0.0f; + void* dBuffer = NULL; + size_t bufferSize = 0; + + // Note: all of these use the bnb_sparse* macros from compat.cuh + // which resolve to cusparse* or hipsparse* as appropriate + + // bnb_sparseCreateCoo → cusparseCreateCoo / hipsparseCreateCoo + // BNB_R_16F → CUDA_R_16F / HIP_R_16F + // etc. + + // Omitting the body as it would be identical to what compat.cuh provides + // (see full macro mappings in compat.cuh) + + CHECK_SPARSE(bnb_sparseCreateCoo( + NULL, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, BNB_SPARSE_INDEX_32I, BNB_SPARSE_INDEX_BASE_ZERO, + BNB_R_16F + )); + + // ... (rest of spmm_coo using bnb_sparse* macros — same pattern) +#endif +} + +// ============================================================================ +// Simple kernel launchers — fully shared +// ============================================================================ + +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, bnb_stream_t stream +) { + const int threads = 512; + const int num_per_thread = 4; + const int n = numRows * numCols; + const int num_blocks = (n + threads * num_per_thread - 1) / (threads * num_per_thread); + + kdequant_mm_int32_fp16 + <<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, bnb_stream_t stream +) { + if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +template void func(T* A, T* B, T value, long n) { + int threads = 512; + int blocks = n / threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +// ============================================================================ +// Template instantiations +// ============================================================================ + +template void func(float* A, float* B, float value, long n); +template void func(unsigned char* A, unsigned char* B, unsigned char value, long n); +template void func(float* A, float* B, float value, long n); +template void func(float* A, float* B, float value, long n); + +template void gemm_4bit_inference_naive( + int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, + int ldc, int blocksize, bnb_stream_t stream +); +template void gemm_4bit_inference_naive( + int m, int n, int k, bnb_bfloat16* A, unsigned char* B, float* absmax, float* datatype, bnb_bfloat16* out, int lda, + int ldb, int ldc, int blocksize, bnb_stream_t stream +); +template void gemm_4bit_inference_naive( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, bnb_stream_t stream +); + +template int igemmlt<32, 0>( + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream +); +template int igemmlt<8, 0>( + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream +); +template int igemmlt<8, 1>( + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream +); diff --git a/csrc/examples/ops_unified.cuh b/csrc/examples/ops_unified.cuh new file mode 100644 index 000000000..b0dd8aaf2 --- /dev/null +++ b/csrc/examples/ops_unified.cuh @@ -0,0 +1,183 @@ +// ops_unified.cuh — EXAMPLE of merged host API declarations for CUDA/HIP +// +// This replaces both csrc/ops.cuh and csrc/ops_hip.cuh. +// Uses compat.cuh types for all platform-specific identifiers. + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include +#include +#include + +#include "compat.cuh" +#include + +// ============================================================================ +// Error checking helpers — unified via compat.cuh types +// ============================================================================ + +inline void checkDeviceStatus(bnb_error_t status) { + if (status != BNB_SUCCESS) { + printf("Device API failed with status %d: %s\n", status, BNB_GET_ERROR_STRING(status)); + throw std::logic_error("Device API failed"); + } +} + +inline int checkBlasLtStatus(bnb_blas_status_t status) { + if (status != BNB_BLAS_STATUS_SUCCESS) { + printf("BLAS Lt API failed with status %d\n", status); + return 1; + } + return 0; +} + +// ============================================================================ +// Enums — identical on both platforms +// ============================================================================ + +typedef enum Operations_t { ksmul = 0 } Operations_t; + +typedef enum Optimizer_t { + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6, +} Optimizer_t; + +typedef enum Funcs_t { FILL = 0, ARANGE = 1, _MUL = 2 } Funcs_t; + +// ============================================================================ +// Context classes — platform-specific handles via #if BNB_HIP +// +// This is one of the few places where #if BNB_HIP is needed, because +// the BLAS handle types and creation APIs genuinely differ. +// ============================================================================ + +class Context { + public: +#if BNB_HIP + rocblas_handle m_handle; + + Context() { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } +#else + cublasHandle_t m_handle; + + Context() { + cublasHandle_t handle; + cublasCreate_v2(&handle); + m_handle = handle; + } +#endif +}; + +class ContextLt { + public: + bnb_blasLt_handle_t m_handle; + + ContextLt() { + bnb_blasLt_handle_t handle; + bnb_blasLtCreate(&handle); + m_handle = handle; + } +}; + +class ContextSparse { + public: + bnb_sparse_handle_t m_handle; + + ContextSparse() { + bnb_sparse_handle_t handle; + bnb_sparseCreate(&handle); + m_handle = handle; + } +}; + +// ============================================================================ +// Function declarations — use bnb_stream_t / bnb_sparse_handle_t +// ============================================================================ + +void quantize(float* code, float* A, unsigned char* out, int n); +void dequantize(float* code, unsigned char* A, float* out, int n, bnb_stream_t stream); + +template +void quantizeBlockwise( + float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, bnb_stream_t stream +); + +template +void optimizer32bit( + T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, + float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, + bool skip_zeros, int n +); + +template +void optimizerStatic8bit( + T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, + float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n +); + +template +void optimizerStatic8bitBlockwise( + T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, bool skip_zeros, int n +); + +template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); + +void gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +); + +template +int igemmlt( + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream +); + +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, bnb_stream_t stream +); + +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, bnb_stream_t stream +); + +void spmm_coo( + bnb_sparse_handle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, + int B_cols, int ldb, half* B, int ldc, half* C, bool transposed_B +); + +template +void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); + +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, bnb_stream_t stream +); + +template void func(T* A, T* B, T value, long n); + +#endif From 9c69888035168d37edf77025275980258f5453de Mon Sep 17 00:00:00 2001 From: badaoui Date: Thu, 19 Feb 2026 15:26:15 +0100 Subject: [PATCH 2/2] update --- csrc/examples/CMakeLists_changes.md | 52 - csrc/examples/CMakeLists_unified.txt | 376 ++++ csrc/examples/common_unified.cuh | 4 +- csrc/{ => examples}/compat.cuh | 58 +- csrc/examples/compat_device.cuh | 58 + csrc/examples/kernels_unified.cu | 2431 +++++++++++++++++++-- csrc/examples/ops_unified.cu | 604 ++++- csrc/examples/ops_unified.cuh | 36 +- csrc/examples/pythonInterface_unified.cpp | 890 ++++++++ 9 files changed, 4070 insertions(+), 439 deletions(-) delete mode 100644 csrc/examples/CMakeLists_changes.md create mode 100644 csrc/examples/CMakeLists_unified.txt rename csrc/{ => examples}/compat.cuh (87%) create mode 100644 csrc/examples/compat_device.cuh create mode 100644 csrc/examples/pythonInterface_unified.cpp diff --git a/csrc/examples/CMakeLists_changes.md b/csrc/examples/CMakeLists_changes.md deleted file mode 100644 index 99a43d521..000000000 --- a/csrc/examples/CMakeLists_changes.md +++ /dev/null @@ -1,52 +0,0 @@ -# CMakeLists.txt Changes for Unified Kernels - -## Summary of changes - -Replace separate `CUDA_FILES` and `HIP_FILES` with a single `GPU_FILES` list. -For HIP builds, tell CMake to compile `.cu` files using the HIP language. - -## Diff - -```diff - # Define included source files - set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp) --set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) --set(HIP_FILES csrc/ops.hip csrc/kernels.hip) -+set(GPU_FILES csrc/ops.cu csrc/kernels.cu) - set(MPS_FILES csrc/mps_ops.mm) - set(METAL_FILES csrc/mps_kernels.metal) - set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) -``` - -```diff - if(BUILD_CUDA) - # ... (CUDA setup unchanged) -- list(APPEND SRC_FILES ${CUDA_FILES}) -+ list(APPEND SRC_FILES ${GPU_FILES}) - string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") - add_compile_definitions(BUILD_CUDA) - elseif(BUILD_HIP) - # ... (HIP setup unchanged) -- list(APPEND SRC_FILES ${HIP_FILES}) -+ list(APPEND SRC_FILES ${GPU_FILES}) - string(APPEND BNB_OUTPUT_NAME "_rocm") - # ... -``` - -```diff - if(BUILD_HIP) - # ... -- set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) -+ set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP) - set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) - # ... - endif() -``` - -## Files to delete after migration - -- `csrc/common_hip.cuh` -- `csrc/kernels.hip` -- `csrc/kernels_hip.cuh` -- `csrc/ops.hip` -- `csrc/ops_hip.cuh` diff --git a/csrc/examples/CMakeLists_unified.txt b/csrc/examples/CMakeLists_unified.txt new file mode 100644 index 000000000..99cf334ae --- /dev/null +++ b/csrc/examples/CMakeLists_unified.txt @@ -0,0 +1,376 @@ +# This CMake config hopefully makes it easier to compile. +# Ensure the CUDA Toolkit is available on your path. Then run: +# For GCC: `cmake -B build . && cmake --build build` +# For MSVC: `cmake -B build . && cmake --build build --config Release` +# You can also use the following options and variables +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version +# is whatever CMake finds on your path. +# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. +# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90;100;120` +# Check your compute capability here: https://developer.nvidia.com/cuda-gpus +# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler +cmake_minimum_required(VERSION 3.22.1) + +project(bitsandbytes LANGUAGES CXX) + +# If run without specifying a build type, default to using the Release configuration: +# optimizing the generated binaries for performance and also adds the `-DNDEBUG` flag, +# which turns off a bunch of asserts which seem to link to new symbols in libstdc++, +# worsening our many_linux compliance.. +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# Define included source files +set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp) +set(GPU_FILES csrc/ops.cu csrc/kernels.cu) +set(MPS_FILES csrc/mps_ops.mm) +set(METAL_FILES csrc/mps_kernels.metal) +set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) +# C++ sources are always included +list(APPEND SRC_FILES ${CPP_FILES}) + +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu) +option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) + +if(APPLE) + set(CMAKE_OSX_DEPLOYMENT_TARGET 14.0) +endif() + +set(BNB_OUTPUT_NAME "bitsandbytes") + +message(STATUS "Configuring ${PROJECT_NAME} (Backend: ${COMPUTE_BACKEND})") + +if(${COMPUTE_BACKEND} STREQUAL "cuda") + if(APPLE) + message(FATAL_ERROR "CUDA is not supported on macOS" ) + endif() + set(BUILD_CUDA ON) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "hip") + if(APPLE) + message(FATAL_ERROR "HIP is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP ON) + set(BUILD_MPS OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "mps") + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS ON) +elseif(${COMPUTE_BACKEND} STREQUAL "xpu") + if(APPLE) + message(FATAL_ERROR "XPU is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + set(BUILD_XPU ON) +else() + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + set(BUILD_XPU OFF) + set(BUILD_CPU ON) +endif() + + +if (BUILD_CPU) + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH) + find_package(OpenMP) +endif() + +if(BUILD_CUDA) + # NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+. + # Workaround: use --allow-unsupported-compiler + # This needs to be added *before* we try to enable the CUDA language so CMake's compiler check passes. + if(MSVC AND MSVC_VERSION VERSION_GREATER_EQUAL 1940) + string(APPEND CMAKE_CUDA_FLAGS " --allow-unsupported-compiler") + + # This is needed to build with VS2022 17.11+ and CUDA < 12.4. + if (MSVC_VERSION VERSION_GREATER_EQUAL 1941) + string(APPEND CMAKE_CUDA_FLAGS " -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH") + endif() + endif() + + enable_language(CUDA) # This will fail if CUDA is not found + find_package(CUDAToolkit REQUIRED) + + # Convert the CUDA version from X.Y.z to XY. There's probably a shorter way of doing this + string(REGEX MATCH "^[0-9]+.[0-9]+" _CUDA_VERSION_FIRST_TWO "${CMAKE_CUDA_COMPILER_VERSION}") + string(REPLACE "." "" CUDA_VERSION_SHORT "${_CUDA_VERSION_FIRST_TWO}") + + # Expose a cache variable that the user can set to ensure the correct version of CUDA is found + set(CUDA_VERSION "${CUDA_VERSION_SHORT}" CACHE STRING "Expected CUDA Version Shortcode") + + message(STATUS "CUDA Version: ${CUDA_VERSION_SHORT} (${CMAKE_CUDA_COMPILER_VERSION})") + message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + + # It should match the discovered version + if(NOT CUDA_VERSION STREQUAL "${CUDA_VERSION_SHORT}") + message(FATAL_ERROR "You've specified CUDA version ${CUDA_VERSION} however the CUDA compiler found is ${CUDA_VERSION_SHORT}." + " Ensure the desired CUDA compiler is the first one available on your PATH." + ) + endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.8") + message(FATAL_ERROR "CUDA Version < 11.8 is not supported") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "14.0") + message(FATAL_ERROR "CUDA Version > 13 is not supported") + endif() + + # CMake < 3.23.0 does not define CMAKE_CUDA_ARCHITECTURES_ALL. + if(CMAKE_VERSION VERSION_LESS "3.23.0") + message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...") + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + # Starting in CUDA 13.0, Thor Blackwell is renamed to SM110. + # Support for architectures older than Turing (SM75) is removed. + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 75 80 86 87 88 89 90 100 103 110 120 121) + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 80 90 100 110 120) + else() + # 11.8-12.9 supports these at a minimum. + set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80 86 87 89 90) + set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80 90) + + # CUDA 12.8 adds support for Blackwell. + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.8") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 100 101 120 121) + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 100 120) + endif() + + # CUDA 12.9 adds SM103 (Blackwell B300). + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 103) + endif() + endif() + endif() + + string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math") + + # It's safe for us to enable more aggressive compression for 13.0+ + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + string(APPEND CMAKE_CUDA_FLAGS " --compress-mode=size") + endif() + + if(PTXAS_VERBOSE) + string(APPEND CMAKE_CUDA_FLAGS " -Xptxas=-v") + endif() + + foreach(capability ${CMAKE_CUDA_ARCHITECTURES_ALL}) + # Most of the items here are like: `xx-real`, so we just extract the `xx` portion + string(REGEX MATCH "[0-9]+" capability_id "${capability}") + if(capability_id GREATER 0) + list(APPEND POSSIBLE_CAPABILITIES ${capability_id}) + endif() + endforeach() + + # This can be changed via -D argument to CMake + # By default all possible capabilities are compiled + set(COMPUTE_CAPABILITY "${POSSIBLE_CAPABILITIES}" CACHE STRING "Compute Capabilities Targeted") + + message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}") + message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}") + + # Use the "real" option to build native cubin for all selections. + # Ensure we build the PTX for the latest version. + # This behavior of adding a PTX (virtual) target for the highest architecture + # is similar to how the "all" and "all-major" options would behave in CMake >= 3.23. + # TODO: Consider bumping CMake requirement and using CMAKE_CUDA_ARCHITECTURES=[all | native] by default + list(REMOVE_DUPLICATES COMPUTE_CAPABILITY) + list(SORT COMPUTE_CAPABILITY COMPARE NATURAL) + list(POP_BACK COMPUTE_CAPABILITY _LATEST_CAPABILITY) + list(TRANSFORM COMPUTE_CAPABILITY APPEND "-real" OUTPUT_VARIABLE CMAKE_CUDA_ARCHITECTURES) + list(APPEND CMAKE_CUDA_ARCHITECTURES ${_LATEST_CAPABILITY}) + + message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}") + + list(APPEND SRC_FILES ${GPU_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + add_compile_definitions(BUILD_CUDA) +elseif(BUILD_HIP) + enable_language(HIP) + message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") + if(DEFINED BNB_ROCM_ARCH) + set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH}) + else() + if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100;gfx1101;gfx1150;gfx1151;gfx1200;gfx1201") + elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + endif() + endif() + message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}") + + list(APPEND SRC_FILES ${GPU_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_rocm") + + # get hip version + execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) + string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") + string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") + + string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + add_compile_definitions(BUILD_HIP) +elseif(BUILD_MPS) + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + + enable_language(OBJCXX) + + list(APPEND SRC_FILES ${MPS_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_mps") + add_compile_definitions(BUILD_MPS) + file(MAKE_DIRECTORY "build") + add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib" + COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES} + COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib" + DEPENDS "${METAL_FILES}" + COMMENT "Compiling Metal kernels" + VERBATIM) + add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_XPU) + list(APPEND SRC_FILES ${XPU_FILES}) + string(APPEND BNB_OUTPUT_NAME "_xpu") + add_compile_definitions(BUILD_XPU) + set(CMAKE_C_COMPILER icx) + set(CMAKE_CXX_COMPILER icpx) + if(WIN32) + set(CMAKE_CXX_COMPILER icx) + endif() +else() + string(APPEND BNB_OUTPUT_NAME "_cpu") + set(GPU_SOURCES) +endif() + + +if(WIN32) + # Export all symbols + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) +endif() + +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") +endif() + +set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) +add_library(bitsandbytes SHARED ${SRC_FILES}) +target_compile_features(bitsandbytes PUBLIC cxx_std_17) +target_include_directories(bitsandbytes PUBLIC csrc) + +if (BUILD_CPU) + if (OpenMP_CXX_FOUND) + target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) + add_definitions(-DHAS_OPENMP) + endif() + + if ((HOST_ARCH MATCHES "x86_64|amd64") AND (NOT MSVC)) + include(CheckCXXCompilerFlag) + check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG) + check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG) + if (HAS_AVX512F_FLAG) + target_compile_options(bitsandbytes PRIVATE -mavx512f) + target_compile_options(bitsandbytes PRIVATE -mavx512dq) + target_compile_options(bitsandbytes PRIVATE -mavx512bw) + target_compile_options(bitsandbytes PRIVATE -mavx512vl) + endif() + if (HAS_AVX512BF16_FLAG) + target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + endif() + target_compile_options( + bitsandbytes PRIVATE + -mprefer-vector-width=256 + -mfma + -mavx2 + -mlzcnt + -mbmi + -mbmi2 + ) + endif() +endif() + + +if(BUILD_CUDA) + target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse) + set_target_properties(bitsandbytes + PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + ) +endif() +if(BUILD_HIP) + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + endmacro() + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + + ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) + target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) + set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP) + set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) + + if(HIP_VERSION VERSION_LESS "6.1") + target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT) + else() + find_package(hipblaslt) + target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt) + endif() +endif() +if(BUILD_MPS) + add_dependencies(bitsandbytes metallib) + target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") +endif() +if(BUILD_XPU) + set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'") + set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;") + + set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) + target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) + target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) + +endif() + +if(WIN32) + set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") +endif() +set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME}) +if(MSVC) + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") +endif() + +set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bitsandbytes") diff --git a/csrc/examples/common_unified.cuh b/csrc/examples/common_unified.cuh index 081dc7780..48e86be3c 100644 --- a/csrc/examples/common_unified.cuh +++ b/csrc/examples/common_unified.cuh @@ -1,6 +1,6 @@ -// common_unified.cuh — Merged architecture constants for CUDA and HIP +// common.cuh — Merged architecture constants for CUDA and HIP // -// This replaces both csrc/common.cuh and csrc/common_hip.cuh. +// This replaces both the old csrc/common.cuh and csrc/common_hip.cuh. // Platform detection uses compat.cuh's BNB_HIP macro. #pragma once diff --git a/csrc/compat.cuh b/csrc/examples/compat.cuh similarity index 87% rename from csrc/compat.cuh rename to csrc/examples/compat.cuh index 53188e3a7..f2116ef9f 100644 --- a/csrc/compat.cuh +++ b/csrc/examples/compat.cuh @@ -28,62 +28,14 @@ #include #include #include +#include +#include #else // CUDA #include #include #include -#include -#include - -#endif - -// ============================================================================ -// CUB / hipCUB — namespace alias -// -// Usage: bnb_cub::BlockLoad<...>, bnb_cub::BlockReduce<...>, etc. -// This single alias eliminates ~90% of the cub:: vs hipcub:: differences. -// ============================================================================ - -#if BNB_HIP - -#include -namespace bnb_cub = hipcub; - -#else // CUDA - -#include -#include -#include -#include -#include -#include -#include -namespace bnb_cub = cub; - -#endif - -// ============================================================================ -// Reduction operators — CUB's Max()/Sum() API differs across versions -// ============================================================================ - -#if BNB_HIP - -#define BNB_MAX_OP hipcub::Max() -#define BNB_SUM_OP hipcub::Sum() - -#else // CUDA - -// CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max() -#if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002 -#include -#define BNB_MAX_OP \ - cuda::maximum<> {} -#else -#define BNB_MAX_OP cub::Max() -#endif -#define BNB_SUM_OP cub::Sum() #endif @@ -101,6 +53,7 @@ using bnb_error_t = hipError_t; #define BNB_GET_ERROR_STRING(e) hipGetErrorString(e) #define BNB_DEVICE_MALLOC(p, s) hipMalloc(p, s) #define BNB_DEVICE_FREE(p) hipFree(p) +#define BNB_DEVICE_MEMSET(p, v, s) hipMemset(p, v, s) #else // CUDA @@ -112,6 +65,7 @@ using bnb_error_t = cudaError_t; #define BNB_GET_ERROR_STRING(e) cudaGetErrorString(e) #define BNB_DEVICE_MALLOC(p, s) cudaMalloc(p, s) #define BNB_DEVICE_FREE(p) cudaFree(p) +#define BNB_DEVICE_MEMSET(p, v, s) cudaMemset(p, v, s) #endif @@ -239,6 +193,8 @@ using bnb_blas_status_t = cublasStatus_t; #include using bnb_sparse_handle_t = hipsparseHandle_t; +using bnb_sparseSpMatDescr_t = hipsparseSpMatDescr_t; +using bnb_sparseDnMatDescr_t = hipsparseDnMatDescr_t; #define bnb_sparseCreate hipsparseCreate #define bnb_sparseCreateCoo hipsparseCreateCoo @@ -269,6 +225,8 @@ using bnb_sparse_handle_t = hipsparseHandle_t; #include using bnb_sparse_handle_t = cusparseHandle_t; +using bnb_sparseSpMatDescr_t = cusparseSpMatDescr_t; +using bnb_sparseDnMatDescr_t = cusparseDnMatDescr_t; #define bnb_sparseCreate cusparseCreate #define bnb_sparseCreateCoo cusparseCreateCoo diff --git a/csrc/examples/compat_device.cuh b/csrc/examples/compat_device.cuh new file mode 100644 index 000000000..586ee8cca --- /dev/null +++ b/csrc/examples/compat_device.cuh @@ -0,0 +1,58 @@ +// compat_device.cuh — Device-only portability layer (CUB, reduction ops, MMA) +// +// Include this from .cu kernel files only (compiled by nvcc/hipcc). +// Do NOT include from .cpp files — use compat.cuh instead for host-safe types. + +#pragma once + +#include "compat.cuh" + +// ============================================================================ +// CUB / hipCUB — namespace alias +// +// Usage: bnb_cub::BlockLoad<...>, bnb_cub::BlockReduce<...>, etc. +// This single alias eliminates ~90% of the cub:: vs hipcub:: differences. +// ============================================================================ + +#if BNB_HIP + +#include +namespace bnb_cub = hipcub; + +#else // CUDA + +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace bnb_cub = cub; + +#endif + +// ============================================================================ +// Reduction operators — CUB's Max()/Sum() API differs across versions +// ============================================================================ + +#if BNB_HIP + +#define BNB_MAX_OP hipcub::Max() +#define BNB_SUM_OP hipcub::Sum() + +#else // CUDA + +// CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max() +#if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002 +#include +#define BNB_MAX_OP \ + cuda::maximum<> {} +#else +#define BNB_MAX_OP cub::Max() +#endif +#define BNB_SUM_OP cub::Sum() + +#endif diff --git a/csrc/examples/kernels_unified.cu b/csrc/examples/kernels_unified.cu index 7244dbce2..d3c16c136 100644 --- a/csrc/examples/kernels_unified.cu +++ b/csrc/examples/kernels_unified.cu @@ -1,45 +1,17 @@ -// kernels_unified.cu — EXAMPLE of merged CUDA/HIP kernel source -// -// This file demonstrates how kernels.cu and kernels.hip can be unified -// into a single source file. It shows representative kernels covering -// all categories of differences: -// -// 1. Shared code (identical on both platforms) — kQuantize, kQuantizeBlockwise -// 2. Platform-specific atomics — atomicMax (CUDA needs custom, HIP has native) -// 3. Warp-size-dependent kernels — kQuantizeBlockwiseSmall (replaces -// kQuantizeBlockwise32 on CUDA and kQuantizeBlockwise64 on HIP) -// 4. Template instantiations — bnb_bfloat16 alias for __nv_bfloat16 / hip_bfloat16 -// -// Key principles: -// - Include "compat.cuh" for all platform abstractions -// - Use bnb_cub:: instead of cub:: or hipcub:: -// - Use BNB_MAX_OP / BNB_SUM_OP instead of cub::Max() / hipcub::Max() -// - Use bnb_bfloat16 instead of __nv_bfloat16 / hip_bfloat16 -// - Use #if BNB_HIP for truly divergent sections -// - <<>> syntax works on both platforms (HIP supports it natively) -// -// This file compiles as: -// - CUDA: nvcc compiles it as .cu (default) -// - HIP: CMake sets LANGUAGE HIP on this .cu file, hipcc compiles it -// // Copyright (c) Facebook, Inc. and its affiliates. +// // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include "common.cuh" // merged common_unified.cuh in the real version -#include "compat.cuh" -#include "kernels.cuh" // merged kernel declarations -#include // DataType_t enum +#include "common.cuh" +#include "compat_device.cuh" +#include "kernels.cuh" #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 -// ============================================================================ -// Lookup tables — identical on both platforms -// ============================================================================ - __device__ static float fp4_dequantization_lut[8] = { 0.0f, // 0b000 0.005208333333f, // 0b001 @@ -70,13 +42,9 @@ __device__ static float nf4_dequantization_lut[16] = { 1.0f // 0b1111 }; -// ============================================================================ -// atomicMax for float — CUDA needs a custom CAS loop, HIP has native support -// ============================================================================ - +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +// HIP has native atomicMax for float; CUDA needs a CAS loop #if !BNB_HIP -// CUDA: no native atomicMax for float, use CAS loop -// source: https://stackoverflow.com/questions/17399119 __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); int old = *address_as_i, assumed; @@ -86,12 +54,7 @@ __device__ float atomicMax(float* address, float val) { } while (assumed != old); return __int_as_float(old); } -#endif -// HIP: atomicMax for float is available natively in ROCm — no custom impl needed - -// ============================================================================ -// Device helper functions — identical on both platforms -// ============================================================================ +#endif // !BNB_HIP __device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { float sign = 1.0f - 2 * ((val & 0b1000) >> 3); @@ -99,6 +62,26 @@ __device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { } __device__ unsigned char dQuantizeFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + int sign = x < 0 ? 0b1000 : 0b0000; x = fabsf(x); if (x > 0.29166667f) @@ -125,90 +108,164 @@ __device__ unsigned char dQuantizeFP4(float x) { __device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } __device__ unsigned char dQuantizeNF4(float x) { + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py if (x > 0.03979014977812767f) - if (x > 0.3893125355243683f) - if (x > 0.6427869200706482f) - if (x > 0.8614784181118011f) + if (x > 0.3893125355243683f) // 1 + if (x > 0.6427869200706482f) // 11 + if (x > 0.8614784181118011f) // 111 return 0b1111; else return 0b1110; - else if (x > 0.5016634166240692f) + else if (x > 0.5016634166240692f) // 110 return 0b1101; else return 0b1100; - else if (x > 0.2035212516784668f) - if (x > 0.2920137718319893f) + else if (x > 0.2035212516784668f) // 10 + if (x > 0.2920137718319893f) // 101 return 0b1011; else return 0b1010; - else if (x > 0.1202552504837513f) + else if (x > 0.1202552504837513f) // 100 return 0b1001; else return 0b1000; - else if (x > -0.33967943489551544f) - if (x > -0.13791173323988914f) - if (x > -0.045525018125772476f) + else if (x > -0.33967943489551544f) // 0 + if (x > -0.13791173323988914f) // 01 + if (x > -0.045525018125772476f) // 011 return 0b0111; else return 0b0110; - else if (x > -0.23460740596055984f) + else if (x > -0.23460740596055984f) // 010 return 0b0101; else return 0b0100; - else if (x > -0.6106329262256622f) - if (x > -0.4599952697753906f) + else if (x > -0.6106329262256622f) // 00 + if (x > -0.4599952697753906f) // 001 return 0b0011; else return 0b0010; - else if (x > -0.8480964004993439f) + else if (x > -0.8480964004993439f) // 000 return 0b0001; else return 0b0000; } -// (dQuantize<> helper omitted for brevity — same pattern, no platform diffs) -template __device__ unsigned char dQuantize(float* smem_code, float rand, float x) { - // Binary search in quantization code — identical on both platforms +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) { return (T(0) < val) - (val < T(0)); } + +template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { int pivot = 127; int upper_pivot = 255; int lower_pivot = 0; + float lower = -1.0f; + float upper = 1.0f; + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} for (int i = 64; i > 0; i >>= 1) { if (x > val) { lower_pivot = pivot; + lower = val; pivot += i; } else { upper_pivot = pivot; + upper = val; pivot -= i; } val = smem_code[pivot]; } if (upper_pivot == 255) - upper_pivot = 254; + upper = smem_code[upper_pivot]; + if (lower_pivot == 0) + lower = smem_code[lower_pivot]; - if (STOCHASTIC) { - if (rand >= (x - smem_code[lower_pivot]) / (smem_code[upper_pivot] - smem_code[lower_pivot])) - return lower_pivot; - else + if (!STOCHASTIC) { + if (x > val) { + float midpoint = (upper + val) * 0.5f; + if (x > midpoint) { + return upper_pivot; + } else + return pivot; + } else { + float midpoint = (lower + val) * 0.5f; + if (x < midpoint) + return lower_pivot; + else + return pivot; + } + } else { + if (x > val) { + float dist_to_upper = fabsf(upper - x); + float dist_full = upper - val; + if (rand >= dist_to_upper / dist_full) + return upper_pivot; + else + return pivot; + } else { + float dist_to_lower = fabsf(lower - x); + float dist_full = val - lower; + if (rand >= dist_to_lower / dist_full) + return lower_pivot; + else + return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char + quantize_2D(float* __restrict__ quadrants, float* __restrict__ const smem_code, float x) { + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for (int i = 64; i > 0; i >>= 1) { + if (x > val) { + lower_pivot = pivot; + lower = val; + pivot += i; + // val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } else { + upper_pivot = pivot; + upper = val; + pivot -= i; + // val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if (x > val) { + midpoint = (upper + val) * 0.5f; + if (x > midpoint) return upper_pivot; + else + return pivot; } else { - if (fabsf(x - smem_code[lower_pivot]) < fabsf(x - smem_code[upper_pivot])) + midpoint = (lower + val) * 0.5f; + if (x < midpoint) return lower_pivot; else - return upper_pivot; + return pivot; } } -// ============================================================================ -// kQuantize — fully shared, zero #ifdefs needed -// -// Before (CUDA): typedef cub::BlockLoad<...> -// Before (HIP): typedef hipcub::BlockLoad<...> -// After (unified): typedef bnb_cub::BlockLoad<...> -// ============================================================================ - __launch_bounds__(TH, 4) __global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n) { const int n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); @@ -217,19 +274,26 @@ __launch_bounds__(TH, 4) __global__ float vals[NUM]; unsigned char qvals[NUM]; + // const int lane_id = threadIdx.x % 2; - // vvvvvvvv unified namespace alias — resolves to cub:: or hipcub:: typedef bnb_cub::BlockLoad LoadFloat; typedef bnb_cub::BlockStore StoreChar; __shared__ typename LoadFloat::TempStorage loadf; __shared__ typename StoreChar::TempStorage storec; __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; - if (threadIdx.x < 256) + if (threadIdx.x < 256) { smem_code[threadIdx.x] = code[threadIdx.x]; + // smem_code[0][threadIdx.x] = code[threadIdx.x]; + // smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_BLOCK) { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; __syncthreads(); @@ -244,20 +308,15 @@ __launch_bounds__(TH, 4) __global__ } } -// ============================================================================ -// kQuantizeBlockwise — fully shared, uses BNB_MAX_OP -// -// The only change vs the original CUDA version: -// cub:: → bnb_cub:: -// CUB_REDUCTIONOP_MAX → BNB_MAX_OP -// ============================================================================ - template +//__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ) { + // This can overflow, so we clamp to INT32_MAX. We won't have more elements than this. const int n_full = min(gridDim.x * BLOCK_SIZE, INT32_MAX); + const int base_idx = blockIdx.x * BLOCK_SIZE; int valid_items = 0; @@ -295,11 +354,14 @@ __global__ void kQuantizeBlockwise( __syncthreads(); LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - // vvvvvvvvvv unified reduction op local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, BNB_MAX_OP, valid_items); if (threadIdx.x == 0) { @@ -348,63 +410,56 @@ __global__ void kQuantizeBlockwise( } } -// ============================================================================ -// kQuantizeBlockwiseSmall — unified warp-size-dependent kernel -// -// This replaces: -// CUDA: kQuantizeBlockwise32 (32 threads, blocksize=32, WarpReduce) -// HIP: kQuantizeBlockwise64 (64 threads, blocksize=64, WarpReduce) -// -// Strategy: Use BNB_WARP_SIZE to derive all constants at compile time. -// On CUDA (warp=32): SMALL_BLOCK_SIZE=32, THREADS=32, THREADS_PER_BLOCK=16 -// On HIP (warp=64): SMALL_BLOCK_SIZE=64, THREADS=64, THREADS_PER_BLOCK=32 -// On HIP (warp=32): SMALL_BLOCK_SIZE=32, THREADS=32, THREADS_PER_BLOCK=16 -// -// The algorithm is identical — only the numeric constants change. -// ============================================================================ - +// Unified small-blocksize kernel for 4-bit quantization +// Processes 2 blocks of BNB_WARP_SIZE values per thread block +// On CUDA (warp=32): blocksize=32, 32 threads, WarpReduce<16> +// On HIP (warp=64): blocksize=64, 64 threads, WarpReduce<32> +// On HIP (warp=32): blocksize=32, 32 threads, WarpReduce<16> template __global__ void kQuantizeBlockwiseSmall( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ) { - // All constants derived from BNB_WARP_SIZE — no #ifdefs needed! - constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // 32 on CUDA, 32 or 64 on HIP - constexpr int NUM_PER_TH = 2; - constexpr int THREADS = BNB_WARP_SIZE; // One full hardware warp + constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // Size of each quantization block + constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing) + constexpr int THREADS = BNB_WARP_SIZE; // Total threads (one full warp) constexpr int THREADS_PER_BLOCK = BNB_WARP_SIZE / 2; // Half-warp per quantization block - const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 quantization blocks per thread block + const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per thread block T vals[NUM_PER_TH]; - unsigned char qvals[NUM_PER_TH / 2]; + unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte float local_abs_max = 0.0f; - const int block_id = threadIdx.x / THREADS_PER_BLOCK; - const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; + const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31 + const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15) typedef bnb_cub::BlockLoad LoadT; typedef bnb_cub::BlockStore StoreChar; - // Logical warp of THREADS_PER_BLOCK: on warp32 HW this is a half-warp, - // on warp64 HW this splits the single HW warp into two logical warps - typedef bnb_cub::WarpReduce WarpReduce; + typedef bnb_cub::WarpReduce + WarpReduce; // Half-warp logical reduction: each half reduces independently __shared__ typename LoadT::TempStorage loadt; __shared__ typename StoreChar::TempStorage storec; - __shared__ typename WarpReduce::TempStorage warp_reduce[2]; + __shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp __shared__ float smem_absmax_value[2]; const int i = base_idx + block_id * BLOCK_SIZE; + // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative + // operations that require ALL 32 threads to participate const bool block_valid = (i < n); + // All 32 threads participate in the load (out-of-bounds threads get 0.0f) __syncthreads(); LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f); + // Each thread computes max of its values local_abs_max = -FLT_MAX; #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + // Reduce within each logical warp of 16 threads independently local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, BNB_MAX_OP); if (local_thread_id == 0) { @@ -436,17 +491,15 @@ __global__ void kQuantizeBlockwiseSmall( break; } + // All 32 threads participate in the store (valid_items limits the actual writes) __syncthreads(); StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2)); } -// ============================================================================ -// kDequantizeBlockwise — fully shared -// ============================================================================ - template __global__ void kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) { + const int n_load = (gridDim.x * TILE_SIZE); int valid_items_load = 0; int valid_items_store = 0; @@ -465,6 +518,7 @@ __global__ void for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { if (DATA_TYPE > 0) { + // Cast n to int64_t to avoid overflow for large n valid_items_load = min(TILE_SIZE, static_cast((static_cast(n) + 1) / 2) - i); valid_items_store = min(TILE_SIZE * 2, n - i * 2); } else { @@ -472,32 +526,33 @@ __global__ void valid_items_store = valid_items_load; } - // blocksize is always power-of-2: use bitwise AND instead of division + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) >> (31 - __clz(blocksize))]); + __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); switch (DATA_TYPE) { case General8bit: +// load code through read-only cache via __ldg #pragma unroll NUM_PER_TH - for (int j = 0; j < NUM_PER_TH; j++) { - local_abs_max = absmax[(i + (threadIdx.x * NUM_PER_TH) + j) / blocksize]; - vals[j] = (T)(code[qvals[j]] * local_abs_max); - } + for (int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]]) * local_abs_max; break; case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { - local_abs_max = absmax[((i * 2) + (threadIdx.x * NUM_PER_TH * 2) + (j * 2)) / blocksize]; - vals[j * 2] = (T)(dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max); - vals[j * 2 + 1] = (T)(dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max); + vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max; } break; case NF4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { - local_abs_max = absmax[((i * 2) + (threadIdx.x * NUM_PER_TH * 2) + (j * 2)) / blocksize]; - vals[j * 2] = (T)(dDequantizeNF4(qvals[j] >> 4) * local_abs_max); - vals[j * 2 + 1] = (T)(dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max); + vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; } break; } @@ -507,94 +562,2042 @@ __global__ void } } -// ============================================================================ -// Template instantiations — bnb_bfloat16 replaces __nv_bfloat16 / hip_bfloat16 -// ============================================================================ +__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n) { + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; -#define MAKE_kQuantizeBlockwise(dtype, block_size, num_per_th, stochastic, data_type_name) \ - template __global__ void kQuantizeBlockwise( \ - float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ - const int rand_offset, const int n \ - ); + __shared__ float smem_code[256]; + if (threadIdx.x < 256) { + smem_code[threadIdx.x] = code[threadIdx.x]; + } -// half instantiations -MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) -// ... (remaining half/float instantiations identical to current) + __syncthreads(); -// float instantiations -MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) -// ... (remaining float instantiations) + for (int i = idx; i < n; i += numThreads) { + out[i] = smem_code[A[i]]; + } +} -// bnb_bfloat16 — resolves to __nv_bfloat16 on CUDA, hip_bfloat16 on HIP -MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, FP4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4) +template +__launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +) { -// Unified small-blocksize kernel instantiations -#define MAKE_kQuantizeBlockwiseSmall(dtype, data_type_name) \ - template __global__ void kQuantizeBlockwiseSmall( \ - float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ - const int rand_offset, const int n \ - ); + const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; -MAKE_kQuantizeBlockwiseSmall(half, FP4) MAKE_kQuantizeBlockwiseSmall(float, FP4) MAKE_kQuantizeBlockwiseSmall( - bnb_bfloat16, FP4 -) MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float, NF4) MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, NF4) + T g_vals[NUM_VALS]; - // Dequantize instantiations - template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n - ); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n -); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n -); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n -); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n -); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n -); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n -); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n -); -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n -); + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f / (1.0f - powf(beta1, step)); + const float correction2 = 1.0f / (1.0f - powf(beta2, step)); + + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) { + switch (OPTIMIZER) { + case ADAM: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j])); + s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + case ADEMAMIX: + break; + } + } + +#pragma unroll NUM_VALS - 1 + for (unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if (threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } +} + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +) { + + const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) + + (n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + // AdEMAMix has an additional state buffer, which we packed + // into state1. We need thread-local storage here for these. + // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. + float s3_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr * correction2 / correction1; + + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm) { + update_scale = (max_unorm * param_norm) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; + } + + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockStore Store; + + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + // Load additional state1 data for AdEMAMix + // TODO: Make constexpr after updating min compiler + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); + } + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { + switch (OPTIMIZER) { + case ADEMAMIX: + // m1 update: m1 = beta1 * m1 + (1-beta1) * g + s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); + + // m2 update: m2 = m2 * beta3 + (1-beta3) * g + s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]); + + // nu update: nu = beta2 * nu + (1-beta2) * g^2 + s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); + + p_vals[j] = (float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / + ((sqrtf(s2_vals[j]) / correction2) + eps)); + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + + break; + case ADAM: + + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j])); + s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + + (update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (eps * correction2)))); + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); + } + } +} + +template +__launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +) { + + const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + +#pragma unroll NUM_VALS + for (unsigned int j = 0; j < NUM_VALS; j++) { + switch (OPTIMIZER) { + case MOMENTUM: + if (step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * (float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = + s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value + s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value + s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm + break; + } + } + +#pragma unroll + for (unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if (threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, + const float beta2, const float eps, const float weight_decay, const int step, const float lr, + const float gnorm_scale, const bool skip_zeros, const int n +) { + + const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) + + (n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm + eps) { + update_scale = (max_unorm * param_norm + eps) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; + } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockStore Store; + + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { + g_vals[j] = gnorm_scale * ((float)g_vals[j]); + if (weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j]) * weight_decay); + } + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + switch (OPTIMIZER) { + case MOMENTUM: + if (step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale * (-lr * (s1_vals[j])); + break; + case LION: + p_vals[j] = + ((float)p_vals[j]) - + update_scale * (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_vals[j])))); + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * ((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - + update_scale * (lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit2State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, + float* unorm, const float beta1, const float beta2, const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, const float gnorm_scale, const int n +) { + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = + n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad LoadUInt8; + typedef bnb_cub::BlockReduce BlockReduce; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if (threadIdx.x < 256) { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS * gridDim.x * NUM8BIT) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0] * beta1; + s1_vals[j] += (1.0f - beta1) * g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]] * max2[0] * beta2; + s2_vals[j] += (1.0f - beta2) * g_val * g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if (unorm != NULL) { +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update + local_unorm += update_val * update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, BNB_MAX_OP, valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, BNB_MAX_OP, valid_items); + if (unorm != NULL) { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items); + } + + if (threadIdx.x == 0) { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if (unorm != NULL) { + atomicAdd(&unorm[0], local_unorm); + } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void __launch_bounds__(NUM_THREADS2, 1) kOptimizerStatic8bit2State( + T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n +) { + + const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr * correction2 / correction1; + // const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f / new_max1[0]; + float new_max_val2 = 1.0f / new_max2[0]; + float update_scale = 1.0f; + + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm) { + update_scale = (max_unorm * param_norm) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; + } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad + LoadChar; + + typedef bnb_cub::BlockStore + StoreChar; + typedef bnb_cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if (threadIdx.x < 512) { + if (threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x - 256] = quantiles2[threadIdx.x - 256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) { + continue; + } + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j] * max1[0]; + + s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j] * max2[0]; + s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j] * new_max_val2); + } + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { + p_vals[j] = (T)(((float)p_vals[j]) + + ((update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (correction2 * eps)))))); + if (weight_decay > 0.0f) + p_vals[j] = update_scale * ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + +template +__global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, + const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, + float* new_max1, const float weight_decay, const float gnorm_scale, const int n +) { + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = + n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad LoadUInt8; + typedef bnb_cub::BlockReduce BlockReduce; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if (threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS * NUM8BIT) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + +#pragma unroll 16 + for (int j = 0; j < NUM8BIT; j++) { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0]; + switch (OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + if (step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); + if (unorm != NULL) + local_unorm += s1_vals[j] * s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, BNB_MAX_OP, valid_items); + if (threadIdx.x == 0) { + atomicMax(&new_max1[0], local_max_s1); + } + if (unorm != NULL) { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items); + if (threadIdx.x == 0) { + atomicAdd(&unorm[0], local_unorm); + } + } +} + +template +__global__ void __launch_bounds__(1024, 1) kOptimizerStatic8bit1State( + T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, + const int n +) { + + const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f / new_max1[0]; + float update_scale = 1.0f; + + if (max_unorm > 0.0f) { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if (update_scale > max_unorm * param_norm) { + update_scale = (max_unorm * param_norm) / update_scale; + } else { + update_scale = 1.0f; + } + } else { + update_scale = 1.0f; + } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad + LoadChar; + + typedef bnb_cub::BlockStore + StoreChar; + typedef bnb_cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if (threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) { + valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) { + continue; + } + +#pragma unroll 4 + for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if (weight_decay > 0.0f) { + switch (OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j]) * weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]] * max1[0]; + + switch (OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + if (step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr * update_scale * (s1_vals[j])); + break; + case LION: + p_vals[j] = + ((float)p_vals[j]) - (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_val)))); + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr * __fdividef(g_val, sqrtf(s1_vals[j]) + eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1); + + // make sure state1 term has still the same sign after quantization + if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + +template +__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n) { + const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef bnb_cub::BlockReduce BlockReduce; + typedef bnb_cub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + +#pragma unroll NUM_VALS + for (int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j]) * ((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if (threadIdx.x == 0) { + if (step == 1) { + // initialize with the same norm for all positions + // #pragma unroll 10 + for (int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + } +} + +#define LANES 2 +#define QUAD 3 + +template +__launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, + const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n +) { + + // const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + float s3_vals[N_PER_TH]; + + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f - __powf(beta2, step)); + const float step_size = __fdividef(-lr * correction2, correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float new_local_abs_max3 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + unsigned char c3s[N_PER_TH]; + + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad + LoadChar; + + typedef bnb_cub::BlockStore + StoreChar; + typedef bnb_cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef bnb_cub::BlockReduce BlockReduce1; + typedef bnb_cub::BlockReduce BlockReduce2; + typedef bnb_cub::BlockReduce BlockReduce3; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ typename BlockReduce2::TempStorage reduce3; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + __shared__ float smem_exchange3[1]; // [[maybe_unused]] + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; +#pragma unroll + for (unsigned int j = 1; j < LANES; j++) { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + +#pragma unroll + for (int k = 0; k < QUAD; k++) { + quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; + } + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + // AdEMAMix has an additional state packed into state1. + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); + } + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + new_local_abs_max3 = -FLT_MAX; + +// update: 2.48/1.57 -> 2.51/1.60 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]] * absmax2[i / BLOCK_SIZE]; + g_val = g_vals[j]; + // float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + // g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val)); + + if (OPTIMIZER == ADEMAMIX) { + // The absmax for the third state is appended to absmax1 + s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i) / BLOCK_SIZE]; + s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); + } + } else { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + + if (OPTIMIZER == ADEMAMIX) { + s3_vals[j] = 0.0f; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); + } + } + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, BNB_MAX_OP); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, BNB_MAX_OP); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, BNB_MAX_OP); + } + + if (threadIdx.x == 0) { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + smem_exchange3[0] = new_local_abs_max3; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + absmax1[i / BLOCK_SIZE] = new_local_abs_max1; + absmax2[i / BLOCK_SIZE] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + absmax1[(n + i) / BLOCK_SIZE] = new_local_abs_max3; + } + } else { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = smem_exchange3[0]; + } + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); +// reduce: 2.67/1.69 -> 2.67/1.70 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + // if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { + if (OPTIMIZER == ADEMAMIX) { + p_vals[j] = + T((float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / + ((sqrtf(s2_vals[j]) / correction2) + eps))); + } else { + p_vals[j] = + (T)(((float)p_vals[j]) + + ((step_size * (__fdividef(s1_vals[j], (sqrtf(s2_vals[j]) + (correction2 * eps))))))); + } + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + +// quantizaztion: 2.67/1.70 -> 3.4/3.3 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j], new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + if (OPTIMIZER == ADEMAMIX) { + c3s[j] = + quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j], new_local_abs_max3)); + + if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { + c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; + } + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); + } + } +} + +#define LANES 2 +#define QUAD 3 + +template +__launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, + const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n +) { + + // const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad + LoadChar; + + typedef bnb_cub::BlockStore + StoreChar; + typedef bnb_cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef bnb_cub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; +#pragma unroll + for (unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + +#pragma unroll + for (int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + +// update: 2.48/1.57 -> 2.51/1.60 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + if (weight_decay > 0.0f) { + switch (OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j]) * weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE]; + + switch (OPTIMIZER) { + case MOMENTUM: + if (step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j] * beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, + // before the momentum is updated by beta2 + g_vals[j] = lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * g_val)); + s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val * g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, BNB_MAX_OP); + + if (threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if (threadIdx.x == 0) + absmax1[i / BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + +// reduce: 2.67/1.69 -> 2.67/1.70 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { + switch (OPTIMIZER) { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr * (s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + +// quantizaztion: 2.67/1.70 -> 3.4/3.3 +#pragma unroll N_PER_TH + for (unsigned int j = 0; j < N_PER_TH; j++) { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { + if (s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +// Inputs: +// A [rows, cols] +// Outputs: +// rowStats [rows] +// out [rows, cols] +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ + void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { + + using BlockReduceT = bnb_cub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + __shared__ T smem_row_absmax; + + const int row_id = blockIdx.x; + const T* row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + T row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const T absval = fabsf(__ldcs(&(row_data[i]))); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < T(threshold) ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, BNB_MAX_OP, cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = smem_row_absmax = row_absmax; + } + __syncthreads(); + + // Quantize row-wise. + const float scale = __fdividef(127.0f, smem_row_absmax); + for (int i = threadIdx.x; i < cols; i += THREADS) { + float val = row_data[i]; + + if constexpr (SPARSE_DECOMP) { + // For sparse decomposition, we do not want to quantize the outliers. + // Instead they're zeroed out. + out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; + } else { + out[row_id * cols + i] = __float2int_rn(val * scale); + } + } +} + +template __global__ void kInt8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols +); +template __global__ void kInt8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols +); + +#define MM_DEQUANT_CONST 6.200012e-05f // 1.0f/(127.0f*127.0f) + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n +) { + const int n_out = numRows * numCols; + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + + float local_rowStats[ITEMS_PER_THREAD]; + float local_colStats[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + + typedef bnb_cub::BlockLoad LoadInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + + int row_idx, col_idx; + +#pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + + local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]); + local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]); + local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); + } + + // Each block loads THREADS * ITEMS_PER_THREAD values from A + int valid_items = + block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + +#pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + local_output[j] = __float2half( + fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) + ); + } + +#pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int outIdx = block_offset + thread_offset + j; + if (outIdx < n_out) { + out[outIdx] = local_output[j]; + } + } +} + +#define DENORM 1.0f / 127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8 * 256 + +template +__global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +) { + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx - 1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id * 32) * SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for (int j = 0; j < MAX_SPARSE_COUNT; j++) { + local_valA[j] = j < count ? values[offset + j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset + j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + while (idx_col_B < colsB) { + + if (dequant_stats != NULL) { + for (int i = threadIdx.x; i < SMEM_SIZE; i += blockDim.x) + if ((idx_col_B + i - local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B + i - local_idx_col_B_offset]; + + __syncthreads(); + } + +#pragma unroll SPMM_ITEMS + for (int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + +#pragma unroll + for (int i = 0; i < count; i++) { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB * local_colidxA[i]; + +#pragma unroll SPMM_ITEMS + for (int j = 0; j < SPMM_ITEMS; j += num_items) { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx * SPMM_ITEMS) + j; + if (idx >= colsB) { + break; + } + if ((idx + num_items < colsB)) { + if (BITS == 8) + reinterpret_cast(local_valsB)[0] = + reinterpret_cast(B)[(row_offset + idx) / num_items]; + else + reinterpret_cast(local_valsB)[0] = + reinterpret_cast(B)[(row_offset + idx) / num_items]; + } else { +#pragma unroll num_items + for (int k = 0; k < num_items; k++) + if (idx + k < colsB) + local_valsB[k] = B[row_offset + idx + k]; + else + local_valsB[k] = 0.0f; + } +#pragma unroll num_items + for (int k = 0; k < num_items; k++) { + if (BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if (valB != 0.0 && valA != 0.0) + local_valC[j + k] = + (float)local_valC[j + k] + + ((float)smem_dequant_stats[idx + k - local_idx_col_B_offset]) * DENORM * valB * valA; + } else + local_valC[j + k] = (float)local_valC[j + k] + (float)local_valsB[k] * (float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB * local_row_idx); + +#pragma unroll SPMM_ITEMS + for (int j = 0; j < SPMM_ITEMS; j += num_items) { + // int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx * SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if (idx_col_C + num_items < colsB) { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = + reinterpret_cast(out)[idx_val / num_items]; + +#pragma unroll num_items + for (int k = 0; k < num_items; k++) + local_valC[(j / num_items) + k] = (float)local_valC[(j / num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val / num_items] = + reinterpret_cast(local_valC)[j / num_items]; + } else { +#pragma unroll num_items + for (int k = 0; k < num_items; k++) + if (idx_col_C + k < colsB) + out[idx_val + k] = (float)out[idx_val + k] + (float)local_valC[j + k]; + } + } + + idx_col_B += blockDim.x * SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x * SPMM_ITEMS; + } +} + +#define num_values_4bit 32 + +template +__global__ void kgemm_4bit_inference_naive( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, + int lda, int ldb, int ldc, int blocksize +) { + + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef bnb_cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int offset_B = ldb * row_B; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + if (threadIdx.x < 16) + quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); + // for(int i = threadIdx.x; i < 16; i++) + // quant_map[i] = T(__ldg(&datatype[i])); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + const int inner_idx_halved = inner_idx / 2; + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize)); + + local_absmax = __ldg(&(absmax[absidx])); + + if (row_B < M) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { +#if BNB_BF16_AVAILABLE + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; +#else + // bf16 multipliation not supported + local_B[k * 2] = + T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax); + local_B[k * 2 + 1] = + T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax); +#endif + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + + } else +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + +// accumulate in float; small performance hit for Ampere, but lower error for outputs +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { +#if BNB_BF16_AVAILABLE + local_C += (float)(local_A[k] * local_B[k]); +#else + // bf16 multipliation not supported + local_C += ((float)local_A[k] * (float)local_B[k]); +#endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < M && warp_lane == 0) + out[row_B] = T(local_C); +} + +template __global__ void kfunc(T* A, T* B, T value, long n) { + for (long i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += (blockDim.x * gridDim.x)) { + switch (FUNC) { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i] * B[i]; + break; + } + } +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float* A, float* B, float value, long n); +template __global__ void kfunc(unsigned char* A, unsigned char* B, unsigned char value, long n); +template __global__ void kfunc(float* A, float* B, float value, long n); +template __global__ void kfunc(float* A, float* B, float value, long n); + +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out, + int lda, int ldb, int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, bnb_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, + bnb_bfloat16* out, int lda, int ldb, int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, + float* out, int lda, int ldb, int ldc, int blocksize +); + +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); +template __global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); + +template __global__ void kdequant_mm_int32_fp16<4, 512>( + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n +); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ + template __global__ void kPreconditionOptimizer32bit1State( \ + gtype * g, gtype * p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, \ + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n \ + ); + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, bnb_bfloat16) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, bnb_bfloat16) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, bnb_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, bnb_bfloat16) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ + template __global__ void kOptimizer32bit1State( \ + gtype * g, gtype * p, float* state1, float* unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, \ + const float lr, const float gnorm_scale, const bool skip_zeros, const int n \ + ); + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(MOMENTUM, bnb_bfloat16) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(RMSPROP, bnb_bfloat16) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, bnb_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) +MAKE_Optimizer32bit1State(ADAGRAD, bnb_bfloat16) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ + template __global__ void kPreconditionOptimizer32bit2State( \ + gtype * g, gtype * p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, \ + const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, \ + const int n \ + ); + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, bnb_bfloat16) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, bnb_bfloat16) + +template __global__ void kOptimizer32bit2State( + float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + bnb_bfloat16* g, bnb_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); +template __global__ void kOptimizer32bit2State( + bnb_bfloat16* g, bnb_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ + template __global__ void kPreconditionOptimizerStatic8bit1State( \ + gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, \ + const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n \ + ); + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) +MAKE_PreconditionStatic8bit1State(ADAGRAD, half) +MAKE_PreconditionStatic8bit1State(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ + template __global__ void kOptimizerStatic8bit1State( \ + gtype * p, gtype* const g, unsigned char* state1, const float* unorm, const float max_unorm, \ + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, \ + const float gnorm_scale, const int n \ + ); + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) +MAKE_optimizerStatic8bit1State(ADAGRAD, half) +MAKE_optimizerStatic8bit1State(ADAGRAD, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ + template __global__ void kPreconditionOptimizerStatic8bit2State( \ + gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, \ + unsigned char* __restrict__ const state2, float* unorm, const float beta1, const float beta2, const float eps, \ + const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, \ + float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n \ + ); + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ + template __global__ void kOptimizerStatic8bit2State( \ + gtype * p, gtype* const g, unsigned char* state1, unsigned char* state2, const float* unorm, \ + const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, \ + const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, \ + const int n \ + ); + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void + kPercentileClipping(float* __restrict__ g, float* gnorm_vec, int step, const int n); +template __global__ void + kPercentileClipping(half* __restrict__ g, float* gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ + template __global__ void kQuantizeBlockwise( \ + float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ + const int rand_offset, const int n \ + ); + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4) + +// Template instantiations for blocksize=32 specialized kernel (4-bit only) +#define MAKE_kQuantizeBlockwiseSmall(dtype, data_type_name) \ + template __global__ void kQuantizeBlockwiseSmall( \ + float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ + const int rand_offset, const int n \ + ); + +// FP4 instantiations for blocksize=32 +MAKE_kQuantizeBlockwiseSmall(half, FP4) MAKE_kQuantizeBlockwiseSmall(float, FP4) MAKE_kQuantizeBlockwiseSmall( + bnb_bfloat16, FP4 +) + + // NF4 instantiations for blocksize=32 + MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float, NF4) MAKE_kQuantizeBlockwiseSmall( + bnb_bfloat16, NF4 + ) + + template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n + ); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n +); +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n +); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ + template __global__ void kOptimizerStatic8bit2StateBlockwise( \ + gtype * p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, \ + const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, \ + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \ + ); + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, bnb_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, bnb_bfloat16, 256, 1) + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ + template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype * p, gtype* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, \ + const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, \ + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \ + ); + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, bnb_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, bnb_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, bnb_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, bnb_bfloat16, 256, 1) diff --git a/csrc/examples/ops_unified.cu b/csrc/examples/ops_unified.cu index 902c15046..23418db61 100644 --- a/csrc/examples/ops_unified.cu +++ b/csrc/examples/ops_unified.cu @@ -1,36 +1,18 @@ -// ops_unified.cu — EXAMPLE of merged host wrappers for CUDA/HIP +// Copyright (c) Facebook, Inc. and its affiliates. // -// This replaces both csrc/ops.cu and csrc/ops.hip. Shows representative -// functions covering all categories of differences. -// -// Key points: -// - <<>> works on both CUDA and HIP (no hipLaunchKernelGGL needed) -// - BNB_CHECK_RETURN replaces CUDA_CHECK_RETURN / hip equivalent -// - bnb_stream_t replaces cudaStream_t / hipStream_t -// - #if BNB_HIP only for genuinely different library code (igemmlt, spmm_coo) - -#include "common.cuh" -#include "compat.cuh" -#include "kernels.cuh" -#include "ops_unified.cuh" +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. #include -#include +#include #include - -#if !BNB_HIP -#include -#endif +#include #define ERR_NOT_IMPLEMENTED 100 using std::cout; using std::endl; -// ============================================================================ -// Quantize / Dequantize — fully shared, <<<>>> works on both platforms -// ============================================================================ - void quantize(float* code, float* A, unsigned char* out, int n) { int num_blocks = n / 1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; @@ -45,10 +27,6 @@ void dequantize(float* code, unsigned char* A, float* out, int n, bnb_stream_t s BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } -// ============================================================================ -// quantizeBlockwise — mostly shared, small warp-size dispatch difference -// ============================================================================ - template void quantizeBlockwise( float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n @@ -69,30 +47,38 @@ void quantizeBlockwise( kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 128) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if (blocksize == 64) + else if (blocksize == 64) { +#if BNB_HIP + // On HIP with 64-wide warps (CDNA), use specialized kernel for 4-bit types + if constexpr (DATA_TYPE > 0) { + kQuantizeBlockwiseSmall + <<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n); + } else { + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + } +#else kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - // Smallest blocksize: uses unified kQuantizeBlockwiseSmall - // BNB_WARP_SIZE is the compile-time block size (32 on CUDA, 32 or 64 on HIP) - else if (blocksize == BNB_WARP_SIZE) { +#endif + } else if (blocksize == 32) { + // For 4-bit: use specialized kernel that processes 2 blocks per warp + // Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2 if constexpr (DATA_TYPE > 0) { int num_blocks_adjusted = (num_blocks + 1) / 2; kQuantizeBlockwiseSmall - <<>>(code, A, absmax, out, rand, rand_offset, n); + <<>>(code, A, absmax, out, rand, rand_offset, n); } } BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } -// ============================================================================ -// dequantizeBlockwise — fully shared -// ============================================================================ - template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, bnb_stream_t stream ) { constexpr int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + // Upcast to int64 to avoid overflow for large n int grid_blocks = ((int64_t)n + tile_size - 1) / tile_size; if (DATA_TYPE > 0) @@ -105,94 +91,321 @@ void dequantizeBlockwise( BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } -// ============================================================================ -// gemm_4bit_inference_naive — small warp-size difference in block count -// ============================================================================ +template +void optimizer32bit( + T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, const float beta1, + const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, + const float lr, const float gnorm_scale, bool skip_zeros, const int n +) { + int num_blocks = n / 4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch (OPTIMIZER) { + case ADAM: + case ADEMAMIX: + if (max_unorm > 0.0f) { + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit2State<<>>( + g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + } + kOptimizer32bit2State<<>>( + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, + gnorm_scale, skip_zeros, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if (max_unorm > 0.0f) { + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit1State + <<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + } -template -void gemm_4bit_inference_naive( - int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, - int blocksize, bnb_stream_t stream + kOptimizer32bit1State<<>>( + g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, + skip_zeros, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + break; + case LION: + // in lion, the momentum update after the parameter update + kOptimizer32bit1State<<>>( + g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, + skip_zeros, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + + if (max_unorm > 0.0f) { + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit1State + <<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + } + break; + } +} + +template +void optimizerStatic8bit( + T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, + float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n ) { - // Warp size affects how many rows each block processes - int num_blocks; - if constexpr (BNB_WARP_SIZE == 64) - num_blocks = (m + 1) / 2; - else - num_blocks = (m + 3) / 4; + int num_blocks = n / 4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - kgemm_4bit_inference_naive - <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + if (max_unorm > 0.0f) { + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(unorm, 0, 1 * sizeof(float))); + } + + switch (OPTIMIZER) { + case ADAM: + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(new_max1, 0, 1 * sizeof(float))); + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(new_max2, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit2State<<>>( + p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, + new_max2, gnorm_scale, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + kOptimizerStatic8bit2State<<>>( + p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, + max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(new_max1, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + kOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, + weight_decay, gnorm_scale, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + kOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, + weight_decay, gnorm_scale, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(new_max1, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>( + p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 256 +#define NUM_2STATE 1 +#define BLOCKSIZE_1STATE 256 +#define NUM_1STATE 1 + +template +void optimizerStatic8bitBlockwise( + T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, bool skip_zeros, int n +) { + + int num_blocks = 0; + switch (OPTIMIZER) { + case ADAM: + case ADEMAMIX: + num_blocks = n / BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit2StateBlockwise + <<>>( + p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, + absmax2, weight_decay, gnorm_scale, skip_zeros, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n / BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit1StateBlockwise + <<>>( + p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n + ); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); + break; + } +} + +template void percentileClipping(T* g, float* gnorm_vec, int step, const int n) { + int num_blocks = n / 2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(&gnorm_vec[step % 100], 0, 1 * sizeof(float))); + kPercentileClipping<<>>(g, gnorm_vec, step, n); BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } -// ============================================================================ -// igemmlt — BLAS library calls genuinely differ between cuBLAS and hipBLAS -// -// This is one of the few functions requiring substantial #if BNB_HIP blocks. -// The algorithm is the same but hipBLAS requires explicit heuristic selection -// while cuBLAS auto-selects. -// ============================================================================ +void gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +) { + const int falpha = 1; + const int fbeta = 0; + const void* alpha = &falpha; + const void* beta = &fbeta; + +#if BNB_HIP + hipblasStatus_t status; + + status = hipblasGemmEx( + context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, + alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, C, HIP_R_32I, ldc, HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT + ); + + if (status != HIPBLAS_STATUS_SUCCESS) { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } +#else + cublasStatus_t status; + + status = cublasGemmEx( + context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k, + alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, C, CUDA_R_32I, ldc, CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP + ); + + if (status != CUBLAS_STATUS_SUCCESS) { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } +#endif +} + +void strided_gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount +) { + const int falpha = 1; + const int fbeta = 0; + const void* alpha = &falpha; + const void* beta = &fbeta; + +#if BNB_HIP + hipblasStatus_t status; + + status = hipblasGemmStridedBatchedEx( + context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, + alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, C, HIP_R_32I, + ldc, (long long int)strideC, batchCount, HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT + ); + + if (status != HIPBLAS_STATUS_SUCCESS) { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } +#else + cublasStatus_t status; + + status = cublasGemmStridedBatchedEx( + context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k, + alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, C, + CUDA_R_32I, ldc, (long long int)strideC, batchCount, CUDA_R_32I, CUBLAS_GEMM_DEFAULT + ); + + if (status != CUBLAS_STATUS_SUCCESS) { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } +#endif +} + +int roundoff(int v, int d) { return (v + d - 1) / d * d; } template int igemmlt( bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, bnb_stream_t stream ) { + #if BNB_HIP && defined(NO_HIPBLASLT) return ERR_NOT_IMPLEMENTED; #else + + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + int has_error = 0; bnb_blasLt_matmul_desc_t matmulDesc; bnb_blasLt_layout_t aDesc, bDesc, cDesc; + auto opT = BNB_BLASLT_OP_T; auto outType = DTYPE_OUT == 32 ? BNB_R_32I : BNB_R_8I; auto scaleType = DTYPE_OUT == 32 ? BNB_R_32I : BNB_R_32F; - auto opT = BNB_BLASLT_OP_T; + + auto pointerMode = BNB_BLASLT_PTR_MODE_ALPHA_VEC; has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&aDesc, BNB_R_8I, m, k, lda)); has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&bDesc, BNB_R_8I, m, n, ldb)); has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&cDesc, outType, k, n, ldc)); + // Default layout order is col major + has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescCreate(&matmulDesc, BNB_BLASLT_COMPUTE_32I, scaleType)); has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescSetAttr(matmulDesc, BNB_BLASLT_DESC_TRANSA, &opT, sizeof(opT))); if (DTYPE_OUT == 32) { - int alpha = 1, beta = 0; - #if BNB_HIP - // HIP requires explicit algorithm heuristic selection + // HIP requires heuristic algo selection + const int64_t max_workspace_size = 0; // set to 0 to avoid choosing GSU kernel + bnb_blasLt_preference_t pref; - const int64_t max_workspace_size = 0; checkBlasLtStatus(bnb_blasLtPrefCreate(&pref)); checkBlasLtStatus( bnb_blasLtPrefSetAttr(pref, BNB_BLASLT_PREF_MAX_WORKSPACE, &max_workspace_size, sizeof(max_workspace_size)) ); - bnb_blasLt_heuristic_t heuristicResult[1]; + const int request_solutions = 1; + bnb_blasLt_heuristic_t heuristicResult[request_solutions]; int returnedAlgoCount = 0; checkBlasLtStatus(bnb_blasLtAlgoGetHeuristic( - ltHandle, matmulDesc, aDesc, bDesc, cDesc, cDesc, pref, 1, heuristicResult, &returnedAlgoCount + ltHandle, matmulDesc, aDesc, bDesc, cDesc, cDesc, pref, request_solutions, heuristicResult, + &returnedAlgoCount )); if (returnedAlgoCount == 0) { has_error = 1; fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); } else { + int alpha = 1, beta = 0; has_error |= checkBlasLtStatus(bnb_blasLtMatmul( ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, &heuristicResult[0].algo, NULL, 0, stream )); } #else - // CUDA: cuBLAS auto-selects algorithm + int alpha = 1, beta = 0; has_error |= checkBlasLtStatus(bnb_blasLtMatmul( ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, NULL, NULL, 0, stream )); #endif } else { + // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. + if (!SCALE_ROWS) { float alpha = 1.0f, beta = 0.0f; has_error |= checkBlasLtStatus(bnb_blasLtMatmul( @@ -200,10 +413,10 @@ int igemmlt( NULL, 0, stream )); } else { - auto pointerMode = BNB_BLASLT_PTR_MODE_ALPHA_VEC; + auto alphaVec = BNB_BLASLT_PTR_MODE_ALPHA_VEC; float beta = 0.0f; has_error |= checkBlasLtStatus( - bnb_blasLtMatmulDescSetAttr(matmulDesc, BNB_BLASLT_DESC_POINTER_MODE, &pointerMode, sizeof(pointerMode)) + bnb_blasLtMatmulDescSetAttr(matmulDesc, BNB_BLASLT_DESC_POINTER_MODE, &pointerMode, sizeof(alphaVec)) ); has_error |= checkBlasLtStatus(bnb_blasLtMatmul( ltHandle, matmulDesc, row_scale, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL, @@ -221,70 +434,116 @@ int igemmlt( printf("error detected"); return has_error; -#endif +#endif // NO_HIPBLASLT +} + +int fill_up_to_nearest_multiple(int value, int multiple) { + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -// ============================================================================ -// spmm_coo — sparse library calls differ but structure is identical -// Uses unified CHECK_SPARSE and bnb_sparse* macros from compat.cuh -// ============================================================================ +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, bnb_stream_t stream +) { + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows * numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + kdequant_mm_int32_fp16 + <<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} + +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, bnb_stream_t stream +) { + if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); +} void spmm_coo( bnb_sparse_handle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half* B, int ldc, half* C, bool transposed_B ) { #if BNB_HIP && defined(NO_HIPBLASLT) - // No sparse support on older ROCm + return; #else + bnb_sparseSpMatDescr_t descA; + bnb_sparseDnMatDescr_t descB, descC; + float alpha = 1.0f; float beta = 0.0f; void* dBuffer = NULL; size_t bufferSize = 0; - // Note: all of these use the bnb_sparse* macros from compat.cuh - // which resolve to cusparse* or hipsparse* as appropriate - - // bnb_sparseCreateCoo → cusparseCreateCoo / hipsparseCreateCoo - // BNB_R_16F → CUDA_R_16F / HIP_R_16F - // etc. - - // Omitting the body as it would be identical to what compat.cuh provides - // (see full macro mappings in compat.cuh) - CHECK_SPARSE(bnb_sparseCreateCoo( - NULL, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, BNB_SPARSE_INDEX_32I, BNB_SPARSE_INDEX_BASE_ZERO, + &descA, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, BNB_SPARSE_INDEX_32I, BNB_SPARSE_INDEX_BASE_ZERO, BNB_R_16F )); + // Create dense matrix C + CHECK_SPARSE(bnb_sparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, BNB_R_16F, BNB_SPARSE_ORDER_ROW)); + // Create dense matrix B + if (transposed_B) { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_SPARSE(bnb_sparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, BNB_R_16F, BNB_SPARSE_ORDER_ROW)); + // allocate an external buffer if needed + CHECK_SPARSE(bnb_sparseSpMM_bufSize( + handle, BNB_SPARSE_OP_NON_TRANSPOSE, transposed_B ? BNB_SPARSE_OP_TRANSPOSE : BNB_SPARSE_OP_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, BNB_R_32F, BNB_SPARSE_SPMM_ALG_DEFAULT, &bufferSize + )); + BNB_CHECK_RETURN(BNB_DEVICE_MALLOC(&dBuffer, bufferSize)); + + // execute SpMM + CHECK_SPARSE(bnb_sparseSpMM( + handle, BNB_SPARSE_OP_NON_TRANSPOSE, transposed_B ? BNB_SPARSE_OP_TRANSPOSE : BNB_SPARSE_OP_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, BNB_R_32F, BNB_SPARSE_SPMM_ALG_DEFAULT, dBuffer + )); - // ... (rest of spmm_coo using bnb_sparse* macros — same pattern) + // destroy matrix/vector descriptors + CHECK_SPARSE(bnb_sparseDestroySpMat(descA)); + CHECK_SPARSE(bnb_sparseDestroyDnMat(descB)); + CHECK_SPARSE(bnb_sparseDestroyDnMat(descC)); + BNB_CHECK_RETURN(BNB_DEVICE_FREE(dBuffer)); #endif } -// ============================================================================ -// Simple kernel launchers — fully shared -// ============================================================================ - -void dequant_mm_int32_fp16( - int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, bnb_stream_t stream +template +void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB ) { - const int threads = 512; - const int num_per_thread = 4; - const int n = numRows * numCols; - const int num_blocks = (n + threads * num_per_thread - 1) / (threads * num_per_thread); - kdequant_mm_int32_fp16 - <<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); + kspmm_coo_very_sparse_naive<<>>( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB + ); BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } -void int8VectorQuant( - half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, bnb_stream_t stream +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, bnb_stream_t stream ) { - if (threshold == 0.0) { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); - } else { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + + int num_blocks = (m + 3) / 4; +#if BNB_HIP + // On 64-wide warp architectures, each warp processes 2 rows instead of 4 + if (BNB_WARP_SIZE == 64) { + num_blocks = (m + 1) / 2; } +#endif + + kgemm_4bit_inference_naive + <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } @@ -297,9 +556,9 @@ template void func(T* A, T* B, T value, long n) { BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } -// ============================================================================ -// Template instantiations -// ============================================================================ +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== template void func(float* A, float* B, float value, long n); template void func(unsigned char* A, unsigned char* B, unsigned char value, long n); @@ -319,6 +578,15 @@ template void gemm_4bit_inference_naive( int ldc, int blocksize, bnb_stream_t stream ); +template void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); +template void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); + template int igemmlt<32, 0>( bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, bnb_stream_t stream @@ -331,3 +599,125 @@ template int igemmlt<8, 1>( bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, bnb_stream_t stream ); + +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, + const int n +); + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream +); + +#define MAKE_optimizer32bit(name, gtype) \ + template void optimizer32bit( \ + gtype * g, gtype * p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, \ + const int n \ + ); + +MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, bnb_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(MOMENTUM, bnb_bfloat16) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, bnb_bfloat16) MAKE_optimizer32bit( + LION, half +) MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, bnb_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, bnb_bfloat16) MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, bnb_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ + template void optimizerStatic8bit( \ + gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ + float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ + float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ + const float gnorm_scale, int n \ + ); + + MAKE_optimizerStatic8bit(ADAM, half) MAKE_optimizerStatic8bit(ADAM, float) MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit( + RMSPROP, half + ) MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(LION, half) MAKE_optimizerStatic8bit(LION, float) MAKE_optimizerStatic8bit(ADAGRAD, half) MAKE_optimizerStatic8bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ + template void optimizerStatic8bitBlockwise( \ + gtype * p, gtype * g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ + float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ + float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ + ); + + MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); + +template void percentileClipping(float* g, float* gnorm_vec, int step, const int n); +template void percentileClipping(half* g, float* gnorm_vec, int step, const int n); diff --git a/csrc/examples/ops_unified.cuh b/csrc/examples/ops_unified.cuh index b0dd8aaf2..94184fedc 100644 --- a/csrc/examples/ops_unified.cuh +++ b/csrc/examples/ops_unified.cuh @@ -1,7 +1,7 @@ -// ops_unified.cuh — EXAMPLE of merged host API declarations for CUDA/HIP +// Copyright (c) Facebook, Inc. and its affiliates. // -// This replaces both csrc/ops.cuh and csrc/ops_hip.cuh. -// Uses compat.cuh types for all platform-specific identifiers. +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. #ifndef ops_H #define ops_H @@ -17,7 +17,7 @@ #include // ============================================================================ -// Error checking helpers — unified via compat.cuh types +// Error checking helpers // ============================================================================ inline void checkDeviceStatus(bnb_error_t status) { @@ -36,10 +36,12 @@ inline int checkBlasLtStatus(bnb_blas_status_t status) { } // ============================================================================ -// Enums — identical on both platforms +// Enums // ============================================================================ -typedef enum Operations_t { ksmul = 0 } Operations_t; +typedef enum Operations_t { + ksmul = 0, +} Operations_t; typedef enum Optimizer_t { ADAM = 0, @@ -51,13 +53,14 @@ typedef enum Optimizer_t { ADEMAMIX = 6, } Optimizer_t; -typedef enum Funcs_t { FILL = 0, ARANGE = 1, _MUL = 2 } Funcs_t; +typedef enum Funcs_t { + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; // ============================================================================ -// Context classes — platform-specific handles via #if BNB_HIP -// -// This is one of the few places where #if BNB_HIP is needed, because -// the BLAS handle types and creation APIs genuinely differ. +// Context classes // ============================================================================ class Context { @@ -104,12 +107,11 @@ class ContextSparse { }; // ============================================================================ -// Function declarations — use bnb_stream_t / bnb_sparse_handle_t +// Function declarations // ============================================================================ void quantize(float* code, float* A, unsigned char* out, int n); void dequantize(float* code, unsigned char* A, float* out, int n, bnb_stream_t stream); - template void quantizeBlockwise( float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n @@ -146,6 +148,10 @@ void gemmex( Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc ); +void strided_gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount +); template int igemmlt( @@ -153,10 +159,12 @@ int igemmlt( int lda, int ldb, int ldc, bnb_stream_t stream ); +void cutlass_igemm( + bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc +); void dequant_mm_int32_fp16( int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, bnb_stream_t stream ); - void int8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, bnb_stream_t stream ); diff --git a/csrc/examples/pythonInterface_unified.cpp b/csrc/examples/pythonInterface_unified.cpp new file mode 100644 index 000000000..9ed2033e1 --- /dev/null +++ b/csrc/examples/pythonInterface_unified.cpp @@ -0,0 +1,890 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#if BUILD_CUDA +#include +#include +#endif +#if BUILD_HIP +#include +#endif +#if BUILD_MPS +// #include +#endif +#if BUILD_XPU +#include +#endif +#include + +// Compatibility between HIP/CUDA APIs +#if BUILD_HIP +#define cudaStream_t hipStream_t +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#define cusparseHandle_t hipsparseHandle_t +#define cudaMallocManaged hipMallocManaged +#define cudaMemAttachHost hipMemAttachHost +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess +#define cudaMemPrefetchAsync hipMemPrefetchAsync +#endif + +// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. +// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to +// maintain all that boilerplate +//=================================================================================== +// UNMANGLED CALLS +//=================================================================================== + +#if BUILD_CUDA || BUILD_HIP + +void gemm_4bit_inference_naive_fp16( + int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void gemm_4bit_inference_naive_bf16( + int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, + int lda, int ldb, int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive<__nv_bfloat16, 16>( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream + ); +} + +void gemm_4bit_inference_naive_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { func(A, B, value, n); } + +MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + +#define MAKE_FUNC32(fname, oname, gtype, gbits) \ + void fname##32bit_grad_##gbits( \ + gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ + const float weight_decay, const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n \ + ) { \ + optimizer32bit( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \ + lr, gnorm_scale, skip_zeros, n \ + ); \ + } + +MAKE_FUNC32(momentum, MOMENTUM, float, 32) +MAKE_FUNC32(momentum, MOMENTUM, half, 16) +MAKE_FUNC32(adam, ADAM, float, fp32) +MAKE_FUNC32(adam, ADAM, half, fp16) +MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) +MAKE_FUNC32(rmsprop, RMSPROP, float, 32) +MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(lion, LION, float, fp32) +MAKE_FUNC32(lion, LION, half, fp16) +MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) +MAKE_FUNC32(adagrad, ADAGRAD, float, 32) +MAKE_FUNC32(adagrad, ADAGRAD, half, 16) +MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32) +MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16) +MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16) + +#define MAKE_FUNC8(fname, oname, gtype, gbits) \ + void fname##_static_8bit_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ + float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ + float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ + float gnorm_scale, int n \ + ) { \ + optimizerStatic8bit( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \ + max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \ + ); \ + } + +MAKE_FUNC8(adam, ADAM, float, 32) +MAKE_FUNC8(adam, ADAM, half, 16) +MAKE_FUNC8(momentum, MOMENTUM, float, 32) +MAKE_FUNC8(momentum, MOMENTUM, half, 16) +MAKE_FUNC8(rmsprop, RMSPROP, float, 32) +MAKE_FUNC8(rmsprop, RMSPROP, half, 16) +MAKE_FUNC8(lion, LION, float, 32) +MAKE_FUNC8(lion, LION, half, 16) + +#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ + void fname##_8bit_blockwise_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ + float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ + float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ + ) { \ + optimizerStatic8bitBlockwise( \ + p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \ + weight_decay, gnorm_scale, skip_zeros, n \ + ); \ + } + +MAKE_BLOCKWISE8(adam, ADAM, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(lion, LION, half, fp16) +MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(lion, LION, float, fp32) +MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) +MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) + +void percentileClipping_g32(float* g, float* gnorm_vec, int step, const int n) { + percentileClipping(g, gnorm_vec, step, n); +} + +void percentileClipping_g16(half* g, float* gnorm_vec, int step, const int n) { + percentileClipping(g, gnorm_vec, step, n); +} + +void quantizeBlockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(code, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_bf16( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_bf16_fp4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise<__nv_bfloat16, 0, FP4>(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_bf16_nf4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise<__nv_bfloat16, 0, NF4>(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(code, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_fp32_fp4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + +void quantizeBlockwise_fp32_nf4(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise(nullptr, A, absmax, out, nullptr, 0, blocksize, n); +} + +void dequantizeBlockwise_fp16( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise<__nv_bfloat16, FP4>(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise<__nv_bfloat16, NF4>(nullptr, A, absmax, out, blocksize, n, stream); +} + +int igemmlt_32( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +) { + return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +int igemmlt_8( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +) { + return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +int igemmlt_8_rowscale( + cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, cudaStream_t stream +) { + return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +void spmm_coo_very_sparse_naive_fp16( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} + +void spmm_coo_very_sparse_naive_int8( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} +#endif + +#if BUILD_XPU + +void dequantizeBlockwise_fp16( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(nullptr, A, absmax, out, blocksize, n, stream); +} + +void gemv_4bit_inference_fp16( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void gemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream + ); +} + +void gemv_4bit_inference_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + +extern "C" { +#if BUILD_CUDA || BUILD_HIP +void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } + +void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { + dequantize(code, A, out, n, stream); +} + +void cdequantize_blockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32(float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n) { + quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_fp4( + float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_nf4( + float* code, float* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cquantize_blockwise_bf16( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_bf16_fp4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); +} + +void cquantize_blockwise_bf16_nf4( + float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, int blocksize, const int n +) { + quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_bf16( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +) { + dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +#define MAKE_CFUNC32(name, gtype, gbits) \ + void c##name##32bit_grad_##gbits( \ + gtype* g, gtype* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, \ + const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, \ + const int n \ + ) { \ + name##32bit_grad_##gbits( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, \ + lr, gnorm_scale, skip_zeros, n \ + ); \ + } + +MAKE_CFUNC32(adam, float, fp32) +MAKE_CFUNC32(adam, half, fp16) +MAKE_CFUNC32(adam, __nv_bfloat16, bf16) +MAKE_CFUNC32(momentum, float, 32) +MAKE_CFUNC32(momentum, half, 16) +MAKE_CFUNC32(rmsprop, float, 32) +MAKE_CFUNC32(rmsprop, half, 16) +MAKE_CFUNC32(lion, float, fp32) +MAKE_CFUNC32(lion, half, fp16) +MAKE_CFUNC32(lion, __nv_bfloat16, bf16) +MAKE_CFUNC32(adagrad, float, 32) +MAKE_CFUNC32(adagrad, half, 16) +MAKE_CFUNC32(ademamix, float, fp32) +MAKE_CFUNC32(ademamix, half, fp16) +MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) + +#define MAKE_CFUNC8(name, gtype, gbits) \ + void c##name##_static_8bit_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, \ + float param_norm, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, \ + float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, \ + float gnorm_scale, int n \ + ) { \ + name##_static_8bit_grad_##gbits( \ + g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, \ + max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n \ + ); \ + } + +MAKE_CFUNC8(adam, float, 32) +MAKE_CFUNC8(adam, half, 16) +MAKE_CFUNC8(momentum, float, 32) +MAKE_CFUNC8(momentum, half, 16) +MAKE_CFUNC8(rmsprop, float, 32) +MAKE_CFUNC8(rmsprop, half, 16) +MAKE_CFUNC8(lion, float, 32) +MAKE_CFUNC8(lion, half, 16) + +#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ + void c##fname##_8bit_blockwise_grad_##gbits( \ + gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, \ + float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, \ + float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n \ + ) { \ + fname##_8bit_blockwise_grad_##gbits( \ + p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, \ + weight_decay, gnorm_scale, skip_zeros, n \ + ); \ + } + +MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) +MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) +MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(lion, LION, half, fp16) +MAKE_CBLOCKWISE8(lion, LION, float, fp32) +MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) +MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) +MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) +MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) + +void cpercentile_clipping_g32(float* g, float* gnorm_vec, int step, const int n) { + percentileClipping_g32(g, gnorm_vec, step, n); +} + +void cpercentile_clipping_g16(half* g, float* gnorm_vec, int step, const int n) { + percentileClipping_g16(g, gnorm_vec, step, n); +} + +void cigemm( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +) { + gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); +} + +void cbatched_igemm( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long strideA, long strideB, long strideC, int batchCount +) { + strided_gemmex( + context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount + ); +} + +Context* get_context() { return new Context(); } + +ContextSparse* get_cusparse() { return new ContextSparse(); } + +int cigemmlt_32( + Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, + int ldb, int ldc, cudaStream_t stream +) { + return igemmlt_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +int cigemmlt_8( + Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, + int ldb, int ldc, cudaStream_t stream +) { + return igemmlt_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +int cigemmlt_8_rowscale( + Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, + int ldb, int ldc, cudaStream_t stream +) { + return igemmlt_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} + +void cdequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream +) { + dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); +} + +void cint8_vector_quant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream +) { + int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); +} + +void cspmm_coo( + ContextSparse* context, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, half* B, int ldc, half* C, bool transposed_B +) { + spmm_coo( + (cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, + transposed_B + ); +} + +void cspmm_coo_very_sparse_naive_fp16( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive_fp16( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} + +void cspmm_coo_very_sparse_naive_int8( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +) { + spmm_coo_very_sparse_naive_int8( + max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, + colsB + ); +} + +void* cget_managed_ptr(size_t bytes) { + void* ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; +} + +void cprefetch(void* ptr, size_t bytes, int device) { + + int hasPrefetch = 0; + CUDA_CHECK_RETURN( + cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device) + ); // 40ns overhead + if (hasPrefetch == 0) + return; + +#if CUDART_VERSION >= 13000 + cudaMemLocation loc{}; + loc.type = cudaMemLocationTypeDevice; + loc.id = device; + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, loc, 0u, 0)); +#else + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); +#endif + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype* A, ctype* B, ctype value, long n) { fname##_##type_name(A, B, value, n); } + +CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + +void cgemm_4bit_inference_naive_fp16( + int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemm_4bit_inference_naive_bf16( + int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, + int lda, int ldb, int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemm_4bit_inference_naive_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, cudaStream_t stream +) { + gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + +#if BUILD_XPU + +void cdequantize_blockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cgemv_4bit_inference_fp16( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + +void cquantize_blockwise_cpu_fp32( + float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n +) { + quantize_cpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp32( + float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n +) { + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_bf16( + float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n +) { + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp16( + float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n +) { + dequantizeBlockwise8bitCpu(code, A, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_cpu_fp4_fp32( + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n +) { + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); +} + +void cdequantize_blockwise_cpu_fp4_bf16( + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n +) { + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); +} + +void cdequantize_blockwise_cpu_fp4_fp16( + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n +) { + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); +} + +void cdequantize_blockwise_cpu_nf4_fp32( + unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n +) { + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); +} + +void cdequantize_blockwise_cpu_nf4_bf16( + unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n +) { + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); +} + +void cdequantize_blockwise_cpu_nf4_fp16( + unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n +) { + dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); +} + +#if defined(__AVX512F__) && defined(__AVX512BF16__) +void gemv_4bit_inference_cpu_fp4_bf16( + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +) { + gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); +} + +void gemv_4bit_inference_cpu_nf4_bf16( + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +) { + gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); +} +#endif +#if defined(__AVX512F__) +bool has_avx512f_cpu() { return has_avx512f(); } +#if defined(__AVX512BF16__) +bool has_avx512bf16_cpu() { return has_avx512bf16(); } +#endif +#endif +}